Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,32 @@
# Author: Kristupas Pranckietis, Vilnius University 05/2024
# Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
# Author: Vincenzo Eduardo Padulano, CERN 10/2024
# Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025

################################################################################
# Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. #
# Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
################################################################################

from __future__ import annotations

from typing import Any, Callable, Tuple, TYPE_CHECKING
import atexit

Check failure on line 18 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:15:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
import numpy as np
import tensorflow as tf
import torch
import ROOT

Check failure on line 24 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:21:5: I001 Import block is un-sorted or un-formatted


class BaseGenerator:
def get_template(
self,
x_rdf: RNode,
x_rdf: ROOT.RDF.RNode,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
) -> Tuple[str, list[int]]:
Expand Down Expand Up @@ -80,9 +82,10 @@

def __init__(
self,
rdataframe: RNode,
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -92,6 +95,7 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder: bool = True,
set_seed: int = 0,
):
"""Wrapper around the Cpp RBatchGenerator

Expand Down Expand Up @@ -126,13 +130,17 @@
drop_remainder (bool):
Drop the remainder of data that is too small to compose full batch.
Defaults to True.
set_seed (int):
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
Defaults to 0 which means that the seed is set to the random device.
"""

import ROOT

Check failure on line 139 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:139:16: F401 `ROOT` imported but unused
from ROOT import RDF

try:
import numpy as np

Check failure on line 143 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:143:29: F401 `numpy` imported but unused; consider using `importlib.util.find_spec` to test for availability

