3535_IMAGENET_STDDEV_RGB = (0.229 , 0.224 , 0.225 )
3636
3737
38+ def sharded_iterator (
39+ dataset : tf .data .Dataset ,
40+ sharding : jax .sharding .NamedSharding ,
41+ ) -> Iterator [Batch ]:
42+ for batch in iter (tensorflow_datasets .as_numpy (dataset )):
43+ yield jax .device_put (batch , sharding )
44+
45+
3846def mnist_dataset (
3947 split : str ,
4048 has_labels : bool ,
@@ -43,6 +51,7 @@ def mnist_dataset(
4351 repeat : bool ,
4452 shuffle : bool ,
4553 drop_remainder : bool ,
54+ sharding : jax .sharding .NamedSharding ,
4655 seed : Optional [int ] = None ,
4756 multi_device : bool = True ,
4857 reshuffle_each_iteration : bool = True ,
@@ -59,13 +68,13 @@ def mnist_dataset(
5968 shuffle: Whether to shuffle the dataset.
6069 drop_remainder: Whether to drop the remainder of the dataset if the number
6170 of data points is not divisible by the total batch size.
71+ sharding: Sharding spec for each batch.
6272 seed: Any seed to use for random pre-processing.
6373 multi_device: If the returned batch should take into account the number of
6474 devices present, in which case it will return an array with shape
6575 `(num_device, device_batch_size, ...)`.
6676 reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
6777 after each iteration.
68- dtype: The returned data type of the images.
6978
7079 Returns:
7180 The MNIST dataset as a tensorflow dataset.
@@ -74,14 +83,7 @@ def mnist_dataset(
7483 # Set for multi devices vs single device
7584 num_devices = jax .device_count () if multi_device else 1
7685 num_local_devices = jax .local_device_count () if multi_device else 1
77-
78- if multi_device :
79- host_batch_shape = [num_local_devices , device_batch_size ]
80- else :
81- host_batch_shape = [device_batch_size ]
82-
8386 host_batch_size = num_local_devices * device_batch_size
84-
8587 num_examples = tfds .builder ("mnist" ).info .splits [split ].num_examples
8688
8789 if num_examples % num_devices != 0 :
@@ -95,8 +97,7 @@ def preprocess_batch(
9597 """Standard reshaping of the images to (28, 28)."""
9698 images = tf .image .convert_image_dtype (images , dtype )
9799 single_example_shape = [784 ] if flatten_images else [28 , 28 ]
98- images = tf .reshape (images , host_batch_shape + single_example_shape )
99- labels = tf .reshape (labels , host_batch_shape )
100+ images = tf .reshape (images , [host_batch_size ] + single_example_shape )
100101 if has_labels :
101102 return dict (images = images , labels = labels )
102103 else :
@@ -123,7 +124,7 @@ def preprocess_batch(
123124
124125 ds = ds .prefetch (tf .data .experimental .AUTOTUNE )
125126
126- return iter ( tensorflow_datasets . as_numpy ( ds ) )
127+ return sharded_iterator ( ds , sharding )
127128
128129
129130def imagenet_num_examples_and_split (
0 commit comments