Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):

class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def __init__(self):
"""Environment variables are set using a dict named 'env_vars' before
loading the model. Child classes can accept this dict as a kwarg."""
self._env_vars = {}
def __init__(
self,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
**kwargs):
"""Initializes the ModelHandler.

Args:
min_batch_size: the minimum batch size to use when batching inputs.
max_batch_size: the maximum batch size to use when batching inputs.
max_batch_duration_secs: the maximum amount of time to buffer a batch
before emitting; used in streaming contexts.
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
element_size_fn: a function that returns the size (weight) of an element.
large_model: set to true if your model is large enough to run into
memory pressure if you load multiple copies.
model_copies: The exact number of models that you would like loaded
onto your machine.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
"""
self._env_vars = kwargs.get('env_vars', {})
self._batching_kwargs: dict[str, Any] = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs['max_batch_duration_secs'] = max_batch_duration_secs
if max_batch_weight is not None:
self._batching_kwargs['max_batch_weight'] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
self._large_model = large_model
self._model_copies = model_copies
self._share_across_processes = large_model or (model_copies is not None)

def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
Expand Down Expand Up @@ -220,7 +258,7 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
Returns:
kwargs suitable for beam.BatchElements.
"""
return {}
return getattr(self, '_batching_kwargs', {})

def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
"""
Expand Down Expand Up @@ -325,14 +363,14 @@ def share_model_across_processes(self) -> bool:
memory. Multi-process support may vary by runner, but this will fallback to
loading per process as necessary. See
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False
return getattr(self, '_share_across_processes', False)

def model_copies(self) -> int:
"""Returns the maximum number of model copies that should be loaded at one
time. This only impacts model handlers that are using
share_model_across_processes to share their model across processes instead
of being loaded per process."""
return 1
return getattr(self, '_model_copies', None) or 1

