PandaRoot
PndMultiClassMlpTrain.h
Go to the documentation of this file.
1 /* ***************************************
2  * MultiClass MLP Training functions *
3  * Author: M.Babai@rug.nl *
4  * Version: *
5  * LICENSE: *
6  * ***************************************
7  */
8 /*
9  * Note: This is just an interface to the original TMVA
10  * implementation. To find out the available options, please read TMVA
11  * manuals. In case of errors or wrong outputs produced by TMVA
12  * classifiers, try to read their mailing list and send your questions
13  * to the same list.
14  ******* VERY IMORTANT ****
15  * You NEED TMVA version > 4.1.X before this works.
16  */
17 //#pragma once
18 #ifndef PND_MULTICLASS_MLP_TRAIN_H
19 #define PND_MULTICLASS_MLP_TRAIN_H
20 
21 // Local includes
22 #include "PndMvaTrainer.h"
23 
24 // TMVA && ROOT
25 #include "TMVA/Factory.h"
26 #include "TMVA/Config.h"
27 
28 // Interface definition for Multiclass MLP trainers.
30  //----------------------------------------
31  //================== public ==============
32  public:
40  explicit PndMultiClassMlpTrain(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
44  virtual ~PndMultiClassMlpTrain();
45 
49  void Train();
50 
55  void storeWeights();
56 
60  void Initialize();
61 
62  //______________________________________________
63  //====== Getters and setters.
64  // Set the name of the current job
65  inline void SetJobName(std::string const &name);
66 
67  // Set data Transformation scheme
68  inline void SetTransformation(std::string const &tran);
69 
70  // Set the options for the MLP alg. See TMVA manuals.
71  inline void SetMlpOptions(std::string const &opts);
72 
73  // Set the file name to store evaluation outputs.
74  inline void SetEvalFileName(std::string const &fname);
75 
76  // Set the directory where weights are stored.
77  inline void SetWeightsOutDir(std::string const &dirName);
78 
79  // Evaluate the classifier?
80  inline void SetEvaluation(bool evaluate);
81 
82  // Get the current job name.
83  inline std::string const &GetJobName() const;
84 
85  // Get the current transformation info.
86  inline std::string const &GetTransformation() const;
87 
88  // Get the classifier options.
89  inline std::string const &GetMlpOptions() const;
90 
91  // Get the name of the weight file.
92  inline std::string const &GetEvalFileName() const;
93 
94  // Get the directory where the weights are stored.
95  inline std::string const &GetWeightsOutDir() const;
96  //----------------------------------------
97 
98  //================== protected ============
99  // protected:
100  //----------------------------------------
101 
102  //================== private =============
103  private:
104  // To avoid mistakes.
106  PndMultiClassMlpTrain &operator=(PndMultiClassMlpTrain const &oth);
107 
108  // Initialize mlp object and set the options.
109  void InitMlp();
110 
111  // Add the variables to the TMVA factory object.
112  void AddVariables();
113 
114  //==============================
115  TMVA::Factory *m_factory; // TMVA factory
116  TFile *EvalFile; // To store evaluation file
117  std::string m_JName; // Job name
118  std::string m_transform; // Transformation opt.
119  std::string m_MlpOptions; // mlp options.
120  std::string m_evalFileName; // evaluation file name.
121  std::string m_weightDirName; // Directory name to store weights.
122  bool m_Evaluate;
123 }; // End of interface definition.
124 //=============== inline functions implementation. ========
125 //__________________________________________
126 inline void PndMultiClassMlpTrain::SetJobName(std::string const &name)
127 {
128  this->m_JName = name;
129 };
130 
131 inline void PndMultiClassMlpTrain::SetTransformation(std::string const &tr)
132 {
133  this->m_transform = tr;
134 };
135 
136 inline void PndMultiClassMlpTrain::SetMlpOptions(std::string const &opt)
137 {
138  this->m_MlpOptions = opt;
139 };
140 
141 inline std::string const &PndMultiClassMlpTrain::GetJobName() const
142 {
143  return m_JName;
144 };
145 
146 inline std::string const &PndMultiClassMlpTrain::GetTransformation() const
147 {
148  return m_transform;
149 };
150 
151 inline std::string const &PndMultiClassMlpTrain::GetMlpOptions() const
152 {
153  return m_MlpOptions;
154 };
155 
156 inline void PndMultiClassMlpTrain::SetEvalFileName(std::string const &fname)
157 {
158  this->m_evalFileName = fname;
159 };
160 
161 inline std::string const &PndMultiClassMlpTrain::GetEvalFileName() const
162 {
163  return m_evalFileName;
164 };
165 
166 inline void PndMultiClassMlpTrain::SetWeightsOutDir(std::string const &dirName)
167 {
168  this->m_weightDirName = dirName;
169 };
170 
171 inline std::string const &PndMultiClassMlpTrain::GetWeightsOutDir() const
172 {
173  return m_weightDirName;
174 };
175 
176 inline void PndMultiClassMlpTrain::SetEvaluation(bool evaluate)
177 {
178  this->m_Evaluate = evaluate;
179 };
180 #endif
PndMultiClassMlpTrain(std::string const &InPut, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
std::string const & GetJobName() const
void SetEvalFileName(std::string const &fname)
void SetWeightsOutDir(std::string const &dirName)
std::string const & GetTransformation() const
virtual ~PndMultiClassMlpTrain()
void SetTransformation(std::string const &tran)
void SetMlpOptions(std::string const &opts)
void SetJobName(std::string const &name)
std::string const & GetWeightsOutDir() const
std::string const & GetEvalFileName() const
void SetEvaluation(bool evaluate)
std::string const & GetMlpOptions() const