10 #ifndef PND_MVA_TRAINER_H 11 #define PND_MVA_TRAINER_H 21 #include "TMVA/Tools.h" 22 #include "TMVA/PDEFoam.h" 23 #include "TMVA/Event.h" 31 #define TRAIN_INC_FOAM 0 43 explicit PndMvaTrainer(std::vector<std::pair<std::string, std::vector<float> *>>
const &InputEvtsParam, std::vector<std::string>
const &ClassNames,
44 std::vector<std::string>
const &VarNames,
bool trim =
true);
52 explicit PndMvaTrainer(std::string
const &InPut, std::vector<std::string>
const &ClassNames, std::vector<std::string>
const &VarNames,
bool trim =
true);
58 virtual void Train() = 0;
76 void SetTestSet(std::set<size_t>
const &testSet);
123 inline std::vector<PndMvaClass>
const &
GetClasses()
const;
129 inline std::vector<PndMvaVariable>
const &
GetVariables()
const;
150 void WriteToWeightFile(std::vector<std::pair<std::string, std::vector<float> *>>
const &weights)
const;
152 #if (TRAIN_INC_FOAM > 0) 188 size_t m_testSetSize;
size_t m_RND_seed
Random seed.
std::set< size_t > const & GetTestEvetIdx() const
void SetTestSetSize(size_t percent=50)
void SetRndSeed(size_t const sd)
PndMvaDataSet m_dataSets
Data set. Holds event values.
void NormalizeData(NormType t=NONORM)
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
virtual void EvalClassifierError()
void WriteErroVect(std::string const &FileName) const
virtual void Initialize()
virtual ~PndMvaTrainer()
Destructor.
void SetAppType(AppType t)
virtual void storeWeights()=0
PndMvaTrainer(std::vector< std::pair< std::string, std::vector< float > *>> const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetAppType(AppType t)
virtual void Train()=0
Derived classes need to implement this methode.
std::set< size_t > m_testSet_indices
Indices of the test set.
std::vector< StepError > const & GetErrorValues() const
size_t GetRndSeed() const
void WriteToWeightFile(std::vector< std::pair< std::string, std::vector< float > *>> const &weights) const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
std::string m_outFile
Output filename.
std::vector< StepError > m_StepErro
Container to keep per step error values.
void SetOutPutFile(std::string const &outFile)
std::vector< PndMvaVariable > const & GetVariables() const
Get the list of available variables.
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
void SetTestSet(std::set< size_t > const &testSet)