Skip to content

Commit f99b00c

Browse files
IlievskiVlmoneta
authored andcommitted
test: Add test for Method DL, for the DNN case
1 parent a94c617 commit f99b00c

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -510,4 +510,5 @@ tags
510510

511511
# TMVA datasets
512512
/tmva/tmva/test/DNN/CNN/dataset/SingleElectronPt50_FEVTDEBUG_n250k_IMG_CROPS32.root
513-
/tmva/tmva/test/DNN/CNN/dataset/SinglePhotonPt50_FEVTDEBUG_n250k_IMG_CROPS32.root
513+
/tmva/tmva/test/DNN/CNN/dataset/SinglePhotonPt50_FEVTDEBUG_n250k_IMG_CROPS32.root
514+
/tmva/tmva/test/DNN/CNN/dataset/tmva_class_example.root

tmva/tmva/test/DNN/CNN/TestMethodDL.cxx

+1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ int main()
3333

3434
TString archCPU = "CPU";
3535
testMethodDL(archCPU);
36+
testMethodDL_DNN(archCPU);
3637
}

tmva/tmva/test/DNN/CNN/TestMethodDL.h

+109
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,113 @@ void testMethodDL(TString architectureStr)
148148
delete dataloader;
149149
}
150150

151+
/** Testing the entire pipeline of the Method DL, when only a Multilayer Percepton
152+
* is constructed. */
153+
//______________________________________________________________________________
154+
void testMethodDL_DNN(TString architectureStr)
155+
{
156+
TFile *input(0);
157+
TString fname = "/Users/vladimirilievski/Desktop/Vladimir/GSoC/ROOT-CI/common-version/root/tmva/tmva/test/DNN/CNN/"
158+
"dataset/tmva_class_example.root";
159+
input = TFile::Open(fname);
160+
161+
// Register the training and test trees
162+
TTree *signalTree = (TTree *)input->Get("TreeS");
163+
TTree *background = (TTree *)input->Get("TreeB");
164+
165+
TString outfileName("results/TMVA_DNN.root");
166+
TFile *outputFile = TFile::Open(outfileName, "RECREATE");
167+
168+
// create factory
169+
TMVA::Factory *factory =
170+
new TMVA::Factory("TMVAClassification", outputFile,
171+
"!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
172+
173+
// create dataset and add variables
174+
TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
175+
176+
dataloader->AddVariable("myvar1 := var1+var2", 'F');
177+
dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
178+
dataloader->AddVariable("var3", "Variable 3", "units", 'F');
179+
dataloader->AddVariable("var4", "Variable 4", "units", 'F');
180+
181+
dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
182+
dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
183+
184+
// Add Signal and Background Trees
185+
Double_t signalWeight = 1.0;
186+
Double_t backgroundWeight = 1.0;
187+
188+
dataloader->AddSignalTree(signalTree, signalWeight);
189+
dataloader->AddBackgroundTree(background, backgroundWeight);
190+
191+
// Prepare training and testing set
192+
dataloader->SetBackgroundWeightExpression("weight");
193+
TCut mycuts = "";
194+
TCut mycutb = "";
195+
dataloader->PrepareTrainingAndTestTree(
196+
mycuts, mycutb, "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
197+
198+
// Input Layout
199+
TString inputLayoutString("InputLayout=1|1|4");
200+
201+
// Batch Layout
202+
TString batchLayoutString("BatchLayout=1|256|4");
203+
204+
// General layout.
205+
TString layoutString("Layout=DENSE|128|TANH,DENSE|128|TANH,DENSE|128|TANH,DENSE|1|LINEAR");
206+
207+
// Training strategies.
208+
TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
209+
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
210+
"WeightDecay=1e-4,Regularization=L2,"
211+
"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True");
212+
TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
213+
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
214+
"WeightDecay=1e-4,Regularization=L2,"
215+
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
216+
TString training2("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
217+
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
218+
"WeightDecay=1e-4,Regularization=L2,"
219+
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
220+
TString trainingStrategyString("TrainingStrategy=");
221+
trainingStrategyString += training0 + "|" + training1 + "|" + training2;
222+
223+
// General Options.
224+
TString dnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:"
225+
"WeightInitialization=XAVIERUNIFORM");
226+
227+
// Concatenate all option strings
228+
dnnOptions.Append(":");
229+
dnnOptions.Append(inputLayoutString);
230+
231+
// Concatenate all option strings
232+
dnnOptions.Append(":");
233+
dnnOptions.Append(batchLayoutString);
234+
235+
dnnOptions.Append(":");
236+
dnnOptions.Append(layoutString);
237+
238+
dnnOptions.Append(":");
239+
dnnOptions.Append(trainingStrategyString);
240+
241+
dnnOptions.Append(":Architecture=");
242+
dnnOptions.Append(architectureStr);
243+
244+
TString methodTitle = "DL_CPU";
245+
factory->BookMethod(dataloader, TMVA::Types::kDL, methodTitle, dnnOptions);
246+
247+
// Train MVAs using the set of training events
248+
factory->TrainAllMethods();
249+
250+
// Save the output
251+
outputFile->Close();
252+
253+
std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
254+
std::cout << "==> TMVAClassification is done!" << std::endl;
255+
256+
delete factory;
257+
delete dataloader;
258+
}
259+
151260
#endif

0 commit comments

Comments
 (0)