@@ -31,8 +31,18 @@ namespace TMVA {
3131namespace Experimental {
3232namespace Internal {
3333
34+ // clang-format off
35+ /* *
36+ \class ROOT::TMVA::Experimental::Internal::RBatchLoader
37+ \ingroup tmva
38+ \brief Building and loading the batches from loaded chunks in RChunkLoader
39+
40+ In this class the chunks that are loaded into memory (see RChunkLoader) are split into batches used in the ML training which are loaded into a queue. This is done for both the training and validation chunks separatly.
41+ */
42+
3443class RBatchLoader {
3544private:
45+ // clang-format on
3646 std::size_t fChunkSize ;
3747 std::size_t fBatchSize ;
3848 std::size_t fNumColumns ;
@@ -45,14 +55,18 @@ private:
4555 std::mutex fBatchLock ;
4656 std::condition_variable fBatchCondition ;
4757
58+ // queuse of tensors of the training and validation batches
4859 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float >>> fTrainingBatchQueue ;
4960 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float >>> fValidationBatchQueue ;
5061
62+ // number of training and validation batches in the queue
5163 std::size_t fNumTrainingBatchQueue ;
5264 std::size_t fNumValidationBatchQueue ;
5365
66+ // current batch that is loaded into memeory
5467 std::unique_ptr<TMVA::Experimental::RTensor<float >> fCurrentBatch ;
5568
69+ // primary and secondary batches used to create batches from a chunk
5670 std::unique_ptr<TMVA::Experimental::RTensor<float >> fPrimaryLeftoverTrainingBatch ;
5771 std::unique_ptr<TMVA::Experimental::RTensor<float >> fSecondaryLeftoverTrainingBatch ;
5872
@@ -104,6 +118,8 @@ public:
104118
105119 // / \brief Return a batch of data as a unique pointer.
106120 // / After the batch has been processed, it should be destroyed.
121+ // / \param[in] chunkTensor RTensor with the data from the chunk
122+ // / \param[in] idxs Index of batch in the chunk
107123 // / \return Training batch
108124 std::unique_ptr<TMVA::Experimental::RTensor<float >>
109125 CreateBatch (TMVA::Experimental::RTensor<float > &chunkTensor, std::size_t idxs)
@@ -116,6 +132,9 @@ public:
116132 return batch;
117133 }
118134
135+
136+ // / \brief Loading the training batch from the queue
137+ // / \return Training batch
119138 TMVA::Experimental::RTensor<float > GetTrainBatch ()
120139 {
121140
@@ -130,6 +149,8 @@ public:
130149 return *fCurrentBatch ;
131150 }
132151
152+ // / \brief Loading the validation batch from the queue
153+ // / \return Training batch
133154 TMVA::Experimental::RTensor<float > GetValidationBatch ()
134155 {
135156
@@ -144,67 +165,89 @@ public:
144165 return *fCurrentBatch ;
145166 }
146167
168+ // / \brief Creating the training batches from a chunk and add them to the queue.
169+ // / \param[in] chunkTensor RTensor with the data from the chunk
170+ // / \param[in] lastbatch Check if the batch in the chunk is the last one
171+ // / \param[in] leftoverBatchSize Size of the leftover batch in the training dataset
172+ // / \param[in] dromRemainder Bool to drop the remainder batch or not
147173 void CreateTrainingBatches (TMVA::Experimental::RTensor<float > &chunkTensor, int lastbatch,
148174 std::size_t leftoverBatchSize, bool dropRemainder)
149175 {
150176 std::size_t ChunkSize = chunkTensor.GetShape ()[0 ];
151177 std::size_t Batches = ChunkSize / fBatchSize ;
152178 std::size_t LeftoverBatchSize = ChunkSize % fBatchSize ;
153179
180+ // create a vector of batches
154181 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float >>> batches;
155182
183+ // fill the full batches from the chunk into a vector
156184 for (std::size_t i = 0 ; i < Batches; i++) {
157185 // Fill a batch
158186 batches.emplace_back (CreateBatch (chunkTensor, i));
159187 }
160188
189+ // copy the remaining entries from the chunk into a leftover batch
161190 TMVA::Experimental::RTensor<float > LeftoverBatch ({LeftoverBatchSize, fNumColumns });
162191 std::copy (chunkTensor.GetData () + (Batches * fBatchSize * fNumColumns ),
163192 chunkTensor.GetData () + (Batches * fBatchSize * fNumColumns + LeftoverBatchSize * fNumColumns ),
164193 LeftoverBatch.GetData ());
165194
195+ // calculate how many empty slots are left in fPrimaryLeftoverTrainingBatch
166196 std::size_t PrimaryLeftoverSize = (*fPrimaryLeftoverTrainingBatch ).GetShape ()[0 ];
167197 std::size_t emptySlots = fBatchSize - PrimaryLeftoverSize;
168198
199+ // copy LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
169200 if (emptySlots >= LeftoverBatchSize) {
170201 (*fPrimaryLeftoverTrainingBatch ) =
171202 (*fPrimaryLeftoverTrainingBatch ).Resize ({PrimaryLeftoverSize + LeftoverBatchSize, fNumColumns });
172203 std::copy (LeftoverBatch.GetData (), LeftoverBatch.GetData () + (LeftoverBatchSize * fNumColumns ),
173204 fPrimaryLeftoverTrainingBatch ->GetData () + (PrimaryLeftoverSize * fNumColumns ));
174205
206+ // copy LeftoverBatch to end of fPrimaryLeftoverTrainingBatch and add it to the batch vector
175207 if (emptySlots == LeftoverBatchSize) {
176208 auto copy =
177209 std::make_unique<TMVA::Experimental::RTensor<float >>(std::vector<std::size_t >{fBatchSize , fNumColumns });
178210 std::copy (fPrimaryLeftoverTrainingBatch ->GetData (),
179211 fPrimaryLeftoverTrainingBatch ->GetData () + (fBatchSize * fNumColumns ), copy->GetData ());
180212 batches.emplace_back (std::move (copy));
181213
214+ // reset fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverTrainingBatch
182215 *fPrimaryLeftoverTrainingBatch = *fSecondaryLeftoverTrainingBatch ;
183216 fSecondaryLeftoverValidationBatch =
184217 std::make_unique<TMVA::Experimental::RTensor<float >>(std::vector<std::size_t >{0 , fNumColumns });
185218 }
186219 }
187220
221+ // copy LeftoverBatch to both fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverTrainingBatch
188222 else if (emptySlots < LeftoverBatchSize) {
223+ // copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
189224 (*fPrimaryLeftoverTrainingBatch ) = (*fPrimaryLeftoverTrainingBatch ).Resize ({fBatchSize , fNumColumns });
190225 std::copy (LeftoverBatch.GetData (), LeftoverBatch.GetData () + (emptySlots * fNumColumns ),
191226 fPrimaryLeftoverTrainingBatch ->GetData () + (PrimaryLeftoverSize * fNumColumns ));
192227
228+ // copy the last part of LeftoverBatch to the end of fSecondaryLeftoverTrainingBatch
193229 (*fSecondaryLeftoverTrainingBatch ) =
194230 (*fSecondaryLeftoverTrainingBatch ).Resize ({LeftoverBatchSize - emptySlots, fNumColumns });
195231 std::copy (LeftoverBatch.GetData () + (emptySlots * fNumColumns ),
196232 LeftoverBatch.GetData () + (LeftoverBatchSize * fNumColumns ),
197233 fSecondaryLeftoverTrainingBatch ->GetData ());
234+
235+ // add fPrimaryLeftoverTrainingBatch to the batch vector
198236 auto copy =
199237 std::make_unique<TMVA::Experimental::RTensor<float >>(std::vector<std::size_t >{fBatchSize , fNumColumns });
200238 std::copy (fPrimaryLeftoverTrainingBatch ->GetData (),
201239 fPrimaryLeftoverTrainingBatch ->GetData () + (fBatchSize * fNumColumns ), copy->GetData ());
202240 batches.emplace_back (std::move (copy));
241+
242+ // exchange fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverValidationBatch
203243 *fPrimaryLeftoverTrainingBatch = *fSecondaryLeftoverTrainingBatch ;
244+
245+ // restet fSecondaryLeftoverValidationBatch
204246 fSecondaryLeftoverValidationBatch =
205247 std::make_unique<TMVA::Experimental::RTensor<float >>(std::vector<std::size_t >{0 , fNumColumns });
206248 }
207249
250+ // copy the content of fPrimaryLeftoverTrainingBatch to the leftover batch from the chunk
208251 if (lastbatch == 1 ) {
209252
210253 if (dropRemainder == false && leftoverBatchSize > 0 ) {
@@ -221,11 +264,17 @@ public:
221264 std::make_unique<TMVA::Experimental::RTensor<float >>(std::vector<std::size_t >{0 , fNumColumns });
222265 }
223266
267+ // append the batches from the batch vector from the chunk to the training batch queue
224268 for (std::size_t i = 0 ; i < batches.size (); i++) {
225269 fTrainingBatchQueue .push (std::move (batches[i]));
226270 }
227271 }
228-
272+
273+ // / \brief Creating the validation batches from a chunk and adding them to the queue
274+ // / \param[in] chunkTensor RTensor with the data from the chunk
275+ // / \param[in] lastbatch Check if the batch in the chunk is the last one
276+ // / \param[in] leftoverBatchSize Size of the leftover batch in the validation dataset
277+ // / \param[in] dromRemainder Bool to drop the remainder batch or not
229278 void CreateValidationBatches (TMVA::Experimental::RTensor<float > &chunkTensor, std::size_t lastbatch,
230279 std::size_t leftoverBatchSize, bool dropRemainder)
231280 {
0 commit comments