Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions breadbox/breadbox/compute/dataset_uploads_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .celery import app, LogErrorsTask
import celery
from ..config import get_settings
from ..utils.progress_tracker import ProgressTracker


@app.task(base=LogErrorsTask, bind=True)
Expand All @@ -46,7 +47,7 @@ def run_dataset_upload(
else:
params: DatasetParams = TableDatasetParams(**dataset_params)

upload_dataset_response = dataset_upload(db, params, user)
upload_dataset_response = dataset_upload(db, params, user, ProgressTracker())

# because celery is going to want to serialize the response,
# convert it to a json dict before returning it
Expand All @@ -55,8 +56,12 @@ def run_dataset_upload(


def dataset_upload(
db: SessionWithUser, dataset_params: DatasetParams, user: str,
db: SessionWithUser,
dataset_params: DatasetParams,
user: str,
progress: ProgressTracker,
):
progress.update_message("started processing of uploaded file")
settings = get_settings()

# NOTE: We make this check in the dataset_crud.add_dataset function too, because we
Expand All @@ -80,6 +85,8 @@ def dataset_upload(
settings.compute_results_location,
)

progress.update_message("reassembled uploaded file")

dataset_id = str(uuid4())

unknown_ids = []
Expand All @@ -92,6 +99,8 @@ def dataset_upload(
)
sample_type = _get_dimension_type(db, dataset_params.sample_type, "sample")

progress.update_message("starting validation")

df_wrapper = read_and_validate_matrix_df(
file_path,
dataset_params.value_type,
Expand Down Expand Up @@ -154,11 +163,14 @@ def dataset_upload(
dataset_params.version,
dataset_params.description,
)

progress.update_message("saving final version of dataset")
save_dataset_file(
dataset_id,
df_wrapper,
dataset_params.value_type,
settings.filestore_location,
progress,
)

else:
Expand Down
5 changes: 4 additions & 1 deletion breadbox/breadbox/io/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..crud.dimension_types import get_dimension_type
import pyarrow

from ..utils.progress_tracker import ProgressTracker

pd.set_option("mode.use_inf_as_na", True)

Expand Down Expand Up @@ -469,7 +470,9 @@ def validate_and_upload_dataset_files(
)

# TODO: Move save function to api layer. Need to make sure the db save is successful first
save_dataset_file(dataset_id, data_dfw, value_type, filestore_location)
save_dataset_file(
dataset_id, data_dfw, value_type, filestore_location, ProgressTracker()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a legacy code path, so I don't care that it's not reporting status. We're not supposed to use this anyway.

)

return dataframe_validated_dimensions

Expand Down
7 changes: 6 additions & 1 deletion breadbox/breadbox/io/filestore_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd

from ..utils.progress_tracker import ProgressTracker
from ..schemas.dataframe_wrapper import DataFrameWrapper
from ..models.dataset import Dataset, MatrixDataset, ValueType
from .hdf5_utils import write_hdf5_file, read_hdf5_file
Expand All @@ -17,6 +18,7 @@ def save_dataset_file(
df_wrapper: DataFrameWrapper,
value_type: ValueType,
filestore_location: str,
progress: ProgressTracker,
):
base_path = os.path.join(filestore_location, dataset_id)
os.makedirs(base_path)
Expand All @@ -27,7 +29,10 @@ def save_dataset_file(
dtype = "float"

write_hdf5_file(
get_file_location(dataset_id, filestore_location, DATA_FILE), df_wrapper, dtype
get_file_location(dataset_id, filestore_location, DATA_FILE),
df_wrapper,
dtype,
progress,
)


Expand Down
9 changes: 9 additions & 0 deletions breadbox/breadbox/io/hdf5_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Literal

from breadbox.utils.progress_tracker import ProgressTracker
from breadbox.schemas.custom_http_exception import FileValidationError
from breadbox.schemas.dataframe_wrapper import ParquetDataFrameWrapper
import h5py
Expand Down Expand Up @@ -31,6 +32,7 @@ def write_hdf5_file(
path: str,
df_wrapper: DataFrameWrapper,
dtype: Literal["float", "str"],
progress: ProgressTracker,
batch_size: int = 5000, # Adjust batch size as needed
):
f = h5py.File(path, mode="w")
Expand Down Expand Up @@ -79,7 +81,11 @@ def write_hdf5_file(
dtype=h5py.string_dtype() if dtype == "str" else np.float64,
)

progress.update_message("transforming columns")
progress.update_process_max_value(len(cols))
for i in range(0, len(cols), batch_size):
progress.update_progress(i)

# Find the correct column slice to write
end_col = i + batch_size

Expand All @@ -102,8 +108,11 @@ def write_hdf5_file(
f"Failed to update {i}:{end_col} of hdf5 file {path} with {values}"
) from e

progress.update_process_max_value(None)
progress.update_message("Creating indexes in database")
create_index_dataset(f, "features", pd.Index(df_wrapper.get_column_names()))
create_index_dataset(f, "samples", pd.Index(df_wrapper.get_index_names()))
progress.update_message("Complete")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should put this in the finally statement?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finally is run on failure too.

Copy link
Contributor

@jessica-cheng jessica-cheng Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcreasi That's true it doesn't make sense for progress to be considered complete for failures. I realize we haven't been catching for failures in this try block so I've added one recently.

finally:
f.close()

Expand Down
59 changes: 59 additions & 0 deletions breadbox/breadbox/utils/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List, Optional
import time

from celery import Task


class ProgressTracker:
message: str
progress_value: float
progress_max_value: Optional[float]

prev_emitted_message: str
prev_emitted_timestamp: float

task: Optional[Task]

def __init__(self, *, task=None, delay_between_updates=5):
self.message = "Started"
self.progress_value = 0.0
self.progress_max_value = None

self.prev_emitted_message = ""
self.prev_emitted_timestamp = time.time()
self.delay_between_updates = delay_between_updates
self.task = task

def update_message(self, message: str):
self.message = message
self._emit_update(force=True)

def update_progress(self, value: float):
self.progress_value = value
self._emit_update()

def update_process_max_value(self, value):
self.progress_value = 0
self.progress_max_value = value
self._emit_update(force=True)

def _emit_update(self, force=True):
message = self.message
if self.progress_max_value is not None and self.progress_value is not None:
message = f"{message} ({ int(self.progress_value/self.progress_max_value*100) }% complete)"
if self.prev_emitted_message != message:
if force or (
(time.time() - self.prev_emitted_timestamp) > self.delay_between_updates
):
print(message)

update_state = {"message": message}

if self.task is not None:
if not self.task.request.called_directly:
self.task.update_state(
state="PROGRESS", meta=update_state,
)

self.last_update = time.time()
self.prev_message = message
2 changes: 1 addition & 1 deletion breadbox/tests/api/test_dataset_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_dataset_uploads_task(
client, tabular_data_file_bad_list_strings
)

def mock_failed_task_result(db, params, user):
def mock_failed_task_result(db, params, user, progress):
return {"result": "Column 'attr1' failed validator"}

monkeypatch.setattr(
Expand Down
23 changes: 17 additions & 6 deletions breadbox/tests/compute/test_dataset_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
AnnotationType,
)
from breadbox.schemas.custom_http_exception import FileValidationError
from breadbox.compute.dataset_uploads_tasks import dataset_upload
from breadbox.compute.dataset_uploads_tasks import dataset_upload, ProgressTracker

from tests import factories

Expand Down Expand Up @@ -64,7 +64,9 @@ def test_matrix_dataset_uploads(
**_default_params
)
user = settings.admin_users[0]
matrix_dataset_w_simple_metadata = dataset_upload(minimal_db, matrix_params, user)
matrix_dataset_w_simple_metadata = dataset_upload(
minimal_db, matrix_params, user, ProgressTracker()
)
assert matrix_dataset_w_simple_metadata.datasetId
dataset_id = matrix_dataset_w_simple_metadata.datasetId

Expand Down Expand Up @@ -99,7 +101,10 @@ def test_matrix_dataset_uploads(
**_default_params
)
matrix_only_sample_type = dataset_upload(
minimal_db, matrix_params_only_sample_type, settings.admin_users[0]
minimal_db,
matrix_params_only_sample_type,
settings.admin_users[0],
ProgressTracker(),
)
assert matrix_only_sample_type.datasetId
dataset_id = matrix_only_sample_type.datasetId
Expand Down Expand Up @@ -174,7 +179,9 @@ def test_tabular_uploads(
**_default_params
)
user = settings.admin_users[0]
tabular_dataset = dataset_upload(minimal_db, tabular_params, user)
tabular_dataset = dataset_upload(
minimal_db, tabular_params, user, ProgressTracker()
)
assert tabular_dataset.datasetId
tabular_dataset_id = tabular_dataset.datasetId
dataset = minimal_db.query(Dataset).filter(Dataset.id == tabular_dataset_id).one()
Expand Down Expand Up @@ -261,7 +268,9 @@ def test_tabular_bad_list_str_col(minimal_db, client, settings, private_group):
},
**_default_params
)
dataset_upload(minimal_db, bad_list_str_params, settings.admin_users[0])
dataset_upload(
minimal_db, bad_list_str_params, settings.admin_users[0], ProgressTracker()
)


def test_tabular_dup_ids_failure(client, private_group, minimal_db, settings):
Expand Down Expand Up @@ -299,4 +308,6 @@ def test_tabular_dup_ids_failure(client, private_group, minimal_db, settings):
},
**_default_params
)
dataset_upload(minimal_db, repeated_ids_params, settings.admin_users[0])
dataset_upload(
minimal_db, repeated_ids_params, settings.admin_users[0], ProgressTracker()
)
5 changes: 4 additions & 1 deletion breadbox/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi.exceptions import HTTPException

