Skip to content

Commit c613642

Browse files
authored
Merge pull request #112 from timkpaine/tkp/pub
Release publisher model, add tests
2 parents a135501 + a13963b commit c613642

File tree

5 files changed

+85
-1
lines changed

5 files changed

+85
-1
lines changed

ccflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .context import *
77
from .enums import Enum
88
from .exttypes import *
9+
from .models import *
910
from .object_config import *
1011
from .publisher import *
1112
from .result import *

ccflow/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1+
from .publisher import *

ccflow/models/publisher.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Generic, Type
2+
3+
from pydantic import Field
4+
from typing_extensions import override
5+
6+
from ..callable import CallableModelType, ContextType, Flow, ResultType, WrapperModel
7+
from ..publisher import PublisherType
8+
from ..result import GenericResult
9+
10+
__all__ = ("PublisherModel",)
11+
12+
13+
class PublisherModel(
14+
WrapperModel[CallableModelType],
15+
Generic[CallableModelType, PublisherType],
16+
):
17+
"""Model that chains together a callable model and a publisher to publish the results of the callable model."""
18+
19+
publisher: PublisherType
20+
field: str = Field(None, description="Specific field on model output to publish")
21+
return_data: bool = Field(
22+
False,
23+
description="Whether to return the underlying model result as the output instead of the publisher output",
24+
)
25+
26+
@property
27+
def result_type(self) -> Type[ResultType]:
28+
"""Result type that will be returned. Could be over-ridden by child class."""
29+
if self.return_data:
30+
return self.model.result_type
31+
else:
32+
return GenericResult
33+
34+
def _get_publisher(self, context):
35+
publisher = self.publisher.model_copy()
36+
# Set the name, if needed
37+
if not publisher.name and self.meta.name:
38+
publisher.name = self.meta.name
39+
# Augment any existing name parameters with the context parameters
40+
name_params = publisher.name_params.copy()
41+
name_params.update(context.model_dump(exclude={"type_"}))
42+
publisher.name_params = name_params
43+
return publisher
44+
45+
@override
46+
@Flow.call
47+
def __call__(self, context: ContextType) -> ResultType:
48+
"""This method gets the result from the underlying model, and publishes it."""
49+
publisher = self._get_publisher(context)
50+
data = self.model(context)
51+
if self.field:
52+
pub_data = getattr(data, self.field)
53+
else:
54+
pub_data = data
55+
publisher.data = pub_data
56+
out = publisher()
57+
if self.return_data:
58+
return data
59+
else:
60+
return self.result_type(value=out)

ccflow/tests/models/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from unittest.mock import patch
2+
3+
from ccflow import CallableModel, DictResult, Flow, GenericResult, NullContext
4+
from ccflow.models import PublisherModel
5+
from ccflow.publishers import PrintPublisher
6+
7+
8+
class TestModel(CallableModel):
9+
@Flow.call
10+
def __call__(self, context: NullContext) -> DictResult[str, str]:
11+
return DictResult[str, str](value={"message": "Hello, World!"})
12+
13+
14+
class TestPublisherModel:
15+
def test_run(self):
16+
with patch("ccflow.publishers.print.print") as mock_print:
17+
model = PublisherModel(model=TestModel(), publisher=PrintPublisher())
18+
res = model(None)
19+
assert isinstance(res, GenericResult) # from PrintPublisher
20+
assert isinstance(res.value, DictResult[str, str])
21+
assert res.value.value == {"message": "Hello, World!"}
22+
assert mock_print.call_count == 1
23+
assert mock_print.call_args[0][0].value == {"message": "Hello, World!"}

0 commit comments

Comments
 (0)