PandaRoot
PndLVQTrain.h
Go to the documentation of this file.
1 //****************************************************************************
2 //* This file is part of PandaRoot. *
3 //* *
4 //* PandaRoot is distributed under the terms of the *
5 //* GNU General Public License (GPL) version 3, *
6 //* copied verbatim in the file "LICENSE". *
7 //* *
8 //* Copyright (C) 2006 - 2024 FAIR GmbH and copyright holders of PandaRoot *
9 //* The copyright holders are listed in the file "COPYRIGHTHOLDERS". *
10 //* The authors are listed in the file "AUTHORS". *
11 //****************************************************************************
12 
13 /* ***************************************
14  * LVQ Training functions *
15  * Author: M.Babai@rug.nl *
16  * Version: *
17  * LICENSE: *
18  * ***************************************
19  */
20 //#pragma once
21 #ifndef PND_LVQ_TRAIN_H
22 #define PND_LVQ_TRAIN_H
23 
24 #ifdef _OPENMP
25 #include <omp.h>
26 #endif
27 
28 // Local includes
29 #include "PndMvaTrainer.h"
30 #include "PndMvaCluster.h"
31 
32 // ____ Local CPP definitions _________
33 #define DEBUG_LVQ_TRAIN 0
34 // ____________________________________
35 
37 typedef enum ProtoInitType {
38  RAND_FROM_DATA = 0, // Select randomly from data vector.
39  CCM_PR = 1, // Random init around Class Conditional Mean.
40  KMEANS_PR = 10, // Init using K-Means clustering.
41  FILE_PR = 20 // Read pre-init from file.
43 
45 class PndLVQTrain : public PndMvaTrainer {
46  //----------------------------------------
47  //================== public ==============
48  public:
55  explicit PndLVQTrain(std::vector<std::pair<std::string, std::vector<float> *>> const &InputEvtsParam, std::vector<std::string> const &ClassNames,
56  std::vector<std::string> const &VarNames, bool trim = false);
63  explicit PndLVQTrain(std::string const &InPut, std::vector<std::string> const &ClassNames, std::vector<std::string> const &VarNames, bool trim = true);
67  virtual ~PndLVQTrain();
68 
73  void storeWeights();
74 
78  void Train();
79 
83  void Train21();
84 
89  inline void setProtoInitType(ProtoInitType iniTypeVal = RAND_FROM_DATA);
90 
96  inline void SetInitProtoFileName(std::string const &fileName);
97 
106  inline void SetLearnPrameters(double const initConst, double const etZ, double const etF, unsigned int const Nswp);
107 
114  void SetNumberOfProto(size_t const numProto);
115 
121  void SetNumberOfProto(std::map<std::string, size_t> const &labelMap);
122 
129  inline void SetErrorStepSize(unsigned int const val = 1000);
130 
135  inline void SetLVQ2_1WindowSize(float const Wsize = 0.3);
136 
140  void EvalClassifierError();
141 
149  inline void SetPerEpochEval(bool val);
150 
155  inline bool GetPerEpochEval() const;
156 
158 #if DEBUG_LVQ_TRAIN == 1
159  inline std::vector<std::pair<std::string, std::vector<float> *>> const &train1sec()
160  {
161  InitProtoK_Means();
162  return m_LVQProtos;
163  };
164 
165  inline std::vector<std::pair<std::string, std::vector<float> *>> const &train2sec()
166  {
167  InitProtoRand();
168  return m_LVQProtos;
169  };
170 
178  float EvalClassifierError(std::vector<std::pair<std::string, std::vector<float> *>> const &TestEvts) const;
179 #endif
180 
182  //----------------------------------------
183  // protected:
184  //================== private =============
185  private:
186  // To avoid mistakes, :).
187  PndLVQTrain(PndLVQTrain const &other);
188  PndLVQTrain &operator=(PndLVQTrain const &other);
189 
191  void EvalClassifierError(unsigned int stp);
192 
197  void InitProtoRand();
198 
203  void InitRandProtoFromData();
204 
208  void InitProtoK_Means();
209 
213  void InitProtoTypes();
214 
218  void cleanProtoList();
219 
223  void UpdateProto(std::vector<float> const &EvtData, std::vector<float> &proto, int const delta, double const ethaT);
224 
231  void ValidateProtoUpdate(std::vector<float> &p);
232 
237  void ReadProtoFromFile();
238  //=====================================
239 
241  std::vector<std::pair<std::string, std::vector<float> *>> m_LVQProtos;
242 
244  std::vector<PndMvaDistObj> m_distances;
245 
250  double m_initConst;
251  double m_ethaZero;
252  double m_ethaFinal;
253  float m_WindowSize;
254 
256  unsigned int m_NumSweep;
257 
259  ProtoInitType m_proto_init;
260 
262  std::string m_initProtoFile;
263 
265  unsigned int m_ErrorStep;
266  unsigned int m_ProgStep;
267 
269  std::map<std::string, size_t> m_numProtoPerClass;
270 
272  bool m_PerEpoch;
273 };
274 // END Interface definition
275 
276 //_______________________ Inline functions _________________________
278 {
279  m_proto_init = iniTypeVal;
280 };
281 
282 inline void PndLVQTrain::SetInitProtoFileName(std::string const &fileName)
283 {
284  m_initProtoFile = fileName;
285 };
286 
287 inline void PndLVQTrain::SetLearnPrameters(double const initConst, double const etZ, double const etF, unsigned int const Nswp)
288 {
289  m_initConst = initConst;
290  m_ethaZero = etZ;
291  m_ethaFinal = etF;
292  m_NumSweep = Nswp;
293 };
294 
295 inline void PndLVQTrain::SetErrorStepSize(unsigned int const val)
296 {
297  m_ErrorStep = val;
298 };
299 
300 inline void PndLVQTrain::SetLVQ2_1WindowSize(float const Wsize)
301 {
302  m_WindowSize = Wsize;
303 };
304 
305 inline void PndLVQTrain::SetPerEpochEval(bool val)
306 {
307  m_PerEpoch = val;
308 };
309 
310 inline bool PndLVQTrain::GetPerEpochEval() const
311 {
312  return m_PerEpoch;
313 };
314 #endif // End of interface definition
void SetPerEpochEval(bool val)
Definition: PndLVQTrain.h:305
bool GetPerEpochEval() const
Definition: PndLVQTrain.h:310
void storeWeights()
void Train21()
void Train()
virtual ~PndLVQTrain()
ProtoInitType
How to initialize LVQ code books.
Definition: PndLVQTrain.h:37
void SetLVQ2_1WindowSize(float const Wsize=0.3)
Definition: PndLVQTrain.h:300
void EvalClassifierError()
PndLVQTrain(std::vector< std::pair< std::string, std::vector< float > *>> const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=false)
Interface definition for LVQ trainers.
Definition: PndLVQTrain.h:45
void SetLearnPrameters(double const initConst, double const etZ, double const etF, unsigned int const Nswp)
Definition: PndLVQTrain.h:287
void setProtoInitType(ProtoInitType iniTypeVal=RAND_FROM_DATA)
Definition: PndLVQTrain.h:277
void SetNumberOfProto(size_t const numProto)
void SetInitProtoFileName(std::string const &fileName)
Definition: PndLVQTrain.h:282
void SetErrorStepSize(unsigned int const val=1000)
Definition: PndLVQTrain.h:295