Skip to content

Commit 032e423

Browse files
committed
[#47] Adding filesystem support for save_df
... Signed-off-by: Todd Gaugler <[email protected]> ...
1 parent 8868835 commit 032e423

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

raydar/task_tracker/task_tracker.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import logging
44
import os
55
from collections.abc import Iterable
6-
from typing import Dict, List, Optional
6+
from typing import Dict, List, Optional, Type
77

88
import coolname
99
import pandas as pd
1010
import polars as pl
11+
import pyarrow.fs as fs
12+
import pyarrow.parquet as pq
1113
import ray
1214
from packaging.version import Version
1315
from ray.serve import shutdown
@@ -84,13 +86,7 @@ def exit(self) -> None:
8486

8587
@ray.remote(resources={"node:__internal_head__": 0.1}, num_cpus=0)
8688
class AsyncMetadataTracker:
87-
def __init__(
88-
self,
89-
name: str,
90-
namespace: str,
91-
path: Optional[str] = None,
92-
enable_perspective_dashboard: bool = False,
93-
):
89+
def __init__(self, name: str, namespace: str, enable_perspective_dashboard: bool = False, filesystem: Type[fs.FileSystem] = fs.LocalFileSystem):
9490
"""An async Ray Actor Class to track task level metadata.
9591
9692
This class constructs a AsyncMetadataTrackerCallback actor, which points back to this actor. Its process(...)
@@ -114,13 +110,13 @@ def __init__(
114110
lifetime="detached",
115111
get_if_exists=True,
116112
).remote(name, namespace)
117-
self.path = path
118113
self.df = None
119114
self.finished_tasks = {}
120115
self.user_defined_metadata = {}
121116
self.perspective_dashboard_enabled = enable_perspective_dashboard
122117
self.pending_tasks = []
123118
self.perspective_table_name = f"{name}_data"
119+
self.filesystem = filesystem()
124120

125121
# WARNING: Do not move this import. Importing these modules elsewhere can cause
126122
# difficult to diagnose, "There is no current event loop in thread 'ray_client_server_" errors.
@@ -306,14 +302,10 @@ def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
306302
return self.proxy_server
307303
raise Exception("This task_tracker has no active proxy_server.")
308304

309-
def save_df(self) -> None:
310-
"""Saves the internally maintained dataframe of task related information from the ray GCS"""
311-
self.get_df()
312-
if self.path is not None and self.df is not None:
313-
logger.info(f"Writing DataFrame to {self.path}")
314-
self.df.write_parquet(self.path)
315-
return True
316-
return False
305+
def save_df(self, path: str) -> None:
306+
"""Saves the internally maintained dataframe of task related information from the ray GCS to a provided path, using the filesystem attributed"""
307+
logger.info(f"Writing DataFrame to {path}")
308+
pq.write_table(self.get_df().to_arrow(), path, filesystem=self.filesystem)
317309

318310
def clear_df(self) -> None:
319311
"""Clears the internally maintained dataframe of task related information from the ray GCS"""
@@ -363,9 +355,9 @@ def get_df(self, process_user_metadata_column=False) -> pl.DataFrame:
363355
return df_with_user_metadata
364356
return df
365357

366-
def save_df(self) -> None:
358+
def save_df(self, path: str) -> None:
367359
"""Save the dataframe used by this object's AsyncMetadataTracker actor"""
368-
return ray.get(self.tracker.save_df.remote())
360+
return ray.get(self.tracker.save_df.remote(path))
369361

370362
def clear(self) -> None:
371363
"""Clear the dataframe used by this object's AsyncMetadataTracker actor"""

raydar/tests/test_task_tracker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
2+
import tempfile
13
import time
24

5+
import pandas as pd
36
import pytest
47
import ray
58
import requests
@@ -39,3 +42,15 @@ def test_get_proxy_server(self):
3942
time.sleep(2)
4043
response = requests.get("http://localhost:8000/tables")
4144
assert eval(response.text) == ["test_table"]
45+
46+
def test_save_df(self):
47+
task_tracker = RayTaskTracker()
48+
refs = [do_some_work.remote() for _ in range(100)]
49+
task_tracker.process(refs)
50+
_ = ray.get(refs)
51+
df = task_tracker.get_df()
52+
with tempfile.TemporaryDirectory() as tempdir:
53+
path = os.path.join(tempdir, "output_dir")
54+
task_tracker.save_df(path)
55+
loaded_df = pd.read_parquet(path)
56+
assert loaded_df.equals(df.to_pandas())

0 commit comments

Comments
 (0)