Skip to content

Commit

Permalink
#0: Fix ttnn.distribute(..) API for dtype=bfloat8_b
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Feb 18, 2025
1 parent 0dacb45 commit 243a961
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,15 @@ def test_line_all_gather_after_reshape(mesh_device):
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
)


@pytest.mark.parametrize("mesh_device", [pytest.param((1, 8), id="1x8_line")], indirect=True)
def test_distribute_api(mesh_device):
torch_hidden_states = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16)
with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
hidden_states = ttnn.from_torch(
torch_hidden_states,
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
)
2 changes: 1 addition & 1 deletion ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def from_torch(
if layout != ttnn.TILE_LAYOUT:
raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!")
# Tilize tensor
tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value)
tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value, mesh_mapper=None)
logical_shape = tensor.shape
padded_shape = tensor.padded_shape
tensor = tensor.reshape(tensor.padded_shape)
Expand Down

0 comments on commit 243a961

Please sign in to comment.