PandaRoot
PndMultiClassMlpTrain.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /* ***************************************
14  * MultiClass MLP Training functions *
15  * Author: M.Babai@rug.nl *
16  * Version: *
17  * LICENSE: *
18  * ***************************************
19  */
20 /*
21  * Note: This is just an interface to the original TMVA
22  * implementation. To find out the available options, please read TMVA
23  * manuals. In case of errors or wrong outputs produced by TMVA
24  * classifiers, try to read their mailing list and send your questions
25  * to the same list.
26  ******* VERY IMORTANT ****
27  * You NEED TMVA version > 4.1.X before this works.
28  */
29 //#pragma once
30 #ifndef PND_MULTICLASS_MLP_TRAIN_H
31 #define PND_MULTICLASS_MLP_TRAIN_H
32 
33 // Local includes
34 #include "PndMvaTrainer.h"
35 
36 // TMVA && ROOT
37 #include "TMVA/Factory.h"
38 #include "TMVA/Config.h"
39 
40 // Interface definition for Multiclass MLP trainers.
42  //----------------------------------------
43  //================== public ==============
44  public:
52  explicit PndMultiClassMlpTrain(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
56  virtual ~PndMultiClassMlpTrain();
57 
61  void Train();
62 
67  void storeWeights();
68 
72  void Initialize();
73 
74  //______________________________________________
75  //====== Getters and setters.
76  // Set the name of the current job
77  inline void SetJobName(std::string const &name);
78 
79  // Set data Transformation scheme
80  inline void SetTransformation(std::string const &tran);
81 
82  // Set the options for the MLP alg. See TMVA manuals.
83  inline void SetMlpOptions(std::string const &opts);
84 
85  // Set the file name to store evaluation outputs.
86  inline void SetEvalFileName(std::string const &fname);
87 
88  // Set the directory where weights are stored.
89  inline void SetWeightsOutDir(std::string const &dirName);
90 
91  // Evaluate the classifier?
92  inline void SetEvaluation(bool evaluate);
93 
94  // Get the current job name.
95  inline std::string const &GetJobName() const;
96 
97  // Get the current transformation info.
98  inline std::string const &GetTransformation() const;
99 
100  // Get the classifier options.
101  inline std::string const &GetMlpOptions() const;
102 
103  // Get the name of the weight file.
104  inline std::string const &GetEvalFileName() const;
105 
106  // Get the directory where the weights are stored.
107  inline std::string const &GetWeightsOutDir() const;
108  //----------------------------------------
109 
110  //================== protected ============
111  // protected:
112  //----------------------------------------
113 
114  //================== private =============
115  private:
116  // To avoid mistakes.
118  PndMultiClassMlpTrain &operator=(PndMultiClassMlpTrain const &oth);
119 
120  // Initialize mlp object and set the options.
121  void InitMlp();
122 
123  // Add the variables to the TMVA factory object.
124  void AddVariables();
125 
126  //==============================
127  TMVA::Factory *m_factory; // TMVA factory
128  TFile *EvalFile; // To store evaluation file
129  std::string m_JName; // Job name
130  std::string m_transform; // Transformation opt.
131  std::string m_MlpOptions; // mlp options.
132  std::string m_evalFileName; // evaluation file name.
133  std::string m_weightDirName; // Directory name to store weights.
134  bool m_Evaluate;
135 }; // End of interface definition.
136 //=============== inline functions implementation. ========
137 //__________________________________________
138 inline void PndMultiClassMlpTrain::SetJobName(std::string const &name)
139 {
140  this->m_JName = name;
141 };
142 
143 inline void PndMultiClassMlpTrain::SetTransformation(std::string const &tr)
144 {
145  this->m_transform = tr;
146 };
147 
148 inline void PndMultiClassMlpTrain::SetMlpOptions(std::string const &opt)
149 {
150  this->m_MlpOptions = opt;
151 };
152 
153 inline std::string const &PndMultiClassMlpTrain::GetJobName() const
154 {
155  return m_JName;
156 };
157 
158 inline std::string const &PndMultiClassMlpTrain::GetTransformation() const
159 {
160  return m_transform;
161 };
162 
163 inline std::string const &PndMultiClassMlpTrain::GetMlpOptions() const
164 {
165  return m_MlpOptions;
166 };
167 
168 inline void PndMultiClassMlpTrain::SetEvalFileName(std::string const &fname)
169 {
170  this->m_evalFileName = fname;
171 };
172 
173 inline std::string const &PndMultiClassMlpTrain::GetEvalFileName() const
174 {
175  return m_evalFileName;
176 };
177 
178 inline void PndMultiClassMlpTrain::SetWeightsOutDir(std::string const &dirName)
179 {
180  this->m_weightDirName = dirName;
181 };
182 
183 inline std::string const &PndMultiClassMlpTrain::GetWeightsOutDir() const
184 {
185  return m_weightDirName;
186 };
187 
188 inline void PndMultiClassMlpTrain::SetEvaluation(bool evaluate)
189 {
190  this->m_Evaluate = evaluate;
191 };
192 #endif
PndMultiClassMlpTrain(std::string const &InPut, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
std::string const & GetJobName() const
void SetEvalFileName(std::string const &fname)
void SetWeightsOutDir(std::string const &dirName)
std::string const & GetTransformation() const
virtual ~PndMultiClassMlpTrain()
void SetTransformation(std::string const &tran)
void SetMlpOptions(std::string const &opts)
void SetJobName(std::string const &name)
std::string const & GetWeightsOutDir() const
std::string const & GetEvalFileName() const
void SetEvaluation(bool evaluate)
std::string const & GetMlpOptions() const