PandaRoot
PndMvaTrainer.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /* **********************************************
14  * MVA classifiers trainers interface. *
15  * Author: M. Babai *
16  * M.Babai@rug.nl *
17  * Version: 0.1 beta1. *
18  * LICENSE: *
19  * **********************************************
20  */
21 //#pragma once
22 #ifndef PND_MVA_TRAINER_H
23 #define PND_MVA_TRAINER_H
24 
25 // C++ includes
26 #include <cassert>
27 #include <limits>
28 #include <ctime>
29 #include <iomanip>
30 #include <cstdlib>
31 
32 // ROOT and PandaRoot
33 #include "TMVA/Tools.h"
34 #include "TMVA/PDEFoam.h"
35 #include "TMVA/Event.h"
36 
37 class TRandom3;
38 
39 // Local includes
40 #include "PndMvaDataSet.h"
41 #include "PndMvaUtil.h"
42 
43 #define TRAIN_INC_FOAM 0
44 
46  //==============================================
47  //================ Public =======================
48  public:
55  explicit PndMvaTrainer(std::vector<std::pair<std::string, std::vector<float> *>> const &InputEvtsParam, std::vector<std::string> const &ClassNames,
56  std::vector<std::string> const &VarNames, bool trim = true);
57 
64  explicit PndMvaTrainer(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
65 
67  virtual ~PndMvaTrainer();
68 
70  virtual void Train() = 0;
71 
76  virtual void storeWeights() = 0;
77 
82  void SetTestSetSize(size_t percent = 50);
83 
88  void SetTestSet(std::set<size_t> const &testSet);
89 
93  void NormalizeData(NormType t = NONORM);
94 
99  void PCATransForm();
100 
105  inline void SetOutPutFile(std::string const &outFile);
106 
111  void WriteErroVect(std::string const &FileName) const;
112 
118  inline std::vector<StepError> const &GetErrorValues() const;
119 
123  virtual void Initialize();
124 
129  inline std::set<size_t> const &GetTestEvetIdx() const;
130 
135  inline std::vector<PndMvaClass> const &GetClasses() const;
136 
141  inline std::vector<PndMvaVariable> const &GetVariables() const;
142 
146  virtual void EvalClassifierError();
147 
148  inline size_t GetRndSeed() const;
149  inline void SetRndSeed(size_t const sd);
150 
151  //______________________________________________
152  //================ Protected ===================
153  protected:
157  inline void SetAppType(AppType t);
158 
162  void WriteToWeightFile(std::vector<std::pair<std::string, std::vector<float> *>> const &weights) const;
163 
164 #if (TRAIN_INC_FOAM > 0)
165 
171  void WriteToWeightFile(std::vector<TMVA::PDEFoam *> const &foams) const;
172 #endif
173  // void WriteDataSetToOutFile();
174 
176  std::set<size_t> m_testSet_indices;
177 
180 
182  std::vector<StepError> m_StepErro;
183 
185  std::string m_outFile;
186 
188  size_t m_RND_seed;
189 
190  void splitTetsSet();
191  //______________________________________________
192  //================ Private =====================
193  private:
195  PndMvaTrainer(PndMvaTrainer const &other);
196  PndMvaTrainer &operator=(PndMvaTrainer const &other);
197 
198  // trim or not
199  bool m_trim;
200  size_t m_testSetSize;
201 }; // End of class definition.
202 
203 //========================= Inline implementations =================
204 inline size_t PndMvaTrainer::GetRndSeed() const
205 {
206  return this->m_RND_seed;
207 };
208 
209 inline void PndMvaTrainer::SetRndSeed(size_t const sd)
210 {
211  this->m_RND_seed = sd;
212 };
213 
214 inline void PndMvaTrainer::SetOutPutFile(std::string const &outFile)
215 {
216  m_outFile = outFile;
217 };
218 
220 {
222 };
223 
224 inline std::set<size_t> const &PndMvaTrainer::GetTestEvetIdx() const
225 {
226  return m_testSet_indices;
227 };
228 
230 inline std::vector<PndMvaClass> const &PndMvaTrainer::GetClasses() const
231 {
232  return m_dataSets.GetClasses();
233 };
234 
236 inline std::vector<PndMvaVariable> const &PndMvaTrainer::GetVariables() const
237 {
238  return m_dataSets.GetVars();
239 };
240 
241 // @return List of evaluation objects.
242 inline std::vector<StepError> const &PndMvaTrainer::GetErrorValues() const
243 {
244  return m_StepErro;
245 };
246 #endif
size_t m_RND_seed
Random seed.
std::set< size_t > const & GetTestEvetIdx() const
void SetTestSetSize(size_t percent=50)
void PCATransForm()
void SetRndSeed(size_t const sd)
PndMvaDataSet m_dataSets
Data set. Holds event values.
void NormalizeData(NormType t=NONORM)
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
virtual void EvalClassifierError()
void WriteErroVect(std::string const &FileName) const
virtual void Initialize()
virtual ~PndMvaTrainer()
Destructor.
void SetAppType(AppType t)
virtual void storeWeights()=0
PndMvaTrainer(std::vector< std::pair< std::string, std::vector< float > *>> const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetAppType(AppType t)
NormType
Definition: PndMvaDataSet.h:60
virtual void Train()=0
Derived classes need to implement this methode.
std::set< size_t > m_testSet_indices
Indices of the test set.
std::vector< StepError > const & GetErrorValues() const
size_t GetRndSeed() const
void WriteToWeightFile(std::vector< std::pair< std::string, std::vector< float > *>> const &weights) const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
std::string m_outFile
Output filename.
std::vector< StepError > m_StepErro
Container to keep per step error values.
void SetOutPutFile(std::string const &outFile)
std::vector< PndMvaVariable > const & GetVariables() const
Get the list of available variables.
void splitTetsSet()
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
AppType
Definition: PndMvaDataSet.h:50
void SetTestSet(std::set< size_t > const &testSet)