diff --git a/breadbox/breadbox/compute/dataset_uploads_tasks.py b/breadbox/breadbox/compute/dataset_uploads_tasks.py index 1b5414ae4..3c109abf0 100644 --- a/breadbox/breadbox/compute/dataset_uploads_tasks.py +++ b/breadbox/breadbox/compute/dataset_uploads_tasks.py @@ -34,6 +34,7 @@ from .celery import app, LogErrorsTask import celery from ..config import get_settings +from ..utils.progress_tracker import ProgressTracker @app.task(base=LogErrorsTask, bind=True) @@ -46,7 +47,7 @@ def run_dataset_upload( else: params: DatasetParams = TableDatasetParams(**dataset_params) - upload_dataset_response = dataset_upload(db, params, user) + upload_dataset_response = dataset_upload(db, params, user, ProgressTracker()) # because celery is going to want to serialize the response, # convert it to a json dict before returning it @@ -55,8 +56,12 @@ def run_dataset_upload( def dataset_upload( - db: SessionWithUser, dataset_params: DatasetParams, user: str, + db: SessionWithUser, + dataset_params: DatasetParams, + user: str, + progress: ProgressTracker, ): + progress.update_message("started processing of uploaded file") settings = get_settings() # NOTE: We make this check in the dataset_crud.add_dataset function too, because we @@ -80,6 +85,8 @@ def dataset_upload( settings.compute_results_location, ) + progress.update_message("reassembled uploaded file") + dataset_id = str(uuid4()) unknown_ids = [] @@ -92,6 +99,8 @@ def dataset_upload( ) sample_type = _get_dimension_type(db, dataset_params.sample_type, "sample") + progress.update_message("starting validation") + df_wrapper = read_and_validate_matrix_df( file_path, dataset_params.value_type, @@ -154,11 +163,14 @@ def dataset_upload( dataset_params.version, dataset_params.description, ) + + progress.update_message("saving final version of dataset") save_dataset_file( dataset_id, df_wrapper, dataset_params.value_type, settings.filestore_location, + progress, ) else: diff --git a/breadbox/breadbox/io/data_validation.py b/breadbox/breadbox/io/data_validation.py index 904f74617..05645654a 100644 --- a/breadbox/breadbox/io/data_validation.py +++ b/breadbox/breadbox/io/data_validation.py @@ -35,6 +35,7 @@ from ..crud.dimension_types import get_dimension_type import pyarrow +from ..utils.progress_tracker import ProgressTracker pd.set_option("mode.use_inf_as_na", True) @@ -469,7 +470,9 @@ def validate_and_upload_dataset_files( ) # TODO: Move save function to api layer. Need to make sure the db save is successful first - save_dataset_file(dataset_id, data_dfw, value_type, filestore_location) + save_dataset_file( + dataset_id, data_dfw, value_type, filestore_location, ProgressTracker() + ) return dataframe_validated_dimensions diff --git a/breadbox/breadbox/io/filestore_crud.py b/breadbox/breadbox/io/filestore_crud.py index d32b7d1e4..68f20afe8 100644 --- a/breadbox/breadbox/io/filestore_crud.py +++ b/breadbox/breadbox/io/filestore_crud.py @@ -5,6 +5,7 @@ import pandas as pd +from ..utils.progress_tracker import ProgressTracker from ..schemas.dataframe_wrapper import DataFrameWrapper from ..models.dataset import Dataset, MatrixDataset, ValueType from .hdf5_utils import write_hdf5_file, read_hdf5_file @@ -17,6 +18,7 @@ def save_dataset_file( df_wrapper: DataFrameWrapper, value_type: ValueType, filestore_location: str, + progress: ProgressTracker, ): base_path = os.path.join(filestore_location, dataset_id) os.makedirs(base_path) @@ -27,7 +29,10 @@ def save_dataset_file( dtype = "float" write_hdf5_file( - get_file_location(dataset_id, filestore_location, DATA_FILE), df_wrapper, dtype + get_file_location(dataset_id, filestore_location, DATA_FILE), + df_wrapper, + dtype, + progress, ) diff --git a/breadbox/breadbox/io/hdf5_utils.py b/breadbox/breadbox/io/hdf5_utils.py index d963160f1..cb3161b05 100644 --- a/breadbox/breadbox/io/hdf5_utils.py +++ b/breadbox/breadbox/io/hdf5_utils.py @@ -1,5 +1,6 @@ from typing import List, Optional, Literal +from breadbox.utils.progress_tracker import ProgressTracker from breadbox.schemas.custom_http_exception import FileValidationError from breadbox.schemas.dataframe_wrapper import ParquetDataFrameWrapper import h5py @@ -31,6 +32,7 @@ def write_hdf5_file( path: str, df_wrapper: DataFrameWrapper, dtype: Literal["float", "str"], + progress: ProgressTracker, batch_size: int = 5000, # Adjust batch size as needed ): f = h5py.File(path, mode="w") @@ -79,7 +81,11 @@ def write_hdf5_file( dtype=h5py.string_dtype() if dtype == "str" else np.float64, ) + progress.update_message("transforming columns") + progress.update_process_max_value(len(cols)) for i in range(0, len(cols), batch_size): + progress.update_progress(i) + # Find the correct column slice to write end_col = i + batch_size @@ -102,8 +108,11 @@ def write_hdf5_file( f"Failed to update {i}:{end_col} of hdf5 file {path} with {values}" ) from e + progress.update_process_max_value(None) + progress.update_message("Creating indexes in database") create_index_dataset(f, "features", pd.Index(df_wrapper.get_column_names())) create_index_dataset(f, "samples", pd.Index(df_wrapper.get_index_names())) + progress.update_message("Complete") finally: f.close() diff --git a/breadbox/breadbox/utils/progress_tracker.py b/breadbox/breadbox/utils/progress_tracker.py new file mode 100644 index 000000000..afa31c2fd --- /dev/null +++ b/breadbox/breadbox/utils/progress_tracker.py @@ -0,0 +1,59 @@ +from typing import List, Optional +import time + +from celery import Task + + +class ProgressTracker: + message: str + progress_value: float + progress_max_value: Optional[float] + + prev_emitted_message: str + prev_emitted_timestamp: float + + task: Optional[Task] + + def __init__(self, *, task=None, delay_between_updates=5): + self.message = "Started" + self.progress_value = 0.0 + self.progress_max_value = None + + self.prev_emitted_message = "" + self.prev_emitted_timestamp = time.time() + self.delay_between_updates = delay_between_updates + self.task = task + + def update_message(self, message: str): + self.message = message + self._emit_update(force=True) + + def update_progress(self, value: float): + self.progress_value = value + self._emit_update() + + def update_process_max_value(self, value): + self.progress_value = 0 + self.progress_max_value = value + self._emit_update(force=True) + + def _emit_update(self, force=True): + message = self.message + if self.progress_max_value is not None and self.progress_value is not None: + message = f"{message} ({ int(self.progress_value/self.progress_max_value*100) }% complete)" + if self.prev_emitted_message != message: + if force or ( + (time.time() - self.prev_emitted_timestamp) > self.delay_between_updates + ): + print(message) + + update_state = {"message": message} + + if self.task is not None: + if not self.task.request.called_directly: + self.task.update_state( + state="PROGRESS", meta=update_state, + ) + + self.last_update = time.time() + self.prev_message = message diff --git a/breadbox/tests/api/test_dataset_uploads.py b/breadbox/tests/api/test_dataset_uploads.py index b0395f7ed..8d7248217 100644 --- a/breadbox/tests/api/test_dataset_uploads.py +++ b/breadbox/tests/api/test_dataset_uploads.py @@ -181,7 +181,7 @@ def test_dataset_uploads_task( client, tabular_data_file_bad_list_strings ) - def mock_failed_task_result(db, params, user): + def mock_failed_task_result(db, params, user, progress): return {"result": "Column 'attr1' failed validator"} monkeypatch.setattr( diff --git a/breadbox/tests/compute/test_dataset_uploads.py b/breadbox/tests/compute/test_dataset_uploads.py index 59f5885e6..d88ca6d9d 100644 --- a/breadbox/tests/compute/test_dataset_uploads.py +++ b/breadbox/tests/compute/test_dataset_uploads.py @@ -21,7 +21,7 @@ AnnotationType, ) from breadbox.schemas.custom_http_exception import FileValidationError -from breadbox.compute.dataset_uploads_tasks import dataset_upload +from breadbox.compute.dataset_uploads_tasks import dataset_upload, ProgressTracker from tests import factories @@ -64,7 +64,9 @@ def test_matrix_dataset_uploads( **_default_params ) user = settings.admin_users[0] - matrix_dataset_w_simple_metadata = dataset_upload(minimal_db, matrix_params, user) + matrix_dataset_w_simple_metadata = dataset_upload( + minimal_db, matrix_params, user, ProgressTracker() + ) assert matrix_dataset_w_simple_metadata.datasetId dataset_id = matrix_dataset_w_simple_metadata.datasetId @@ -99,7 +101,10 @@ def test_matrix_dataset_uploads( **_default_params ) matrix_only_sample_type = dataset_upload( - minimal_db, matrix_params_only_sample_type, settings.admin_users[0] + minimal_db, + matrix_params_only_sample_type, + settings.admin_users[0], + ProgressTracker(), ) assert matrix_only_sample_type.datasetId dataset_id = matrix_only_sample_type.datasetId @@ -174,7 +179,9 @@ def test_tabular_uploads( **_default_params ) user = settings.admin_users[0] - tabular_dataset = dataset_upload(minimal_db, tabular_params, user) + tabular_dataset = dataset_upload( + minimal_db, tabular_params, user, ProgressTracker() + ) assert tabular_dataset.datasetId tabular_dataset_id = tabular_dataset.datasetId dataset = minimal_db.query(Dataset).filter(Dataset.id == tabular_dataset_id).one() @@ -261,7 +268,9 @@ def test_tabular_bad_list_str_col(minimal_db, client, settings, private_group): }, **_default_params ) - dataset_upload(minimal_db, bad_list_str_params, settings.admin_users[0]) + dataset_upload( + minimal_db, bad_list_str_params, settings.admin_users[0], ProgressTracker() + ) def test_tabular_dup_ids_failure(client, private_group, minimal_db, settings): @@ -299,4 +308,6 @@ def test_tabular_dup_ids_failure(client, private_group, minimal_db, settings): }, **_default_params ) - dataset_upload(minimal_db, repeated_ids_params, settings.admin_users[0]) + dataset_upload( + minimal_db, repeated_ids_params, settings.admin_users[0], ProgressTracker() + ) diff --git a/breadbox/tests/conftest.py b/breadbox/tests/conftest.py index 1f273c91d..08308288e 100644 --- a/breadbox/tests/conftest.py +++ b/breadbox/tests/conftest.py @@ -9,6 +9,7 @@ from fastapi.exceptions import HTTPException from breadbox.api.dependencies import get_db_with_user, get_user +from breadbox.utils.progress_tracker import ProgressTracker from breadbox.config import Settings, get_settings from breadbox.db.base import Base from breadbox.db.session import SessionWithUser, SessionLocalWithUser @@ -248,7 +249,9 @@ def mock_run_dataset_upload_task(dataset_params, user): else: params = TableDatasetParams(**dataset_params) minimal_db.reset_user(user) - return dataset_uploads_tasks.dataset_upload(minimal_db, params, user) + return dataset_uploads_tasks.dataset_upload( + minimal_db, params, user, ProgressTracker() + ) def mock_return_task(result): from celery.result import EagerResult diff --git a/breadbox/tests/io/test_hdf5_utils.py b/breadbox/tests/io/test_hdf5_utils.py index c667292e8..99223d0f6 100644 --- a/breadbox/tests/io/test_hdf5_utils.py +++ b/breadbox/tests/io/test_hdf5_utils.py @@ -9,6 +9,8 @@ import pytest import h5py +from breadbox.utils.progress_tracker import ProgressTracker + @pytest.fixture def test_dataframe(row_length: int = 500, col_length: int = 11000): @@ -61,7 +63,11 @@ def test_write_parquet_to_hdf5(tmpdir, test_dataframe, test_parquet_file): # override batch size to force multiple batches write_hdf5_file( - path=str(output_h5), df_wrapper=wrapper, dtype="float", batch_size=1000 + path=str(output_h5), + df_wrapper=wrapper, + dtype="float", + batch_size=1000, + progress=ProgressTracker(), ) # Verify output @@ -102,7 +108,11 @@ def test_write_parquet_nulls_to_hdf5(tmpdir): # override batch size to force multiple batches write_hdf5_file( - path=str(output_h5), df_wrapper=wrapper, dtype="float", batch_size=1000 + path=str(output_h5), + df_wrapper=wrapper, + dtype="float", + batch_size=1000, + progress=ProgressTracker(), ) # Verify output