Skip to content

Commit 06d2190

Browse files
Tests for GAN
1 parent 16d2e75 commit 06d2190

File tree

8 files changed

+844
-257
lines changed

8 files changed

+844
-257
lines changed

tmva/tmva/inc/TMVA/MethodGAN.h

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// @(#)root/tmva/tmva/dnn:$Id$
2-
// Author: Vladimir Ilievski, Saurav Shekhar
2+
// Author: Anushree Rankawat
33

44
/**********************************************************************************
55
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
@@ -11,8 +11,7 @@
1111
* Generative Adversarial Networks *
1212
* *
1313
* Authors (alphabetical): *
14-
* Vladimir Ilievski <[email protected]> - CERN, Switzerland *
15-
* Saurav Shekhar <[email protected]> - ETH Zurich, Switzerland *
14+
* Anushree Rankawat <[email protected]> *
1615
* *
1716
* Copyright (c) 2005-2015: *
1817
* CERN, Switzerland *
@@ -52,6 +51,13 @@
5251
#include "TMVA/DNN/Architectures/Cuda.h"
5352
#endif
5453

54+
#ifdef R__HAS_TMVACPU
55+
using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
56+
#else
57+
using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
58+
#endif
59+
using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
60+
5561
#include "TMVA/DNN/Architectures/Reference.h"
5662
#include "TMVA/DNN/Functions.h"
5763
#include "TMVA/DNN/DeepNet.h"
@@ -65,8 +71,6 @@ using namespace TMVA::DNN;
6571
using Architecture_t = TCpu<Double_t>;
6672
using Scalar_t = Architecture_t::Scalar_t;
6773
using DeepNet_t = TMVA::DNN::TDeepNet<Architecture_t>;
68-
//using Matrix_t = typename TCpu<double>::Matrix_t;
69-
//using TensorInput = std::tuple<const std::vector<Matrix_t> &, const Matrix_t &, const Matrix_t &>;
7074
using TensorDataLoader_t = TTensorDataLoader<TMVAInput_t, Architecture_t>;
7175

7276
using TMVA::DNN::EActivationFunction;
@@ -106,13 +110,6 @@ class MethodGAN : public MethodBase {
106110
private:
107111
// Key-Value vector type, contining the values for the training options
108112
using KeyValueVector_t = std::vector<std::map<TString, TString>>;
109-
//using TensorInput = std::tuple<const std::vector<TMatrixT<Double_t>> &>;
110-
#ifdef R__HAS_TMVACPU
111-
using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
112-
#else
113-
using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
114-
#endif
115-
using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
116113
std::unique_ptr<DeepNetImpl_t> generatorFNet, discriminatorFNet, combinedFNet;
117114
using Matrix_t = typename ArchitectureImpl_t::Matrix_t;
118115

@@ -133,7 +130,7 @@ class MethodGAN : public MethodBase {
133130
* a reference in the function. */
134131
template <typename Architecture_t, typename Layer_t>
135132
void CreateDeepNet(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
136-
std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &fNet, TString layoutString);
133+
std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &modelNet, TString layoutString);
137134

138135
size_t fGeneratorInputDepth; ///< The depth of the input of the generator.
139136
size_t fGeneratorInputHeight; ///< The height of the input of the generator.
@@ -197,20 +194,22 @@ class MethodGAN : public MethodBase {
197194
void Train();
198195

199196
Double_t GetMvaValue(Double_t *err = 0, Double_t *errUpper = 0);
200-
Double_t GetMvaValueGAN(std::unique_ptr<DeepNetImpl_t> & fNet, Double_t *err = 0, Double_t *errUpper = 0);
201-
197+
Double_t GetMvaValueGAN(std::unique_ptr<DeepNetImpl_t> & modelNet, Double_t *err = 0, Double_t *errUpper = 0);
202198
void CreateNoisyMatrices(std::vector<TMatrixT<Double_t>> &inputTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, DeepNet_t &DeepNet, size_t nSamples, size_t classLabel);
203199
Double_t ComputeLoss(TTensorDataLoader<TensorInput, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
204200
Double_t ComputeLoss(TTensorDataLoader<TMVAInput_t, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
205-
void CreateDiscriminatorFakeData(std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, size_t nSamples, size_t classLabel);
206-
void CombineGAN(DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet);
207-
208-
//void AddWeightsXMLToGAN(std::unique_ptr<DeepNetImpl_t> & fNet, void * parent);
201+
void CreateDiscriminatorFakeData(std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, EOutputFunction outputFunction, size_t nSamples, size_t classLabel, size_t epochs);
202+
void CombineGAN(DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet, std::unique_ptr<DeepNetImpl_t> & combinedNet);
203+
void SetDiscriminatorLayerTraining(DeepNet_t &discrimatorNet);
209204

210205
/*! Methods for writing and reading weights */
211206
using MethodBase::ReadWeightsFromStream;
212207
void AddWeightsXMLTo(void *parent) const;
208+
void AddWeightsXMLToGenerator(void *parent) const;
209+
void AddWeightsXMLToDiscriminator(void *parent) const;
213210
void ReadWeightsFromXML(void *wghtnode);
211+
void ReadWeightsFromXMLGenerator(void *rootXML);
212+
void ReadWeightsFromXMLDiscriminator(void *rootXML);
214213
void ReadWeightsFromStream(std::istream &);
215214

216215
/* Create ranking */

tmva/tmva/src/MethodBase.cxx

+4-4
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,8 @@ void TMVA::MethodBase::AddRegressionOutput(Types::ETreeType type)
764764
regRes->Resize( nEvents );
765765

766766
// Drawing the progress bar every event was causing a huge slowdown in the evaluation time
767-
// So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
768-
767+
// So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
768+
769769
Int_t totalProgressDraws = 100; // total number of times to update the progress bar
770770
Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
771771
if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
@@ -1570,7 +1570,7 @@ void TMVA::MethodBase::ReadStateFromXML( void* methodNode )
15701570
fMVAPdfB->ReadXML(pdfnode);
15711571
}
15721572
}
1573-
else if (nodeName=="Weights") {
1573+
else if (nodeName.SubString("Weights")=="Weights") {
15741574
ReadWeightsFromXML(ch);
15751575
}
15761576
else {
@@ -1994,7 +1994,7 @@ TDirectory* TMVA::MethodBase::BaseDir() const
19941994
sdir = methodDir->mkdir(defaultDir);
19951995
sdir->cd();
19961996
// write weight file name into target file
1997-
if (fModelPersistence) {
1997+
if (fModelPersistence) {
19981998
TObjString wfilePath( gSystem->WorkingDirectory() );
19991999
TObjString wfileName( GetWeightFileName() );
20002000
wfilePath.Write( "TrainingPath" );

0 commit comments

Comments
 (0)