Skip to content

Commit 1c5ed39

Browse files
botevKfacJaxDev
authored andcommitted
Making KFAC use pjit instead of pmap
PiperOrigin-RevId: 526595657
1 parent 3be9b1a commit 1c5ed39

13 files changed

Lines changed: 626 additions & 534 deletions

examples/datasets.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
_IMAGENET_STDDEV_RGB = (0.229, 0.224, 0.225)
3636

3737

38+
def sharded_iterator(
39+
dataset: tf.data.Dataset,
40+
sharding: jax.sharding.NamedSharding,
41+
) -> Iterator[Batch]:
42+
for batch in iter(tensorflow_datasets.as_numpy(dataset)):
43+
yield jax.device_put(batch, sharding)
44+
45+
3846
def mnist_dataset(
3947
split: str,
4048
has_labels: bool,
@@ -43,6 +51,7 @@ def mnist_dataset(
4351
repeat: bool,
4452
shuffle: bool,
4553
drop_remainder: bool,
54+
sharding: jax.sharding.NamedSharding,
4655
seed: Optional[int] = None,
4756
multi_device: bool = True,
4857
reshuffle_each_iteration: bool = True,
@@ -59,13 +68,13 @@ def mnist_dataset(
5968
shuffle: Whether to shuffle the dataset.
6069
drop_remainder: Whether to drop the remainder of the dataset if the number
6170
of data points is not divisible by the total batch size.
71+
sharding: Sharding spec for each batch.
6272
seed: Any seed to use for random pre-processing.
6373
multi_device: If the returned batch should take into account the number of
6474
devices present, in which case it will return an array with shape
6575
`(num_device, device_batch_size, ...)`.
6676
reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
6777
after each iteration.
68-
dtype: The returned data type of the images.
6978
7079
Returns:
7180
The MNIST dataset as a tensorflow dataset.
@@ -74,14 +83,7 @@ def mnist_dataset(
7483
# Set for multi devices vs single device
7584
num_devices = jax.device_count() if multi_device else 1
7685
num_local_devices = jax.local_device_count() if multi_device else 1
77-
78-
if multi_device:
79-
host_batch_shape = [num_local_devices, device_batch_size]
80-
else:
81-
host_batch_shape = [device_batch_size]
82-
8386
host_batch_size = num_local_devices * device_batch_size
84-
8587
num_examples = tfds.builder("mnist").info.splits[split].num_examples
8688

8789
if num_examples % num_devices != 0:
@@ -95,8 +97,7 @@ def preprocess_batch(
9597
"""Standard reshaping of the images to (28, 28)."""
9698
images = tf.image.convert_image_dtype(images, dtype)
9799
single_example_shape = [784] if flatten_images else [28, 28]
98-
images = tf.reshape(images, host_batch_shape + single_example_shape)
99-
labels = tf.reshape(labels, host_batch_shape)
100+
images = tf.reshape(images, [host_batch_size] + single_example_shape)
100101
if has_labels:
101102
return dict(images=images, labels=labels)
102103
else:
@@ -123,7 +124,7 @@ def preprocess_batch(
123124

124125
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
125126

126-
return iter(tensorflow_datasets.as_numpy(ds))
127+
return sharded_iterator(ds, sharding)
127128

128129

129130
def imagenet_num_examples_and_split(

examples/optimizers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __init__(
8585
axis_name=self.pmap_axis_name,
8686
)
8787

88+
@property
89+
def state_sharding(self) -> jax.sharding.NamedSharding:
90+
raise NotImplementedError()
91+
8892
def init(
8993
self,
9094
params: Params,
@@ -438,7 +442,6 @@ def create_optimizer(
438442
value_func_has_aux=has_aux,
439443
value_func_has_state=has_func_state,
440444
value_func_has_rng=has_rng,
441-
multi_device=True,
442445
**kwargs,
443446
)
444447
elif name == "sgd":

0 commit comments

Comments
 (0)