diff --git a/tango/integrations/torch/model.py b/tango/integrations/torch/model.py index c1bf8c1a7..961d63681 100644 --- a/tango/integrations/torch/model.py +++ b/tango/integrations/torch/model.py @@ -1,3 +1,5 @@ +# from typing import Any, Dict + import torch from tango.common.registrable import Registrable @@ -10,3 +12,6 @@ class Model(torch.nn.Module, Registrable): Its :meth:`~torch.nn.Module.forward()` method should return a :class:`dict` that includes the ``loss`` during training and any tracked metrics during validation. """ + + # def _to_params(self) -> Dict[str, Any]: + # return {} diff --git a/tests/integrations/torch/model_test.py b/tests/integrations/torch/model_test.py new file mode 100644 index 000000000..c206b2167 --- /dev/null +++ b/tests/integrations/torch/model_test.py @@ -0,0 +1,52 @@ +import pytest +from torch import nn + +from tango.common.testing import TangoTestCase +from tango.integrations.torch import Model +from tango.step import Step + + +class FeedForward(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4) + self.activation = nn.ReLU() + + def forward(self, x): + return self.activation(self.linear(x)) + + +@Model.register("simple_regression_model", exist_ok=True) +class SimpleRegressionModel(Model): + def __init__(self): + super().__init__() + self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)]) + self.regression_head = nn.Linear(4, 1) + self.loss_fcn = nn.MSELoss() + + def forward(self, x, y): + output = self.blocks(x) + output = self.regression_head(output) + loss = self.loss_fcn(output, y) + return {"loss": loss} + + +@Step.register("step-that-takes-model-as-input") +class StepThatTakesModelAsInput(Step): + def run(self, model: Model) -> Model: # type: ignore + return model + + +class TestModelAsStepInput(TangoTestCase): + def test_step_that_takes_model_as_input(self): + config = { + "steps": { + "model": { + "type": "step-that-takes-model-as-input", + "model": {"type": "simple_regression_model"}, + } + } + } + + with pytest.raises(NotImplementedError): + self.run(config)