1515
1616"""`tfds build` command."""
1717
18- import argparse
18+ import dataclasses
1919import functools
2020import importlib
2121import itertools
2525from typing import Any , Dict , Iterator , Optional , Tuple , Type , Union
2626
2727from absl import logging
28+ import simple_parsing
2829import tensorflow_datasets as tfds
2930from tensorflow_datasets .scripts .cli import cli_utils
3031
3132# pylint: disable=logging-fstring-interpolation
3233
3334
34- def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
35- """Add subparser for `build` command."""
36- build_parser = parsers .add_parser (
37- 'build' , help = 'Commands for downloading and preparing datasets.'
38- )
39- build_parser .add_argument (
40- 'datasets' , # Positional arguments
41- type = str ,
42- nargs = '*' ,
43- help = (
44- 'Name(s) of the dataset(s) to build. Default to current dir. '
45- 'See https://www.tensorflow.org/datasets/cli for accepted values.'
46- ),
47- )
48- build_parser .add_argument ( # Also accept keyword arguments
49- '--datasets' ,
50- type = str ,
51- nargs = '+' ,
52- dest = 'datasets_keyword' ,
53- help = 'Datasets can also be provided as keyword argument.' ,
54- )
35+ @dataclasses .dataclass (frozen = True , kw_only = True )
36+ class _AutomationGroup :
37+ """Used by automated scripts.
5538
56- cli_utils .add_debug_argument_group (build_parser )
57- cli_utils .add_path_argument_group (build_parser )
58- cli_utils .add_generation_argument_group (build_parser )
59- cli_utils .add_publish_argument_group (build_parser )
39+ Attributes:
40+ exclude_datasets: If set, generate all datasets except the one defined here.
41+ Comma separated list of datasets to exclude.
42+ experimental_latest_version: Build the latest Version(experiments=...)
43+ available rather than default version.
44+ """
6045
61- # **** Automation options ****
62- automation_group = build_parser .add_argument_group (
63- 'Automation' , description = 'Used by automated scripts.'
64- )
65- automation_group .add_argument (
66- '--exclude_datasets' ,
67- type = str ,
68- help = (
69- 'If set, generate all datasets except the one defined here. '
70- 'Comma separated list of datasets to exclude. '
71- ),
46+ exclude_datasets : list [str ] = cli_utils .comma_separated_list_field ()
47+ experimental_latest_version : bool = False
48+
49+
50+ @dataclasses .dataclass (frozen = True , kw_only = True )
51+ class CmdArgs :
52+ """Commands for downloading and preparing datasets.
53+
54+ Attributes:
55+ datasets: Name(s) of the dataset(s) to build. Default to current dir. See
56+ https://www.tensorflow.org/datasets/cli for accepted values.
57+ datasets_keyword: Datasets can also be provided as keyword argument.
58+ debug: Debug & tests options.
59+ path: Paths options.
60+ generation: Generation options.
61+ publish: Publishing options.
62+ automation: Automation options.
63+ """
64+
65+ datasets : list [str ] = simple_parsing .field (
66+ positional = True , default_factory = list , nargs = '*'
7267 )
73- automation_group .add_argument (
74- '--experimental_latest_version' ,
75- action = 'store_true' ,
76- help = (
77- 'Build the latest Version(experiments=...) available rather than '
78- 'default version.'
79- ),
68+ datasets_keyword : list [str ] = simple_parsing .field (
69+ alias = 'datasets' , default_factory = list , nargs = '*'
8070 )
71+ debug : cli_utils .DebugGroup = simple_parsing .field (prefix = '' )
72+ path : cli_utils .PathGroup = simple_parsing .field (prefix = '' )
73+ generation : cli_utils .GenerationGroup = simple_parsing .field (prefix = '' )
74+ publish : cli_utils .PublishGroup = simple_parsing .field (prefix = '' )
75+ automation : _AutomationGroup = simple_parsing .field (prefix = '' )
8176
82- build_parser .set_defaults (subparser_fn = _build_datasets )
77+ def execute (self ):
78+ _build_datasets (self )
8379
8480
85- def _build_datasets (args : argparse . Namespace ) -> None :
81+ def _build_datasets (args : CmdArgs ) -> None :
8682 """Build the given datasets."""
8783 # Eventually register additional datasets imports
88- if args .imports :
89- list (importlib .import_module (m ) for m in args .imports . split ( ',' ) )
84+ if args .generation . imports :
85+ list (importlib .import_module (m ) for m in args .generation . imports )
9086
9187 # Select datasets to generate
92- datasets = (args .datasets or []) + (args .datasets_keyword or [])
93- if args .exclude_datasets : # Generate all datasets if `--exclude_datasets` set
88+ datasets = args .datasets + args .datasets_keyword
89+ if (
90+ args .automation .exclude_datasets
91+ ): # Generate all datasets if `--exclude_datasets` set
9492 if datasets :
9593 raise ValueError ("--exclude_datasets can't be used with `datasets`" )
9694 datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
97- args .exclude_datasets . split ( ',' )
95+ args .automation . exclude_datasets
9896 )
9997 datasets = sorted (datasets ) # `set` is not deterministic
10098 else :
10199 datasets = datasets or ['' ] # Empty string for default
102100
103101 # Import builder classes
104102 builders_cls_and_kwargs = [
105- _get_builder_cls_and_kwargs (dataset , has_imports = bool (args .imports ))
103+ _get_builder_cls_and_kwargs (
104+ dataset , has_imports = bool (args .generation .imports )
105+ )
106106 for dataset in datasets
107107 ]
108108
@@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None:
112112 for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
113113 ))
114114 process_builder_fn = functools .partial (
115- _download if args .download_only else _download_and_prepare , args
115+ _download if args .generation .download_only else _download_and_prepare ,
116+ args ,
116117 )
117118
118- if args .num_processes == 1 :
119+ if args .generation . num_processes == 1 :
119120 for builder in builders :
120121 process_builder_fn (builder )
121122 else :
122- with multiprocessing .Pool (args .num_processes ) as pool :
123+ with multiprocessing .Pool (args .generation . num_processes ) as pool :
123124 pool .map (process_builder_fn , builders )
124125
125126
126127def _make_builders (
127- args : argparse . Namespace ,
128+ args : CmdArgs ,
128129 builder_cls : Type [tfds .core .DatasetBuilder ],
129130 builder_kwargs : Dict [str , Any ],
130131) -> Iterator [tfds .core .DatasetBuilder ]:
@@ -139,7 +140,7 @@ def _make_builders(
139140 Initialized dataset builders.
140141 """
141142 # Eventually overwrite version
142- if args .experimental_latest_version :
143+ if args .automation . experimental_latest_version :
143144 if 'version' in builder_kwargs :
144145 raise ValueError (
145146 "Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -150,19 +151,19 @@ def _make_builders(
150151 builder_kwargs ['config' ] = _get_config_name (
151152 builder_cls = builder_cls ,
152153 config_kwarg = builder_kwargs .get ('config' ),
153- config_name = args .config ,
154- config_idx = args .config_idx ,
154+ config_name = args .generation . config ,
155+ config_idx = args .generation . config_idx ,
155156 )
156157
157- if args .file_format :
158- builder_kwargs ['file_format' ] = args .file_format
158+ if args .generation . file_format :
159+ builder_kwargs ['file_format' ] = args .generation . file_format
159160
160161 make_builder = functools .partial (
161162 _make_builder ,
162163 builder_cls ,
163- overwrite = args .overwrite ,
164- fail_if_exists = args .fail_if_exists ,
165- data_dir = args .data_dir ,
164+ overwrite = args .debug . overwrite ,
165+ fail_if_exists = args .debug . fail_if_exists ,
166+ data_dir = args .path . data_dir ,
166167 ** builder_kwargs ,
167168 )
168169
@@ -301,7 +302,7 @@ def _make_builder(
301302
302303
303304def _download (
304- args : argparse . Namespace ,
305+ args : CmdArgs ,
305306 builder : tfds .core .DatasetBuilder ,
306307) -> None :
307308 """Downloads all files of the given builder."""
@@ -323,7 +324,7 @@ def _download(
323324 if builder .MAX_SIMULTANEOUS_DOWNLOADS is not None :
324325 max_simultaneous_downloads = builder .MAX_SIMULTANEOUS_DOWNLOADS
325326
326- download_dir = args .download_dir or os .path .join (
327+ download_dir = args .path . download_dir or os .path .join (
327328 builder ._data_dir_root , 'downloads' # pylint: disable=protected-access
328329 )
329330 dl_manager = tfds .download .DownloadManager (
@@ -345,51 +346,51 @@ def _download(
345346
346347
347348def _download_and_prepare (
348- args : argparse . Namespace ,
349+ args : CmdArgs ,
349350 builder : tfds .core .DatasetBuilder ,
350351) -> None :
351352 """Generate a single builder."""
352353 cli_utils .download_and_prepare (
353354 builder = builder ,
354355 download_config = _make_download_config (args , dataset_name = builder .name ),
355- download_dir = args .download_dir ,
356- publish_dir = args .publish_dir ,
357- skip_if_published = args .skip_if_published ,
358- overwrite = args .overwrite ,
356+ download_dir = args .path . download_dir ,
357+ publish_dir = args .publish . publish_dir ,
358+ skip_if_published = args .publish . skip_if_published ,
359+ overwrite = args .debug . overwrite ,
359360 )
360361
361362
362363def _make_download_config (
363- args : argparse . Namespace ,
364+ args : CmdArgs ,
364365 dataset_name : str ,
365366) -> tfds .download .DownloadConfig :
366367 """Generate the download and prepare configuration."""
367368 # Load the download config
368- manual_dir = args .manual_dir
369- if args .add_name_to_manual_dir :
369+ manual_dir = args .path . manual_dir
370+ if args .path . add_name_to_manual_dir :
370371 manual_dir = manual_dir / dataset_name
371372
372373 kwargs = {}
373- if args .max_shard_size_mb :
374- kwargs ['max_shard_size' ] = args .max_shard_size_mb << 20
375- if args .download_config :
376- kwargs .update (json .loads (args .download_config ))
374+ if args .generation . max_shard_size_mb :
375+ kwargs ['max_shard_size' ] = args .generation . max_shard_size_mb << 20
376+ if args .generation . download_config :
377+ kwargs .update (json .loads (args .generation . download_config ))
377378
378379 if 'download_mode' in kwargs :
379380 kwargs ['download_mode' ] = tfds .download .GenerateMode (
380381 kwargs ['download_mode' ]
381382 )
382383 else :
383384 kwargs ['download_mode' ] = tfds .download .GenerateMode .REUSE_DATASET_IF_EXISTS
384- if args .update_metadata_only :
385+ if args .generation . update_metadata_only :
385386 kwargs ['download_mode' ] = tfds .download .GenerateMode .UPDATE_DATASET_INFO
386387
387388 dl_config = tfds .download .DownloadConfig (
388- extract_dir = args .extract_dir ,
389+ extract_dir = args .path . extract_dir ,
389390 manual_dir = manual_dir ,
390- max_examples_per_split = args .max_examples_per_split ,
391- register_checksums = args .register_checksums ,
392- force_checksums_validation = args .force_checksums_validation ,
391+ max_examples_per_split = args .debug . max_examples_per_split ,
392+ register_checksums = args .generation . register_checksums ,
393+ force_checksums_validation = args .generation . force_checksums_validation ,
393394 ** kwargs ,
394395 )
395396
@@ -400,9 +401,9 @@ def _make_download_config(
400401 beam = None
401402
402403 if beam is not None :
403- if args .beam_pipeline_options :
404+ if args .generation . beam_pipeline_options :
404405 dl_config .beam_options = beam .options .pipeline_options .PipelineOptions (
405- flags = [f'--{ opt } ' for opt in args .beam_pipeline_options . split ( ',' ) ]
406+ flags = [f'--{ opt } ' for opt in args .generation . beam_pipeline_options ]
406407 )
407408
408409 return dl_config
0 commit comments