PandaRoot
PndMvaDataSet.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /***************************************
14  * Class interface of DataSet class. *
15  * Author: M.Babai (M.Babai@rug.nl) *
16  * License: *
17  * Version: *
18  ***************************************/
19 //#pragma once
20 #ifndef PND_MVA_DATASET_H
21 #define PND_MVA_DATASET_H
22 
23 // C++ includes
24 #include <iostream>
25 #include <fstream>
26 #include <string>
27 #include <vector>
28 #include <map>
29 #include <set>
30 #include <algorithm>
31 #include <cmath>
32 #include <cassert>
33 #include <limits>
34 #include <typeinfo>
35 #include <exception>
36 #include <utility>
37 
38 // ROOT
39 #include "TFile.h"
40 #include "TTree.h"
41 #include "TRandom3.h"
42 
43 // Local includes
44 #include "PndMvaClass.h"
45 #include "PndMvaVariable.h"
46 #include "PndMvaVarPCATransform.h"
47 
48 // ========================================================================
49 // Application type
50 typedef enum AppType {
51  UNKAPP = 0,
52  TRAIN = 1, // Training algorithm.
53  CLASSIFY = 2, // Read weights to do classification.
54  TMVATRAIN = 10, // Provide input for TMVA Training.
55  TMVACLS = 20, // TMVA classification.
56  PRE_INIT_EVTS = 30 // Pre-initialized event data.
57 } AppType;
58 
59 // Normalization schemes
60 typedef enum NormType {
61  NONORM = 0, // Do nothing
62  VARX = 1, // Use Sample variance
63  MINMAX = 2, // Use Sample Min and Max
64  MEDIAN = 3, // Use median and interquartile range (IQR).
65  VARNORM = 4 // Variable Normalize Transform
66 } NormType;
67 
69 
70 class PndMvaDataSetException : public std::exception {
71  public:
72  explicit PndMvaDataSetException() : m_message("UNKNOWN_MvaDataSetException"){};
73 
74  explicit PndMvaDataSetException(std::string const &val) : m_message(val){};
75 
76  virtual ~PndMvaDataSetException() throw(){};
77 
78  virtual char const *what() const throw() { return m_message.c_str(); };
79 
80  virtual std::string const &what() { return m_message; };
81 
82  private:
83  std::string m_message;
84 };
86 
87 // ==================== Data set class ==========================
89  public:
100  explicit PndMvaDataSet(std::vector<std::pair<std::string, std::vector<float> *>> const &InputEvtsParam, std::vector<std::string> const &classNames,
101  std::vector<std::string> const &varNames, AppType type);
109  explicit PndMvaDataSet(std::string const &WeightFile, std::vector<std::string> const &classNames, std::vector<std::string> const &varNames, AppType type);
110 
112  virtual ~PndMvaDataSet();
113 
118  // void WriteDataSet(std::string const& outFile) __attribute__ ((deprecated));
119  virtual void WriteDataSet(std::string const &outFile);
120 
130  virtual void InitClsCondMeans(std::set<size_t> const &excludeIndxs);
131 
136  inline void SetTrim(bool t);
137 
139  inline std::vector<std::pair<std::string, std::vector<float> *>> const &GetData() const;
140 
142  inline std::vector<PndMvaClass> const &GetClasses() const;
143 
145  inline std::vector<PndMvaVariable> const &GetVars() const;
146 
148  inline std::map<std::string, std::vector<float> *> const &GetClassCondMeans() const;
149 
151  inline std::string const &GetInFileName() const;
152 
153  //========================= PCA =====================//
159  virtual void PCATransForm();
160 
164  inline bool Used_PCA() const;
165 
169  inline void Use_PCA(bool t);
170 
175  inline PndMvaVarPCATransform const &Get_PCA() const;
176 
177  //_________________________ PCA _____________________//
178 
182  inline NormType GetNormType() const;
183 
187  inline void SetNormType(NormType t);
188 
192  inline AppType GetAppType() const;
193 
196  inline void SetAppType(AppType t);
197 
202  virtual void Initialize();
203 
204  inline size_t GetRndSeed() const;
205  inline void SetRndSeed(size_t const sd);
206 
207  //______________________________________________________________
208  protected:
212  void ReadInput();
213 
217  void ReadWeightsFromFile();
218 
219  //==============================================================
220  private:
221  // Private to avoid mistakes.
222  PndMvaDataSet(PndMvaDataSet const &other);
223  PndMvaDataSet &operator=(PndMvaDataSet const &other);
224 
228  void Trim();
229 
233  void NormalizeDataSet();
234 
239  void InitClasses(std::vector<std::string> const &labels);
240 
245  void InitVariables(std::vector<std::string> const &variables);
246 
247  // Validate the input file
248  void ValidateWeightFile();
249 
256  void CompClsCondMean(std::string const &clsName, std::set<size_t> const &exCluds);
257 
262  void ComputeVariance();
263 
267  void DetermineMedian();
268 
272  void MinMaxDiff();
273 
277  void FindMinMax();
278 
282  void VarNormalize();
283 
284  // __________________________ Member parameters ___________
286  std::string m_input;
287 
289  std::vector<PndMvaClass> m_classes;
290 
292  std::vector<PndMvaVariable> m_vars;
293 
295  std::vector<std::pair<std::string, std::vector<float> *>> m_events;
296 
298  std::map<std::string, std::vector<float> *> m_ClassCondMeans;
299 
300  // PCA transformation.
301  PndMvaVarPCATransform m_PCA;
302 
303  // If PCA was applied.
304  bool m_UsePCA;
305 
306  // Normalization scheme
307  NormType m_NormType;
308 
309  // Application type.
310  AppType m_AppType;
311  bool m_trim;
312  size_t m_RND_seed;
313 };
314 // End of class interface definition.
315 
316 // ============= Inline implementation ==================
317 
318 inline size_t PndMvaDataSet::GetRndSeed() const
319 {
320  return this->m_RND_seed;
321 };
322 
323 inline void PndMvaDataSet::SetRndSeed(size_t const sd)
324 {
325  this->m_RND_seed = sd;
326 };
327 
328 inline std::vector<std::pair<std::string, std::vector<float> *>> const &PndMvaDataSet::GetData() const
329 {
330  assert(m_events.size() != 0);
331  return m_events;
332 };
333 
334 inline std::vector<PndMvaClass> const &PndMvaDataSet::GetClasses() const
335 {
336  return m_classes;
337 };
338 
339 inline std::vector<PndMvaVariable> const &PndMvaDataSet::GetVars() const
340 {
341  return m_vars;
342 };
343 
344 inline std::map<std::string, std::vector<float> *> const &PndMvaDataSet::GetClassCondMeans() const
345 {
346  return m_ClassCondMeans;
347 };
348 
349 inline std::string const &PndMvaDataSet::GetInFileName() const
350 {
351  return m_input;
352 };
353 
354 inline bool PndMvaDataSet::Used_PCA() const
355 {
356  return m_UsePCA;
357 };
358 inline void PndMvaDataSet::Use_PCA(bool t)
359 {
360  m_UsePCA = t;
361 };
363 {
364  return m_PCA;
365 };
367 {
368  return m_NormType;
369 };
371 {
372  m_NormType = t;
373 };
375 {
376  return m_AppType;
377 };
379 {
380  m_AppType = t;
381 };
382 inline void PndMvaDataSet::SetTrim(bool t)
383 {
384  m_trim = t;
385 };
386 #endif
std::map< std::string, std::vector< float > * > const & GetClassCondMeans() const
Get classconditional means for all classes (labels).
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
AppType GetAppType() const
std::string const & GetInFileName() const
Get name of input file name (weight/event file).
bool Used_PCA() const
void SetNormType(NormType t)
PndMvaDataSetException(std::string const &val)
Definition: PndMvaDataSet.h:74
void SetAppType(AppType t)
PndMvaVarPCATransform const & Get_PCA() const
void SetRndSeed(size_t const sd)
NormType
Definition: PndMvaDataSet.h:60
void Use_PCA(bool t)
std::vector< std::pair< std::string, std::vector< float > * > > const & GetData() const
Get available data vectors.
void SetTrim(bool t)
virtual std::string const & what()
Definition: PndMvaDataSet.h:80
virtual char const * what() const
Definition: PndMvaDataSet.h:78
virtual ~PndMvaDataSetException()
Definition: PndMvaDataSet.h:76
size_t GetRndSeed() const
NormType GetNormType() const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
AppType
Definition: PndMvaDataSet.h:50