PandaRoot
PndMvaTrainer.h
Go to the documentation of this file.
1 /* **********************************************
2  * MVA classifiers trainers interface. *
3  * Author: M. Babai *
4  * M.Babai@rug.nl *
5  * Version: 0.1 beta1. *
6  * LICENSE: *
7  * **********************************************
8  */
9 //#pragma once
10 #ifndef PND_MVA_TRAINER_H
11 #define PND_MVA_TRAINER_H
12 
13 // C++ includes
14 #include <cassert>
15 #include <limits>
16 #include <ctime>
17 #include <iomanip>
18 #include <cstdlib>
19 
20 // ROOT and PandaRoot
21 #include "TMVA/Tools.h"
22 #include "TMVA/PDEFoam.h"
23 #include "TMVA/Event.h"
24 
25 class TRandom3;
26 
27 // Local includes
28 #include "PndMvaDataSet.h"
29 #include "PndMvaUtil.h"
30 
31 #define TRAIN_INC_FOAM 0
32 
34  //==============================================
35  //================ Public =======================
36  public:
43  explicit PndMvaTrainer(std::vector<std::pair<std::string, std::vector<float> *>> const &InputEvtsParam, std::vector<std::string> const &ClassNames,
44  std::vector<std::string> const &VarNames, bool trim = true);
45 
52  explicit PndMvaTrainer(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
53 
55  virtual ~PndMvaTrainer();
56 
58  virtual void Train() = 0;
59 
64  virtual void storeWeights() = 0;
65 
70  void SetTestSetSize(size_t percent = 50);
71 
76  void SetTestSet(std::set<size_t> const &testSet);
77 
81  void NormalizeData(NormType t = NONORM);
82 
87  void PCATransForm();
88 
93  inline void SetOutPutFile(std::string const &outFile);
94 
99  void WriteErroVect(std::string const &FileName) const;
100 
106  inline std::vector<StepError> const &GetErrorValues() const;
107 
111  virtual void Initialize();
112 
117  inline std::set<size_t> const &GetTestEvetIdx() const;
118 
123  inline std::vector<PndMvaClass> const &GetClasses() const;
124 
129  inline std::vector<PndMvaVariable> const &GetVariables() const;
130 
134  virtual void EvalClassifierError();
135 
136  inline size_t GetRndSeed() const;
137  inline void SetRndSeed(size_t const sd);
138 
139  //______________________________________________
140  //================ Protected ===================
141  protected:
145  inline void SetAppType(AppType t);
146 
150  void WriteToWeightFile(std::vector<std::pair<std::string, std::vector<float> *>> const &weights) const;
151 
152 #if (TRAIN_INC_FOAM > 0)
153 
159  void WriteToWeightFile(std::vector<TMVA::PDEFoam *> const &foams) const;
160 #endif
161  // void WriteDataSetToOutFile();
162 
164  std::set<size_t> m_testSet_indices;
165 
168 
170  std::vector<StepError> m_StepErro;
171 
173  std::string m_outFile;
174 
176  size_t m_RND_seed;
177 
178  void splitTetsSet();
179  //______________________________________________
180  //================ Private =====================
181  private:
183  PndMvaTrainer(PndMvaTrainer const &other);
184  PndMvaTrainer &operator=(PndMvaTrainer const &other);
185 
186  // trim or not
187  bool m_trim;
188  size_t m_testSetSize;
189 }; // End of class definition.
190 
191 //========================= Inline implementations =================
192 inline size_t PndMvaTrainer::GetRndSeed() const
193 {
194  return this->m_RND_seed;
195 };
196 
197 inline void PndMvaTrainer::SetRndSeed(size_t const sd)
198 {
199  this->m_RND_seed = sd;
200 };
201 
202 inline void PndMvaTrainer::SetOutPutFile(std::string const &outFile)
203 {
204  m_outFile = outFile;
205 };
206 
208 {
210 };
211 
212 inline std::set<size_t> const &PndMvaTrainer::GetTestEvetIdx() const
213 {
214  return m_testSet_indices;
215 };
216 
218 inline std::vector<PndMvaClass> const &PndMvaTrainer::GetClasses() const
219 {
220  return m_dataSets.GetClasses();
221 };
222 
224 inline std::vector<PndMvaVariable> const &PndMvaTrainer::GetVariables() const
225 {
226  return m_dataSets.GetVars();
227 };
228 
229 // @return List of evaluation objects.
230 inline std::vector<StepError> const &PndMvaTrainer::GetErrorValues() const
231 {
232  return m_StepErro;
233 };
234 #endif
size_t m_RND_seed
Random seed.
std::set< size_t > const & GetTestEvetIdx() const
void SetTestSetSize(size_t percent=50)
void PCATransForm()
void SetRndSeed(size_t const sd)
PndMvaDataSet m_dataSets
Data set. Holds event values.
void NormalizeData(NormType t=NONORM)
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
virtual void EvalClassifierError()
void WriteErroVect(std::string const &FileName) const
virtual void Initialize()
virtual ~PndMvaTrainer()
Destructor.
void SetAppType(AppType t)
virtual void storeWeights()=0
PndMvaTrainer(std::vector< std::pair< std::string, std::vector< float > *>> const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetAppType(AppType t)
NormType
Definition: PndMvaDataSet.h:48
virtual void Train()=0
Derived classes need to implement this methode.
std::set< size_t > m_testSet_indices
Indices of the test set.
std::vector< StepError > const & GetErrorValues() const
size_t GetRndSeed() const
void WriteToWeightFile(std::vector< std::pair< std::string, std::vector< float > *>> const &weights) const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
std::string m_outFile
Output filename.
std::vector< StepError > m_StepErro
Container to keep per step error values.
void SetOutPutFile(std::string const &outFile)
std::vector< PndMvaVariable > const & GetVariables() const
Get the list of available variables.
void splitTetsSet()
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
AppType
Definition: PndMvaDataSet.h:38
void SetTestSet(std::set< size_t > const &testSet)