Skip to content

Commit

Permalink
Added incremental mode 'extend' (#388)
Browse files Browse the repository at this point in the history
* Added incremental mode 'extend', that downloads files into new shards, and updated README.

* Made param 'start_shard_id' optional, and added docstring.

* fix lint

---------

Co-authored-by: Edward Guilfoyle <[email protected]>
  • Loading branch information
rom1504 and edwardguil authored Jan 13, 2024
1 parent e46a6c0 commit c48952a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ This module exposes a single function `download` which takes the same arguments
* **min_image_size** minimum size of the image to download (default *0*)
* **max_image_area** maximum area of the image to download (default *inf*)
* **max_aspect_ratio** maximum aspect ratio of the image to download (default *inf*)
* **incremental_mode** Can be "incremental" or "overwrite". For "incremental", img2dataset will download all the shards that were not downloaded, for "overwrite" img2dataset will delete recursively the output folder then start from zero (default *incremental*)
* **incremental_mode** Can be "incremental", "overwrite" or "extend". For "incremental", img2dataset will download all the shards that were not downloaded, for "overwrite" img2dataset will delete recursively the output folder then start from zero, for "extend" img2dataset will download shards from the next available shard number (default *incremental*)
* **max_shard_retry** Number of time to retry failed shards at the end (default *1*)
* **user_agent_token** Additional identifying token that will be added to the User-Agent header sent with HTTP requests to download images; for example: "img2downloader". (default *None*)
* **disallowed_header_directives** List of X-Robots-Tags header directives that, if present in HTTP response when downloading an image, will cause the image to be excluded from the output dataset. To ignore x-robots-tags, pass '[]'. (default '["noai", "noimageai", "noindex", "noimageindex"]')
Expand Down
6 changes: 6 additions & 0 deletions img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
save_caption = caption_col is not None

fs, output_path = fsspec.core.url_to_fs(output_folder)
start_shard_id = 0

if not fs.exists(output_path):
fs.mkdir(output_path)
Expand All @@ -158,6 +159,10 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
fs.rm(output_path, recursive=True)
fs.mkdir(output_path)
done_shards = set()
elif incremental_mode == "extend":
existing_shards = [int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json")]
start_shard_id = max(existing_shards, default=-1) + 1
done_shards = set()
else:
raise ValueError(f"Unknown incremental mode {incremental_mode}")

Expand Down Expand Up @@ -187,6 +192,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
number_sample_per_shard,
done_shards,
tmp_path,
start_shard_id,
)

if output_format == "webdataset":
Expand Down
5 changes: 4 additions & 1 deletion img2dataset/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Reader:
- save_additional_columns: the list of additional columns to save
- number_sample_per_shard: the number of samples per shard
- done_shards: a set of already done shards
- start_shard_id: the shard id to begin downloading from
"""

def __init__(
Expand All @@ -39,6 +40,7 @@ def __init__(
number_sample_per_shard,
done_shards,
tmp_path,
start_shard_id: int = 0,
) -> None:
self.input_format = input_format
self.url_col = url_col
Expand All @@ -48,6 +50,7 @@ def __init__(
self.save_additional_columns = save_additional_columns
self.number_sample_per_shard = number_sample_per_shard
self.done_shards = done_shards
self.start_shard_id = start_shard_id

fs, url_path = fsspec.core.url_to_fs(url_list)
self.fs = fs
Expand Down Expand Up @@ -190,7 +193,7 @@ def __iter__(self):
shard is a tuple (sample id, sample)
sample is a tuple of the columns
"""
start_shard_id = 0
start_shard_id = self.start_shard_id
for i, input_file in enumerate(self.input_files):
print("Sharding file number " + str(i + 1) + " of " + str(len(self.input_files)) + " called " + input_file)

Expand Down

0 comments on commit c48952a

Please sign in to comment.