except ImportError:
raise ImportError(
Expand All @@ -154,11 +162,6 @@

self.noded_rdf = RDF.AsRNode(rdataframe)

if ROOT.Internal.RDF.GetDataSourceLabel(self.noded_rdf) != "TTreeDS":
raise ValueError(
"RNode object must be created out of TTrees or files of TTree"
)

if isinstance(target, str):
target = [target]

Expand Down Expand Up @@ -221,15 +224,16 @@
self.generator = TMVA.Experimental.Internal.RBatchGenerator(template)(
self.noded_rdf,
chunk_size,
block_size,
batch_size,
self.given_columns,
self.num_columns,
max_vec_sizes_list,
vec_padding,
validation_split,
max_chunks,
shuffle,
drop_remainder,
set_seed,
)

atexit.register(self.DeActivate)
Expand All @@ -238,6 +242,9 @@
def is_active(self):
return self.generator.IsActive()

def is_training_active(self):
return self.generator.TrainingIsActive()

def Activate(self):
"""Initialize the generator to be used for a loop"""
self.generator.Activate()
Expand All @@ -246,6 +253,30 @@
"""Deactivate the generator"""
self.generator.DeActivate()

def ActivateTrainingEpoch(self):
"""Activate the generator"""
self.generator.ActivateTrainingEpoch()

def ActivateValidationEpoch(self):
"""Activate the generator"""
self.generator.ActivateValidationEpoch()

def DeActivateTrainingEpoch(self):
"""Deactivate the generator"""
self.generator.DeActivateTrainingEpoch()

def DeActivateValidationEpoch(self):
"""Deactivate the generator"""
self.generator.DeActivateValidationEpoch()

def CreateTrainBatches(self):
"""Deactivate the generator"""
self.generator.CreateTrainBatches()

def CreateValidationBatches(self):
"""Deactivate the generator"""
self.generator.CreateValidationBatches()

def GetSample(self):
"""
Return a sample of data that has the same size and types as the actual
Expand Down Expand Up @@ -286,7 +317,7 @@
np.zeros((self.batch_size)).reshape(-1, 1),
)

def ConvertBatchToNumpy(self, batch: "RTensor") -> np.ndarray:

Check failure on line 320 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:320:43: F821 Undefined name `RTensor`
"""Convert a RTensor into a NumPy array

Args:
Expand Down Expand Up @@ -337,8 +368,8 @@
Returns:
torch.Tensor: converted batch
"""
import torch
import numpy as np

Check failure on line 372 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:371:9: I001 Import block is un-sorted or un-formatted

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())
Expand Down Expand Up @@ -445,12 +476,14 @@
class LoadingThreadContext:
def __init__(self, base_generator: BaseGenerator):
self.base_generator = base_generator

# create training batches from the first chunk
self.base_generator.CreateTrainBatches();

Check failure on line 480 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E703)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:480:49: E703 Statement ends with an unnecessary semicolon

def __enter__(self):
self.base_generator.Activate()
self.base_generator.ActivateTrainingEpoch()

def __exit__(self, type, value, traceback):
self.base_generator.DeActivate()
self.base_generator.DeActivateTrainingEpoch()
return True


Expand All @@ -469,6 +502,7 @@
self.base_generator = base_generator
self.conversion_function = conversion_function


def Activate(self):
"""Start the loading of training batches"""
self.base_generator.Activate()
Expand Down Expand Up @@ -503,6 +537,7 @@
return self.base_generator.generator.TrainRemainderRows()

def __iter__(self):

self._callable = self.__call__()

return self
Expand All @@ -522,16 +557,28 @@
Union[np.NDArray, torch.Tensor]: A batch of data
"""

with LoadingThreadContext(self.base_generator):
with LoadingThreadContext(self.base_generator):
while True:
batch = self.base_generator.GetTrainBatch()

if batch is None:
break

yield self.conversion_function(batch)

return None

class LoadingThreadContextVal:
def __init__(self, base_generator: BaseGenerator):
self.base_generator = base_generator
# create validation batches from the first chunk
self.base_generator.CreateValidationBatches()

return None
def __enter__(self):
self.base_generator.ActivateValidationEpoch()

def __exit__(self, type, value, traceback):
self.base_generator.DeActivateValidationEpoch()
return True



class ValidationRBatchGenerator:
Expand Down Expand Up @@ -588,27 +635,27 @@
return batch

def __call__(self) -> Any:
"""Loop through the validation batches
"""Start the loading of batches and yield the results

Yields:
Union[np.NDArray, torch.Tensor]: A batch of data
"""
if self.base_generator.is_active:
self.base_generator.DeActivate()

while True:
batch = self.base_generator.GetValidationBatch()

if not batch:
break

yield self.conversion_function(batch)



with LoadingThreadContextVal(self.base_generator):
while True:
batch = self.base_generator.GetValidationBatch()
if batch is None:
self.base_generator.DeActivateValidationEpoch()
break
yield self.conversion_function(batch)

return None

def CreateNumPyGenerators(
rdataframe: RNode,
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -618,6 +665,7 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
"""
Return two batch generators based on the given ROOT file and tree or RDataFrame
Expand Down Expand Up @@ -670,12 +718,13 @@
validation generator will return no batches.
"""

import numpy as np

Check failure on line 721 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:721:21: F401 `numpy` imported but unused

base_generator = BaseGenerator(
rdataframe,
batch_size,
chunk_size,
block_size,
columns,
max_vec_sizes,
vec_padding,
Expand All @@ -685,6 +734,7 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
)

train_generator = TrainRBatchGenerator(
Expand All @@ -702,9 +752,10 @@


def CreateTFDatasets(
rdataframe: RNode,
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -714,6 +765,7 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
"""
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
Expand Down Expand Up @@ -771,6 +823,7 @@
rdataframe,
batch_size,
chunk_size,
block_size,
columns,
max_vec_sizes,
vec_padding,
Expand All @@ -780,6 +833,7 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
)

train_generator = TrainRBatchGenerator(
Expand Down Expand Up @@ -847,9 +901,10 @@


def CreatePyTorchGenerators(
rdataframe: RNode,
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -859,6 +914,7 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
"""
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
Expand Down Expand Up @@ -914,6 +970,7 @@
rdataframe,
batch_size,
chunk_size,
block_size,
columns,
max_vec_sizes,
vec_padding,
Expand All @@ -923,6 +980,7 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
)

train_generator = TrainRBatchGenerator(
Expand Down
Loading
Loading