|
| 1 | +from typing import Generic, Type |
| 2 | + |
| 3 | +from pydantic import Field |
| 4 | +from typing_extensions import override |
| 5 | + |
| 6 | +from ..callable import CallableModelType, ContextType, Flow, ResultType, WrapperModel |
| 7 | +from ..publisher import PublisherType |
| 8 | +from ..result import GenericResult |
| 9 | + |
| 10 | +__all__ = ("PublisherModel",) |
| 11 | + |
| 12 | + |
| 13 | +class PublisherModel( |
| 14 | + WrapperModel[CallableModelType], |
| 15 | + Generic[CallableModelType, PublisherType], |
| 16 | +): |
| 17 | + """Model that chains together a callable model and a publisher to publish the results of the callable model.""" |
| 18 | + |
| 19 | + publisher: PublisherType |
| 20 | + field: str = Field(None, description="Specific field on model output to publish") |
| 21 | + return_data: bool = Field( |
| 22 | + False, |
| 23 | + description="Whether to return the underlying model result as the output instead of the publisher output", |
| 24 | + ) |
| 25 | + |
| 26 | + @property |
| 27 | + def result_type(self) -> Type[ResultType]: |
| 28 | + """Result type that will be returned. Could be over-ridden by child class.""" |
| 29 | + if self.return_data: |
| 30 | + return self.model.result_type |
| 31 | + else: |
| 32 | + return GenericResult |
| 33 | + |
| 34 | + def _get_publisher(self, context): |
| 35 | + publisher = self.publisher.model_copy() |
| 36 | + # Set the name, if needed |
| 37 | + if not publisher.name and self.meta.name: |
| 38 | + publisher.name = self.meta.name |
| 39 | + # Augment any existing name parameters with the context parameters |
| 40 | + name_params = publisher.name_params.copy() |
| 41 | + name_params.update(context.model_dump(exclude={"type_"})) |
| 42 | + publisher.name_params = name_params |
| 43 | + return publisher |
| 44 | + |
| 45 | + @override |
| 46 | + @Flow.call |
| 47 | + def __call__(self, context: ContextType) -> ResultType: |
| 48 | + """This method gets the result from the underlying model, and publishes it.""" |
| 49 | + publisher = self._get_publisher(context) |
| 50 | + data = self.model(context) |
| 51 | + if self.field: |
| 52 | + pub_data = getattr(data, self.field) |
| 53 | + else: |
| 54 | + pub_data = data |
| 55 | + publisher.data = pub_data |
| 56 | + out = publisher() |
| 57 | + if self.return_data: |
| 58 | + return data |
| 59 | + else: |
| 60 | + return self.result_type(value=out) |
0 commit comments