def override_metrics(self, metrics_namespace: str = '') -> bool:
"""Returns a boolean representing whether or not a model handler will
Expand Down
58 changes: 58 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,5 +2133,63 @@ def request(self, batch, model, inference_args=None):
model_handler.run_inference([1], FakeModel())


class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
"""A ModelHandler used to test element sizing behavior."""
def __init__(
self,
max_batch_size: int = 10,
max_batch_weight: Optional[int] = None,
element_size_fn=None):
self._max_batch_size = max_batch_size
self._max_batch_weight = max_batch_weight
self._element_size_fn = element_size_fn

def load_model(self) -> FakeModel:
return FakeModel()

def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]

def batch_elements_kwargs(self):
kwargs = {'max_batch_size': self._max_batch_size}
if self._max_batch_weight is not None:
kwargs['max_batch_weight'] = self._max_batch_weight
if self._element_size_fn:
kwargs['element_size_fn'] = self._element_size_fn
return kwargs
Comment on lines +2138 to +2159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The FakeModelHandlerForSizing re-implements the logic for handling batching keyword arguments, instead of leveraging the new implementation in the ModelHandler base class. To ensure the test correctly validates the base class behavior, this test handler should be refactored to call super().__init__ and remove the overridden batch_elements_kwargs method.

Suggested change
def __init__(
self,
max_batch_size: int = 10,
max_batch_weight: Optional[int] = None,
element_size_fn=None):
self._max_batch_size = max_batch_size
self._max_batch_weight = max_batch_weight
self._element_size_fn = element_size_fn
def load_model(self) -> FakeModel:
return FakeModel()
def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]
def batch_elements_kwargs(self):
kwargs = {'max_batch_size': self._max_batch_size}
if self._max_batch_weight is not None:
kwargs['max_batch_weight'] = self._max_batch_weight
if self._element_size_fn:
kwargs['element_size_fn'] = self._element_size_fn
return kwargs
def __init__(
self,
max_batch_size: int = 10,
max_batch_weight: Optional[int] = None,
element_size_fn=None):
super().__init__(
max_batch_size=max_batch_size,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn)
def load_model(self) -> FakeModel:
return FakeModel()
def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good suggestion



class RunInferenceSizeTest(unittest.TestCase):
"""Tests for ModelHandler.batch_elements_kwargs with element_size_fn."""
def test_kwargs_are_passed_correctly(self):
"""Adds element_size_fn without clobbering existing kwargs."""
def size_fn(x):
return 10

sized_handler = FakeModelHandlerForSizing(
max_batch_size=20, max_batch_weight=100, element_size_fn=size_fn)

kwargs = sized_handler.batch_elements_kwargs()

self.assertEqual(kwargs['max_batch_size'], 20)
self.assertEqual(kwargs['max_batch_weight'], 100)
self.assertIn('element_size_fn', kwargs)
self.assertEqual(kwargs['element_size_fn'](1), 10)

def test_sizing_with_edge_cases(self):
"""Allows extreme values from element_size_fn."""
zero_size_fn = lambda x: 0
sized_handler = FakeModelHandlerForSizing(
max_batch_size=1, element_size_fn=zero_size_fn)
kwargs = sized_handler.batch_elements_kwargs()
self.assertEqual(kwargs['element_size_fn'](999), 0)

large_size_fn = lambda x: 1000000
sized_handler = FakeModelHandlerForSizing(
max_batch_size=1, element_size_fn=large_size_fn)
kwargs = sized_handler.batch_elements_kwargs()
self.assertEqual(kwargs['element_size_fn'](1), 1000000)


if __name__ == '__main__':
unittest.main()
16 changes: 14 additions & 2 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Google Gemini.
**NOTE:** This API and its implementation are under development and
Expand All @@ -134,15 +136,18 @@ def __init__(
project: the GCP project to use for Vertex AI requests. Setting this
parameter routes requests to Vertex AI. If this paramter is provided,
location must also be provided and api_key should not be set.
location: the GCP project to use for Vertex AI requests. Setting this
location: the GCP project to use for Vertex AI requests. Setting this
parameter routes requests to Vertex AI. If this paramter is provided,
project must also be provided and api_key should not be set.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
inputs.
max_batch_duration_secs: optional. the maximum amount of time to buffer
max_batch_duration_secs: optional. the maximum amount of time to buffer
a batch before emitting; used in streaming contexts.
max_batch_weight: optional. the maximum total weight of a batch.
element_size_fn: optional. a function that returns the size (weight)
of an element.
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
Expand All @@ -152,6 +157,10 @@ def __init__(
self._batching_kwargs["max_batch_size"] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
if max_batch_weight is not None:
self._batching_kwargs["max_batch_weight"] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn

self.model_name = model_name
self.request_fn = request_fn
Expand All @@ -174,6 +183,9 @@ def __init__(
retry_filter=_retry_on_appropriate_service_error,
**kwargs)

def batch_elements_kwargs(self):
return self._batching_kwargs

def create_client(self) -> genai.Client:
"""Creates the GenAI client used to send requests. Creates a version for
the Vertex AI API or the Gemini Developer API based on the arguments
Expand Down
96 changes: 39 additions & 57 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -262,27 +264,28 @@ def __init__(
model_copies: The exact number of models that you would like loaded
onto your machine. This can be useful if you exactly know your CPU or
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

**Supported Versions:** HuggingFaceModelHandler supports
transformers>=4.18.0.
"""
super().__init__(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
large_model=large_model,
model_copies=model_copies,
**kwargs)
self._model_uri = model_uri
self._model_class = model_class
self._device = device
self._inference_fn = inference_fn
self._model_config_args = load_model_args if load_model_args else {}
self._batching_kwargs = {}
self._env_vars = kwargs.get("env_vars", {})
if min_batch_size is not None:
self._batching_kwargs["min_batch_size"] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs["max_batch_size"] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
self._share_across_processes = large_model or (model_copies is not None)
self._model_copies = model_copies or 1
self._framework = framework

_validate_constructor_args(
Expand Down Expand Up @@ -352,15 +355,6 @@ def get_num_bytes(
return sum(
(el.element_size() for tensor in batch for el in tensor.values()))

def batch_elements_kwargs(self):
return self._batching_kwargs

def share_model_across_processes(self) -> bool:
return self._share_across_processes

def model_copies(self) -> int:
return self._model_copies

def get_metrics_namespace(self) -> str:
"""
Returns:
Expand Down Expand Up @@ -415,6 +409,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -450,27 +446,28 @@ def __init__(
model_copies: The exact number of models that you would like loaded
onto your machine. This can be useful if you exactly know your CPU or
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

**Supported Versions:** HuggingFaceModelHandler supports
transformers>=4.18.0.
"""
super().__init__(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
large_model=large_model,
model_copies=model_copies,
**kwargs)
self._model_uri = model_uri
self._model_class = model_class
self._device = device
self._inference_fn = inference_fn
self._model_config_args = load_model_args if load_model_args else {}
self._batching_kwargs = {}
self._env_vars = kwargs.get("env_vars", {})
if min_batch_size is not None:
self._batching_kwargs["min_batch_size"] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs["max_batch_size"] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
self._share_across_processes = large_model or (model_copies is not None)
self._model_copies = model_copies or 1
self._framework = ""

_validate_constructor_args(
Expand Down Expand Up @@ -547,15 +544,6 @@ def get_num_bytes(
return sum(
(el.element_size() for tensor in batch for el in tensor.values()))

def batch_elements_kwargs(self):
return self._batching_kwargs

def share_model_across_processes(self) -> bool:
return self._share_across_processes

def model_copies(self) -> int:
return self._model_copies

def get_metrics_namespace(self) -> str:
"""
Returns:
Expand Down Expand Up @@ -586,6 +574,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for Hugging Face Pipelines.
Expand Down Expand Up @@ -629,27 +619,28 @@ def __init__(
model_copies: The exact number of models that you would like loaded
onto your machine. This can be useful if you exactly know your CPU or
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

**Supported Versions:** HuggingFacePipelineModelHandler supports
transformers>=4.18.0.
"""
super().__init__(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
large_model=large_model,
model_copies=model_copies,
**kwargs)
self._task = task
self._model = model
self._inference_fn = inference_fn
self._load_pipeline_args = load_pipeline_args if load_pipeline_args else {}
self._batching_kwargs = {}
self._framework = "pt"
self._env_vars = kwargs.get('env_vars', {})
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
self._share_across_processes = large_model or (model_copies is not None)
self._model_copies = model_copies or 1

# Check if the device is specified twice. If true then the device parameter
# of model handler is overridden.
Expand Down Expand Up @@ -726,15 +717,6 @@ def get_num_bytes(self, batch: Sequence[str]) -> int:
"""
return sum(sys.getsizeof(element) for element in batch)

def batch_elements_kwargs(self):
return self._batching_kwargs

def share_model_across_processes(self) -> bool:
return self._share_across_processes

def model_copies(self) -> int:
return self._model_copies

def get_metrics_namespace(self) -> str:
"""
Returns:
Expand Down
Loading
Loading