-
Notifications
You must be signed in to change notification settings - Fork 1
feat(breadbox): Add upload progress updates #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from typing import List, Optional, Literal | ||
|
||
from breadbox.utils.progress_tracker import ProgressTracker | ||
from breadbox.schemas.custom_http_exception import FileValidationError | ||
from breadbox.schemas.dataframe_wrapper import ParquetDataFrameWrapper | ||
import h5py | ||
|
@@ -31,6 +32,7 @@ def write_hdf5_file( | |
path: str, | ||
df_wrapper: DataFrameWrapper, | ||
dtype: Literal["float", "str"], | ||
progress: ProgressTracker, | ||
batch_size: int = 5000, # Adjust batch size as needed | ||
): | ||
f = h5py.File(path, mode="w") | ||
|
@@ -79,7 +81,11 @@ def write_hdf5_file( | |
dtype=h5py.string_dtype() if dtype == "str" else np.float64, | ||
) | ||
|
||
progress.update_message("transforming columns") | ||
progress.update_process_max_value(len(cols)) | ||
for i in range(0, len(cols), batch_size): | ||
progress.update_progress(i) | ||
|
||
# Find the correct column slice to write | ||
end_col = i + batch_size | ||
|
||
|
@@ -102,8 +108,11 @@ def write_hdf5_file( | |
f"Failed to update {i}:{end_col} of hdf5 file {path} with {values}" | ||
) from e | ||
|
||
progress.update_process_max_value(None) | ||
progress.update_message("Creating indexes in database") | ||
create_index_dataset(f, "features", pd.Index(df_wrapper.get_column_names())) | ||
create_index_dataset(f, "samples", pd.Index(df_wrapper.get_index_names())) | ||
progress.update_message("Complete") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should put this in the finally statement? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rcreasi That's true it doesn't make sense for progress to be considered complete for failures. I realize we haven't been catching for failures in this try block so I've added one recently. |
||
finally: | ||
f.close() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import List, Optional | ||
import time | ||
|
||
from celery import Task | ||
|
||
|
||
class ProgressTracker: | ||
message: str | ||
progress_value: float | ||
progress_max_value: Optional[float] | ||
|
||
prev_emitted_message: str | ||
prev_emitted_timestamp: float | ||
|
||
task: Optional[Task] | ||
|
||
def __init__(self, *, task=None, delay_between_updates=5): | ||
self.message = "Started" | ||
self.progress_value = 0.0 | ||
self.progress_max_value = None | ||
|
||
self.prev_emitted_message = "" | ||
self.prev_emitted_timestamp = time.time() | ||
self.delay_between_updates = delay_between_updates | ||
self.task = task | ||
|
||
def update_message(self, message: str): | ||
self.message = message | ||
self._emit_update(force=True) | ||
|
||
def update_progress(self, value: float): | ||
self.progress_value = value | ||
self._emit_update() | ||
|
||
def update_process_max_value(self, value): | ||
self.progress_value = 0 | ||
self.progress_max_value = value | ||
self._emit_update(force=True) | ||
|
||
def _emit_update(self, force=True): | ||
message = self.message | ||
if self.progress_max_value is not None and self.progress_value is not None: | ||
message = f"{message} ({ int(self.progress_value/self.progress_max_value*100) }% complete)" | ||
if self.prev_emitted_message != message: | ||
if force or ( | ||
(time.time() - self.prev_emitted_timestamp) > self.delay_between_updates | ||
): | ||
print(message) | ||
|
||
update_state = {"message": message} | ||
|
||
if self.task is not None: | ||
if not self.task.request.called_directly: | ||
self.task.update_state( | ||
state="PROGRESS", meta=update_state, | ||
) | ||
|
||
self.last_update = time.time() | ||
self.prev_message = message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a legacy code path, so I don't care that it's not reporting status. We're not supposed to use this anyway.