from breadbox.api.dependencies import get_db_with_user, get_user
from breadbox.utils.progress_tracker import ProgressTracker
from breadbox.config import Settings, get_settings
from breadbox.db.base import Base
from breadbox.db.session import SessionWithUser, SessionLocalWithUser
Expand Down Expand Up @@ -248,7 +249,9 @@ def mock_run_dataset_upload_task(dataset_params, user):
else:
params = TableDatasetParams(**dataset_params)
minimal_db.reset_user(user)
return dataset_uploads_tasks.dataset_upload(minimal_db, params, user)
return dataset_uploads_tasks.dataset_upload(
minimal_db, params, user, ProgressTracker()
)

def mock_return_task(result):
from celery.result import EagerResult
Expand Down
14 changes: 12 additions & 2 deletions breadbox/tests/io/test_hdf5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pytest
import h5py

from breadbox.utils.progress_tracker import ProgressTracker


@pytest.fixture
def test_dataframe(row_length: int = 500, col_length: int = 11000):
Expand Down Expand Up @@ -61,7 +63,11 @@ def test_write_parquet_to_hdf5(tmpdir, test_dataframe, test_parquet_file):

# override batch size to force multiple batches
write_hdf5_file(
path=str(output_h5), df_wrapper=wrapper, dtype="float", batch_size=1000
path=str(output_h5),
df_wrapper=wrapper,
dtype="float",
batch_size=1000,
progress=ProgressTracker(),
)

# Verify output
Expand Down Expand Up @@ -102,7 +108,11 @@ def test_write_parquet_nulls_to_hdf5(tmpdir):

# override batch size to force multiple batches
write_hdf5_file(
path=str(output_h5), df_wrapper=wrapper, dtype="float", batch_size=1000
path=str(output_h5),
df_wrapper=wrapper,
dtype="float",
batch_size=1000,
progress=ProgressTracker(),
)

# Verify output
Expand Down
Loading