PandaRoot
PndMvaDataSet.h
Go to the documentation of this file.
1 /***************************************
2  * Class interface of DataSet class. *
3  * Author: M.Babai (M.Babai@rug.nl) *
4  * License: *
5  * Version: *
6  ***************************************/
7 //#pragma once
8 #ifndef PND_MVA_DATASET_H
9 #define PND_MVA_DATASET_H
10 
11 // C++ includes
12 #include <iostream>
13 #include <fstream>
14 #include <string>
15 #include <vector>
16 #include <map>
17 #include <set>
18 #include <algorithm>
19 #include <cmath>
20 #include <cassert>
21 #include <limits>
22 #include <typeinfo>
23 #include <exception>
24 #include <utility>
25 
26 // ROOT
27 #include "TFile.h"
28 #include "TTree.h"
29 #include "TRandom3.h"
30 
31 // Local includes
32 #include "PndMvaClass.h"
33 #include "PndMvaVariable.h"
34 #include "PndMvaVarPCATransform.h"
35 
36 // ========================================================================
37 // Application type
38 typedef enum AppType {
39  UNKAPP = 0,
40  TRAIN = 1, // Training algorithm.
41  CLASSIFY = 2, // Read weights to do classification.
42  TMVATRAIN = 10, // Provide input for TMVA Training.
43  TMVACLS = 20, // TMVA classification.
44  PRE_INIT_EVTS = 30 // Pre-initialized event data.
45 } AppType;
46 
47 // Normalization schemes
48 typedef enum NormType {
49  NONORM = 0, // Do nothing
50  VARX = 1, // Use Sample variance
51  MINMAX = 2, // Use Sample Min and Max
52  MEDIAN = 3, // Use median and interquartile range (IQR).
53  VARNORM = 4 // Variable Normalize Transform
54 } NormType;
55 
57 
58 class PndMvaDataSetException : public std::exception {
59  public:
60  explicit PndMvaDataSetException() : m_message("UNKNOWN_MvaDataSetException"){};
61 
62  explicit PndMvaDataSetException(std::string const &val) : m_message(val){};
63 
64  virtual ~PndMvaDataSetException() throw(){};
65 
66  virtual char const *what() const throw() { return m_message.c_str(); };
67 
68  virtual std::string const &what() { return m_message; };
69 
70  private:
71  std::string m_message;
72 };
74 
75 // ==================== Data set class ==========================
77  public:
88  explicit PndMvaDataSet(std::vector<std::pair<std::string, std::vector<float> *>> const &InputEvtsParam, std::vector<std::string> const &classNames,
89  std::vector<std::string> const &varNames, AppType type);
97  explicit PndMvaDataSet(std::string const &WeightFile, std::vector<std::string> const &classNames, std::vector<std::string> const &varNames, AppType type);
98 
100  virtual ~PndMvaDataSet();
101 
106  // void WriteDataSet(std::string const& outFile) __attribute__ ((deprecated));
107  virtual void WriteDataSet(std::string const &outFile);
108 
118  virtual void InitClsCondMeans(std::set<size_t> const &excludeIndxs);
119 
124  inline void SetTrim(bool t);
125 
127  inline std::vector<std::pair<std::string, std::vector<float> *>> const &GetData() const;
128 
130  inline std::vector<PndMvaClass> const &GetClasses() const;
131 
133  inline std::vector<PndMvaVariable> const &GetVars() const;
134 
136  inline std::map<std::string, std::vector<float> *> const &GetClassCondMeans() const;
137 
139  inline std::string const &GetInFileName() const;
140 
141  //========================= PCA =====================//
147  virtual void PCATransForm();
148 
152  inline bool Used_PCA() const;
153 
157  inline void Use_PCA(bool t);
158 
163  inline PndMvaVarPCATransform const &Get_PCA() const;
164 
165  //_________________________ PCA _____________________//
166 
170  inline NormType GetNormType() const;
171 
175  inline void SetNormType(NormType t);
176 
180  inline AppType GetAppType() const;
181 
184  inline void SetAppType(AppType t);
185 
190  virtual void Initialize();
191 
192  inline size_t GetRndSeed() const;
193  inline void SetRndSeed(size_t const sd);
194 
195  //______________________________________________________________
196  protected:
200  void ReadInput();
201 
205  void ReadWeightsFromFile();
206 
207  //==============================================================
208  private:
209  // Private to avoid mistakes.
210  PndMvaDataSet(PndMvaDataSet const &other);
211  PndMvaDataSet &operator=(PndMvaDataSet const &other);
212 
216  void Trim();
217 
221  void NormalizeDataSet();
222 
227  void InitClasses(std::vector<std::string> const &labels);
228 
233  void InitVariables(std::vector<std::string> const &variables);
234 
235  // Validate the input file
236  void ValidateWeightFile();
237 
244  void CompClsCondMean(std::string const &clsName, std::set<size_t> const &exCluds);
245 
250  void ComputeVariance();
251 
255  void DetermineMedian();
256 
260  void MinMaxDiff();
261 
265  void FindMinMax();
266 
270  void VarNormalize();
271 
272  // __________________________ Member parameters ___________
274  std::string m_input;
275 
277  std::vector<PndMvaClass> m_classes;
278 
280  std::vector<PndMvaVariable> m_vars;
281 
283  std::vector<std::pair<std::string, std::vector<float> *>> m_events;
284 
286  std::map<std::string, std::vector<float> *> m_ClassCondMeans;
287 
288  // PCA transformation.
289  PndMvaVarPCATransform m_PCA;
290 
291  // If PCA was applied.
292  bool m_UsePCA;
293 
294  // Normalization scheme
295  NormType m_NormType;
296 
297  // Application type.
298  AppType m_AppType;
299  bool m_trim;
300  size_t m_RND_seed;
301 };
302 // End of class interface definition.
303 
304 // ============= Inline implementation ==================
305 
306 inline size_t PndMvaDataSet::GetRndSeed() const
307 {
308  return this->m_RND_seed;
309 };
310 
311 inline void PndMvaDataSet::SetRndSeed(size_t const sd)
312 {
313  this->m_RND_seed = sd;
314 };
315 
316 inline std::vector<std::pair<std::string, std::vector<float> *>> const &PndMvaDataSet::GetData() const
317 {
318  assert(m_events.size() != 0);
319  return m_events;
320 };
321 
322 inline std::vector<PndMvaClass> const &PndMvaDataSet::GetClasses() const
323 {
324  return m_classes;
325 };
326 
327 inline std::vector<PndMvaVariable> const &PndMvaDataSet::GetVars() const
328 {
329  return m_vars;
330 };
331 
332 inline std::map<std::string, std::vector<float> *> const &PndMvaDataSet::GetClassCondMeans() const
333 {
334  return m_ClassCondMeans;
335 };
336 
337 inline std::string const &PndMvaDataSet::GetInFileName() const
338 {
339  return m_input;
340 };
341 
342 inline bool PndMvaDataSet::Used_PCA() const
343 {
344  return m_UsePCA;
345 };
346 inline void PndMvaDataSet::Use_PCA(bool t)
347 {
348  m_UsePCA = t;
349 };
351 {
352  return m_PCA;
353 };
355 {
356  return m_NormType;
357 };
359 {
360  m_NormType = t;
361 };
363 {
364  return m_AppType;
365 };
367 {
368  m_AppType = t;
369 };
370 inline void PndMvaDataSet::SetTrim(bool t)
371 {
372  m_trim = t;
373 };
374 #endif
std::map< std::string, std::vector< float > * > const & GetClassCondMeans() const
Get classconditional means for all classes (labels).
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
AppType GetAppType() const
std::string const & GetInFileName() const
Get name of input file name (weight/event file).
bool Used_PCA() const
void SetNormType(NormType t)
PndMvaDataSetException(std::string const &val)
Definition: PndMvaDataSet.h:62
void SetAppType(AppType t)
PndMvaVarPCATransform const & Get_PCA() const
void SetRndSeed(size_t const sd)
NormType
Definition: PndMvaDataSet.h:48
void Use_PCA(bool t)
std::vector< std::pair< std::string, std::vector< float > * > > const & GetData() const
Get available data vectors.
void SetTrim(bool t)
virtual std::string const & what()
Definition: PndMvaDataSet.h:68
virtual char const * what() const
Definition: PndMvaDataSet.h:66
virtual ~PndMvaDataSetException()
Definition: PndMvaDataSet.h:64
size_t GetRndSeed() const
NormType GetNormType() const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
AppType
Definition: PndMvaDataSet.h:38