PandaRoot
PndKnnClassify.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /* ************************************
14  * Author: M. Babai (M.Babai@rug.nl) *
15  * *
16  * KNN based pid classifier *
17  * *
18  * Modified: *
19  * *
20  * ************************************/
21 //#pragma once
22 #ifndef PND_KNN_CLASSIFY_H
23 #define PND_KNN_CLASSIFY_H
24 
25 // LOCAL includes
26 #include "PndMvaClassifier.h"
27 
28 // TMVA
29 #include "TMVA/NodekNN.h"
30 #include "TMVA/ModulekNN.h"
31 
32 //____________________________________________
34 // typedef std::list < std::pair<const TMVA::kNN::Node<TMVA::kNN::Event>*, float> > ResList;
35 //____________________________________________
36 
41  public:
49  explicit PndKnnClassify(std::string const &inputFile, std::vector<std::string> const &classNames, std::vector<std::string> const &varNames);
51  virtual ~PndKnnClassify();
52 
58  void GetMvaValues(std::vector<float> eventData, std::map<std::string, float> &result);
64  std::string *Classify(std::vector<float> EvtData);
65 
71  inline void SetEvtParam(float const scFact, double const weight);
72 
74  inline void SetKnn(size_t const N);
75 
77  inline size_t GetKnn();
78 
82  virtual void Initialize();
83 
85  void print() { m_module->Print(); }
87 
88  // ================== Private ===============
89  private:
90  // To avoid mistakes. ;)
91  PndKnnClassify(PndKnnClassify const &other);
92  PndKnnClassify &operator=(PndKnnClassify const &other);
93 
97  void InitKNN();
98 
100  typedef std::list<std::pair<const TMVA::kNN::Node<TMVA::kNN::Event> *, float>> ResList;
101 
102  size_t m_knn;
103  float m_ScaleFact;
104  double m_weight;
105  TMVA::kNN::ModulekNN *m_module; // TMVA Knn module.
106 
110  std::map<std::string, size_t> m_classIndices;
111 }; // End of classifier interface definition
112 
113 //___________________ Inline implementation. __________________________________
114 inline void PndKnnClassify::SetEvtParam(float const scFact, double const weight)
115 {
116  m_ScaleFact = scFact;
117  m_weight = weight;
118 };
119 
120 inline void PndKnnClassify::SetKnn(size_t const N)
121 {
122  m_knn = N;
123 };
124 
125 inline size_t PndKnnClassify::GetKnn()
126 {
127  return m_knn;
128 };
129 #endif
void GetMvaValues(std::vector< float > eventData, std::map< std::string, float > &result)
!< Type definition of the neighbour list.
virtual ~PndKnnClassify()
Destructor.
PndKnnClassify(std::string const &inputFile, std::vector< std::string > const &classNames, std::vector< std::string > const &varNames)
void SetEvtParam(float const scFact, double const weight)
void SetKnn(size_t const N)
Set the number of neighbours.
std::string * Classify(std::vector< float > EvtData)
void print()
DEBUG Produces a lot of output.
virtual void Initialize()
size_t GetKnn()
Get the number of neighbours.