PandaRoot
PndMultiClassBdtTrain.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /* ***************************************
14  * MultiClass BDTG Training functions *
15  * Author: M.Babai@rug.nl *
16  * Version: *
17  * LICENSE: *
18  * ***************************************
19  */
20 /*
21  * Note: This is just an interface to the original TMVA
22  * implementation. To find out the available options, please read TMVA
23  * manuals. In case of errors or wrong outputs produced by TMVA
24  * classifiers, try to read their mailing list and send your questions
25  * to the same list.
26  ******* VERY IMORTANT ****
27  * You NEED (TMVA version) > 4.1.X before this works.
28  */
29 //#pragma once
30 #ifndef PND_MULTICLASS_BDT_TRAIN_H
31 #define PND_MULTICLASS_BDT_TRAIN_H
32 
33 // Local includes
34 #include "PndMvaTrainer.h"
35 
36 // TMVA && ROOT
37 #include "TMVA/Factory.h"
38 #include "TMVA/Config.h"
39 
40 // Interface definition for Multiclass MLP trainers.
42  //----------------------------------------
43  //================== public ==============
44  public:
52  explicit PndMultiClassBdtTrain(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
56  virtual ~PndMultiClassBdtTrain();
57 
61  void Train();
62 
67  void storeWeights();
68 
72  void Initialize();
73 
74  //______________________________________________
75  //====== Getters and setters.
76  // Set the name of the current job
77  inline void SetJobName(std::string const &name);
78 
79  // Set data Transformation scheme
80  inline void SetTransformation(std::string const &tran);
81 
82  // Set the options for the MLP alg. See TMVA manuals.
83  inline void SetBdtOptions(std::string const &opts);
84 
85  // Set the file name to store evaluation outputs.
86  inline void SetEvalFileName(std::string const &fname);
87 
88  // Set the directory where weights are stored.
89  inline void SetWeightsOutDir(std::string const &dirName);
90 
91  // Evaluate the classifier?
92  inline void SetEvaluation(bool evaluate);
93 
94  // Get the current job name.
95  inline std::string const &GetJobName() const;
96 
97  // Get the current transformation info.
98  inline std::string const &GetTransformation() const;
99 
100  // Get the classifier options.
101  inline std::string const &GetBdtOptions() const;
102 
103  // Get the name of the weight file.
104  inline std::string const &GetEvalFileName() const;
105 
106  // Get the directory where the weights are stored.
107  inline std::string const &GetWeightsOutDir() const;
108  //----------------------------------------
109 
110  //================== protected ============
111  // protected:
112  //----------------------------------------
113 
114  //================== private =============
115  private:
116  // To avoid mistakes.
118  PndMultiClassBdtTrain &operator=(PndMultiClassBdtTrain const &oth);
119 
120  // Initialize mlp object and set the options.
121  void InitBdt();
122  // Add the variables to the TMVA factory object.
123  void AddVariables();
124 
125  //==============================
126  TMVA::Factory *m_factory; // TMVA factory
127  TFile *EvalFile; // To store evaluation file
128  std::string m_JName; // Job name
129  std::string m_transform; // Transformation opt.
130  std::string m_BdtOptions; // Bdt options.
131  std::string m_evalFileName; // evaluation file name.
132  std::string m_weightDirName; // Directory name to store weights.
133  bool m_Evaluate;
134 }; // End of interface definition.
135 //=============== inline functions implementation. ========
136 //__________________________________________
137 inline void PndMultiClassBdtTrain::SetJobName(std::string const &name)
138 {
139  this->m_JName = name;
140 };
141 
142 inline void PndMultiClassBdtTrain::SetTransformation(std::string const &tr)
143 {
144  this->m_transform = tr;
145 };
146 
147 inline void PndMultiClassBdtTrain::SetBdtOptions(std::string const &opt)
148 {
149  this->m_BdtOptions = opt;
150 };
151 
152 inline std::string const &PndMultiClassBdtTrain::GetJobName() const
153 {
154  return m_JName;
155 };
156 
157 inline std::string const &PndMultiClassBdtTrain::GetTransformation() const
158 {
159  return m_transform;
160 };
161 
162 inline std::string const &PndMultiClassBdtTrain::GetBdtOptions() const
163 {
164  return m_BdtOptions;
165 };
166 
167 inline void PndMultiClassBdtTrain::SetEvalFileName(std::string const &fname)
168 {
169  this->m_evalFileName = fname;
170 };
171 
172 inline std::string const &PndMultiClassBdtTrain::GetEvalFileName() const
173 {
174  return m_evalFileName;
175 };
176 
177 inline void PndMultiClassBdtTrain::SetWeightsOutDir(std::string const &dirName)
178 {
179  this->m_weightDirName = dirName;
180 };
181 
182 inline std::string const &PndMultiClassBdtTrain::GetWeightsOutDir() const
183 {
184  return m_weightDirName;
185 };
186 
187 inline void PndMultiClassBdtTrain::SetEvaluation(bool evaluate)
188 {
189  this->m_Evaluate = evaluate;
190 };
191 #endif
std::string const & GetTransformation() const
std::string const & GetBdtOptions() const
void SetEvaluation(bool evaluate)
std::string const & GetWeightsOutDir() const
PndMultiClassBdtTrain(std::string const &InPut, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetTransformation(std::string const &tran)
void SetEvalFileName(std::string const &fname)
void SetJobName(std::string const &name)
void SetBdtOptions(std::string const &opts)
void SetWeightsOutDir(std::string const &dirName)
std::string const & GetJobName() const
virtual ~PndMultiClassBdtTrain()
std::string const & GetEvalFileName() const