From e4942e4ae9b1875c84d91161310deba619de5366 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 30 May 2024 21:20:53 +0000 Subject: [PATCH] Update activation sharding to shape mesh --- llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index f6e6d4f2a..e898878e5 100755 --- a/llama/model.py +++ b/llama/model.py @@ -216,9 +216,9 @@ def forward( import torch_xla.experimental.dynamo_mark_sharding if self.enable_activation_sharding: device_ids = [i for i in range(self.num_devices)] - mesh_shape = [self.num_devices, 1, 1] + mesh_shape = [self.num_devices, 1] axis_names = 'None' - partition_spec = '(2, 1, 0)' + partition_spec = '(None, None, None)' torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec) return output