1
1
// @(#)root/tmva/tmva/dnn:$Id$
2
- // Author: Vladimir Ilievski, Saurav Shekhar
2
+ // Author: Anushree Rankawat
3
3
4
4
/* *********************************************************************************
5
5
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
11
11
* Generative Adversarial Networks *
12
12
* *
13
13
* Authors (alphabetical): *
14
- * Vladimir Ilievski <[email protected] > - CERN, Switzerland *
15
- * Saurav Shekhar <[email protected] > - ETH Zurich, Switzerland *
14
+ * Anushree Rankawat <[email protected] > *
16
15
* *
17
16
* Copyright (c) 2005-2015: *
18
17
* CERN, Switzerland *
52
51
#include " TMVA/DNN/Architectures/Cuda.h"
53
52
#endif
54
53
54
+ #ifdef R__HAS_TMVACPU
55
+ using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
56
+ #else
57
+ using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
58
+ #endif
59
+ using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
60
+
55
61
#include " TMVA/DNN/Architectures/Reference.h"
56
62
#include " TMVA/DNN/Functions.h"
57
63
#include " TMVA/DNN/DeepNet.h"
@@ -65,8 +71,6 @@ using namespace TMVA::DNN;
65
71
using Architecture_t = TCpu<Double_t>;
66
72
using Scalar_t = Architecture_t::Scalar_t;
67
73
using DeepNet_t = TMVA::DNN::TDeepNet<Architecture_t>;
68
- // using Matrix_t = typename TCpu<double>::Matrix_t;
69
- // using TensorInput = std::tuple<const std::vector<Matrix_t> &, const Matrix_t &, const Matrix_t &>;
70
74
using TensorDataLoader_t = TTensorDataLoader<TMVAInput_t, Architecture_t>;
71
75
72
76
using TMVA::DNN::EActivationFunction;
@@ -106,13 +110,6 @@ class MethodGAN : public MethodBase {
106
110
private:
107
111
// Key-Value vector type, contining the values for the training options
108
112
using KeyValueVector_t = std::vector<std::map<TString, TString>>;
109
- // using TensorInput = std::tuple<const std::vector<TMatrixT<Double_t>> &>;
110
- #ifdef R__HAS_TMVACPU
111
- using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
112
- #else
113
- using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
114
- #endif
115
- using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
116
113
std::unique_ptr<DeepNetImpl_t> generatorFNet, discriminatorFNet, combinedFNet;
117
114
using Matrix_t = typename ArchitectureImpl_t::Matrix_t;
118
115
@@ -133,7 +130,7 @@ class MethodGAN : public MethodBase {
133
130
* a reference in the function. */
134
131
template <typename Architecture_t, typename Layer_t>
135
132
void CreateDeepNet (DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
136
- std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &fNet , TString layoutString);
133
+ std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &modelNet , TString layoutString);
137
134
138
135
size_t fGeneratorInputDepth ; // /< The depth of the input of the generator.
139
136
size_t fGeneratorInputHeight ; // /< The height of the input of the generator.
@@ -197,20 +194,22 @@ class MethodGAN : public MethodBase {
197
194
void Train ();
198
195
199
196
Double_t GetMvaValue (Double_t *err = 0 , Double_t *errUpper = 0 );
200
- Double_t GetMvaValueGAN (std::unique_ptr<DeepNetImpl_t> & fNet , Double_t *err = 0 , Double_t *errUpper = 0 );
201
-
197
+ Double_t GetMvaValueGAN (std::unique_ptr<DeepNetImpl_t> & modelNet, Double_t *err = 0 , Double_t *errUpper = 0 );
202
198
void CreateNoisyMatrices (std::vector<TMatrixT<Double_t>> &inputTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, DeepNet_t &DeepNet, size_t nSamples, size_t classLabel);
203
199
Double_t ComputeLoss (TTensorDataLoader<TensorInput, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
204
200
Double_t ComputeLoss (TTensorDataLoader<TMVAInput_t, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
205
- void CreateDiscriminatorFakeData (std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, size_t nSamples, size_t classLabel);
206
- void CombineGAN (DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet);
207
-
208
- // void AddWeightsXMLToGAN(std::unique_ptr<DeepNetImpl_t> & fNet, void * parent);
201
+ void CreateDiscriminatorFakeData (std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, EOutputFunction outputFunction, size_t nSamples, size_t classLabel, size_t epochs);
202
+ void CombineGAN (DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet, std::unique_ptr<DeepNetImpl_t> & combinedNet);
203
+ void SetDiscriminatorLayerTraining (DeepNet_t &discrimatorNet);
209
204
210
205
/* ! Methods for writing and reading weights */
211
206
using MethodBase::ReadWeightsFromStream;
212
207
void AddWeightsXMLTo (void *parent) const ;
208
+ void AddWeightsXMLToGenerator (void *parent) const ;
209
+ void AddWeightsXMLToDiscriminator (void *parent) const ;
213
210
void ReadWeightsFromXML (void *wghtnode);
211
+ void ReadWeightsFromXMLGenerator (void *rootXML);
212
+ void ReadWeightsFromXMLDiscriminator (void *rootXML);
214
213
void ReadWeightsFromStream (std::istream &);
215
214
216
215
/* Create ranking */
0 commit comments