Skip to content

Commit 57396ba

Browse files
committed
Add comments, remove unused variables and clean up code
1 parent 1b14def commit 57396ba

File tree

4 files changed

+20
-31
lines changed

4 files changed

+20
-31
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(
227227
block_size,
228228
batch_size,
229229
self.given_columns,
230-
self.num_columns,
231230
max_vec_sizes_list,
232231
vec_padding,
233232
validation_split,
@@ -477,6 +476,7 @@ def GetValidationBatch(self) -> Any:
477476
class LoadingThreadContext:
478477
def __init__(self, base_generator: BaseGenerator):
479478
self.base_generator = base_generator
479+
# create training batches from the first chunk
480480
self.base_generator.CreateTrainBatches();
481481

482482
def __enter__(self):
@@ -537,7 +537,6 @@ def last_batch_no_of_rows(self) -> int:
537537
return self.base_generator.generator.TrainRemainderRows()
538538

539539
def __iter__(self):
540-
# batch = self.base_generator.GetTrainBatch()
541540

542541
self._callable = self.__call__()
543542

@@ -546,7 +545,6 @@ def __iter__(self):
546545
def __next__(self):
547546
batch = self._callable.__next__()
548547

549-
# self.base_generator.ActivateTrainingEpoch()
550548
if batch is None:
551549
raise StopIteration
552550

@@ -558,14 +556,11 @@ def __call__(self) -> Any:
558556
Yields:
559557
Union[np.NDArray, torch.Tensor]: A batch of data
560558
"""
561-
# if (not self.base_generator.is_training_active):
562-
# self.base_generator.DeActivateTrainingEpoch()
563559

564560
with LoadingThreadContext(self.base_generator):
565561
while True:
566562
batch = self.base_generator.GetTrainBatch()
567563
if batch is None:
568-
# self.base_generator.DeActivateTrainingEpoch()
569564
break
570565
yield self.conversion_function(batch)
571566

@@ -574,6 +569,7 @@ def __call__(self) -> Any:
574569
class LoadingThreadContextVal:
575570
def __init__(self, base_generator: BaseGenerator):
576571
self.base_generator = base_generator
572+
# create validation batches from the first chunk
577573
self.base_generator.CreateValidationBatches()
578574

579575
def __enter__(self):
@@ -639,7 +635,7 @@ def __next__(self):
639635
return batch
640636

641637
def __call__(self) -> Any:
642-
"""Start the loading of batches and Yield the results
638+
"""Start the loading of batches and yield the results
643639
644640
Yields:
645641
Union[np.NDArray, torch.Tensor]: A batch of data

bindings/pyroot/pythonizations/test/rbatchgenerator_completeness.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import unittest
22
import os
33
import ROOT
4-
# from ROOT import RVec
4+
from ROOT import RVec
55
import numpy as np
66
from random import randrange
7-
ROOT.gInterpreter.Declare("#include <ROOT/RVec.hxx>")
87

98
class RBatchGeneratorMultipleFiles(unittest.TestCase):
109

tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private:
108108

109109
public:
110110
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
111-
const std::size_t batchSize, const std::vector<std::string> &cols, const std::size_t numColumns,
111+
const std::size_t batchSize, const std::vector<std::string> &cols,
112112
const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
113113
const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
114114
bool dropRemainder = true, const std::size_t setSeed = 0)
@@ -149,9 +149,10 @@ public:
149149
// split the dataset into training and validation sets
150150
fChunkLoader->SplitDataset();
151151

152-
fNumTrainingEntries = fChunkLoader->GetNumTrainingEntries();
153-
fNumValidationEntries = fChunkLoader->GetNumValidationEntries();
154-
152+
// number of training and validation entries after the split
153+
fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
154+
fNumTrainingEntries = fNumEntries - fNumValidationEntries;
155+
155156
fLeftoverTrainingBatchSize = fNumTrainingEntries % fBatchSize;
156157
fLeftoverValidationBatchSize = fNumValidationEntries % fBatchSize;
157158

@@ -171,6 +172,7 @@ public:
171172
fNumValidationBatches = fNumFullValidationBatches + fNumLeftoverValidationBatches;
172173
}
173174

175+
// number of training and validation chunks, calculated in RChunkConstructor
174176
fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
175177
fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
176178

@@ -228,8 +230,6 @@ public:
228230
void CreateTrainBatches()
229231
{
230232

231-
auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
232-
233233
fChunkLoader->CreateTrainingChunksIntervals();
234234
fTrainingEpochActive = true;
235235
fTrainingChunkNum = 0;
@@ -244,8 +244,6 @@ public:
244244
void CreateValidationBatches()
245245
{
246246

247-
auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
248-
249247
fChunkLoader->CreateValidationChunksIntervals();
250248
fValidationEpochActive = true;
251249
fValidationChunkNum = 0;

tmva/tmva/inc/TMVA/BatchGenerator/RChunkConstructor.hxx

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2-
// Author: Kristupas Pranckietis, Vilnius University 05/2024
3-
// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4-
// Author: Vincenzo Eduardo Padulano, CERN 10/2024
51
// Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
62

73
/*************************************************************************
@@ -55,9 +51,9 @@ The blocks are defined by their start and end entries, which correspond to posit
5551

5652
struct RChunkConstructor {
5753
// clang-format on
58-
std::size_t fNumEntries;
59-
std::size_t fChunkSize;
60-
std::size_t fBlockSize;
54+
std::size_t fNumEntries{};
55+
std::size_t fChunkSize{};
56+
std::size_t fBlockSize{};
6157

6258
// size of full and leftover chunks
6359
std::size_t SizeOfFullChunk;
@@ -102,17 +98,17 @@ struct RChunkConstructor {
10298
std::size_t NumberOfBlocks;
10399

104100
// pair of start and end entries in the different block types
105-
std::vector<std::pair<Long_t, Long_t>> BlockIntervals = {};
101+
std::vector<std::pair<Long_t, Long_t>> BlockIntervals;
106102

107-
std::vector<std::pair<Long_t, Long_t>> FullBlockIntervalsInFullChunks = {};
108-
std::vector<std::pair<Long_t, Long_t>> LeftoverBlockIntervalsInFullChunks = {};
103+
std::vector<std::pair<Long_t, Long_t>> FullBlockIntervalsInFullChunks;
104+
std::vector<std::pair<Long_t, Long_t>> LeftoverBlockIntervalsInFullChunks;
109105

110-
std::vector<std::pair<Long_t, Long_t>> FullBlockIntervalsInLeftoverChunks = {};
111-
std::vector<std::pair<Long_t, Long_t>> LeftoverBlockIntervalsInLeftoverChunks = {};
106+
std::vector<std::pair<Long_t, Long_t>> FullBlockIntervalsInLeftoverChunks;
107+
std::vector<std::pair<Long_t, Long_t>> LeftoverBlockIntervalsInLeftoverChunks;
112108

113-
std::vector<std::vector<std::pair<Long_t, Long_t>>> ChunksIntervals = {};
109+
std::vector<std::vector<std::pair<Long_t, Long_t>>> ChunksIntervals;
114110

115-
std::vector<std::size_t> ChunksSizes = {};
111+
std::vector<std::size_t> ChunksSizes;
116112

117113
RChunkConstructor(const std::size_t numEntries, const std::size_t chunkSize, const std::size_t blockSize)
118114
: fNumEntries(numEntries), fChunkSize(chunkSize), fBlockSize(blockSize)

0 commit comments

Comments
 (0)