diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 3e5290abb..4641e3899 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -177,7 +177,7 @@ def run_component( ComponentNotFoundException: if the ``component_path`` is failed to resolve. """ - with log_event("run_component") as ctx: + with log_event("run_component", workspace=workspace) as ctx: dryrun_info = self.dryrun_component( component, component_args, @@ -187,6 +187,7 @@ def run_component( parent_run_id=parent_run_id, ) handle = self.schedule(dryrun_info) + ctx._torchx_event.workspace = workspace ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler) ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image ctx._torchx_event.app_id = parse_app_handle(handle)[2] @@ -237,7 +238,9 @@ def run( An application handle that is used to call other action APIs on the app. """ - with log_event(api="run", runcfg=json.dumps(cfg) if cfg else None) as ctx: + with log_event( + api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace + ) as ctx: dryrun_info = self.dryrun( app, scheduler, @@ -371,7 +374,12 @@ def dryrun( role.env[tracker_config_env_var_name(name)] = config cfg = cfg or dict() - with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None): + with log_event( + "dryrun", + scheduler, + runcfg=json.dumps(cfg) if cfg else None, + workspace=workspace, + ): sched = self._scheduler(scheduler) resolved_cfg = sched.run_opts().resolve(cfg) if workspace and isinstance(sched, WorkspaceMixin): diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index 219def102..cedba10d6 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -85,9 +85,15 @@ def __init__( app_id: Optional[str] = None, app_image: Optional[str] = None, runcfg: Optional[str] = None, + workspace: Optional[str] = None, ) -> None: self._torchx_event: TorchxEvent = self._generate_torchx_event( - api, scheduler or "", app_id, app_image=app_image, runcfg=runcfg + api, + scheduler or "", + app_id, + app_image=app_image, + runcfg=runcfg, + workspace=workspace, ) self._start_cpu_time_ns = 0 self._start_wall_time_ns = 0 @@ -124,6 +130,7 @@ def _generate_torchx_event( app_image: Optional[str] = None, runcfg: Optional[str] = None, source: SourceType = SourceType.UNKNOWN, + workspace: Optional[str] = None, ) -> TorchxEvent: return TorchxEvent( session=app_id or "", @@ -133,4 +140,5 @@ def _generate_torchx_event( app_image=app_image, runcfg=runcfg, source=source, + workspace=workspace, ) diff --git a/torchx/runner/events/api.py b/torchx/runner/events/api.py index 70e8f2791..5cb5f11ab 100644 --- a/torchx/runner/events/api.py +++ b/torchx/runner/events/api.py @@ -47,6 +47,7 @@ class TorchxEvent: cpu_time_usec: Optional[int] = None wall_time_usec: Optional[int] = None start_epoch_time_usec: Optional[int] = None + workspace: Optional[str] = None def __str__(self) -> str: return self.serialize() diff --git a/torchx/runner/events/test/lib_test.py b/torchx/runner/events/test/lib_test.py index f17324f42..14c025ad7 100644 --- a/torchx/runner/events/test/lib_test.py +++ b/torchx/runner/events/test/lib_test.py @@ -46,12 +46,14 @@ def test_event_created(self) -> None: scheduler="test_scheduler", api="test_api", app_image="test_app_image", + workspace="test_workspace", ) self.assertEqual("test_session", event.session) self.assertEqual("test_scheduler", event.scheduler) self.assertEqual("test_api", event.api) self.assertEqual("test_app_image", event.app_image) self.assertEqual(SourceType.UNKNOWN, event.source) + self.assertEqual("test_workspace", event.workspace) def test_event_deser(self) -> None: event = TorchxEvent( @@ -59,6 +61,7 @@ def test_event_deser(self) -> None: scheduler="test_scheduler", api="test_api", app_image="test_app_image", + workspace="test_workspace", source=SourceType.EXTERNAL, ) json_event = event.serialize() @@ -74,6 +77,7 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non self.assertEqual(expected.api, actual.api) self.assertEqual(expected.app_image, actual.app_image) self.assertEqual(expected.source, actual.source) + self.assertEqual(expected.workspace, actual.workspace) def test_create_context(self, _) -> None: cfg = json.dumps({"test_key": "test_value"}) @@ -83,6 +87,7 @@ def test_create_context(self, _) -> None: "test_app_id", app_image="test_app_image_id", runcfg=cfg, + workspace="test_workspace", ) expected_torchx_event = TorchxEvent( "test_app_id", @@ -91,7 +96,9 @@ def test_create_context(self, _) -> None: "test_app_id", app_image="test_app_image_id", runcfg=cfg, + workspace="test_workspace", ) + self.assert_torchx_event(expected_torchx_event, context._torchx_event) def test_record_event(self, record_mock: MagicMock) -> None: @@ -102,6 +109,7 @@ def test_record_event(self, record_mock: MagicMock) -> None: "test_app_id", app_image="test_app_image_id", runcfg=cfg, + workspace="test_workspace", ) as ctx: pass @@ -112,6 +120,7 @@ def test_record_event(self, record_mock: MagicMock) -> None: "test_app_id", app_image="test_app_image_id", runcfg=cfg, + workspace="test_workspace", cpu_time_usec=ctx._torchx_event.cpu_time_usec, wall_time_usec=ctx._torchx_event.wall_time_usec, )