Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion heat/core/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def __repr__(dndarray) -> str:
Returns a printable representation of the passed DNDarray.
Unlike the __str__ method, which prints a representation targeted at users, this method targets developers by showing key internal parameters of the DNDarray.
"""
return f"<DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__})>"
tensor_string = torch._tensor_str._tensor_str(dndarray.larray, __INDENT + 1)
return f"DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__}\n{' ' * __INDENT} {tensor_string})"


def _torch_data(dndarray, summarize) -> DNDarray:
Expand Down
26 changes: 24 additions & 2 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,36 @@ def test_split_2_above_threshold(self):
self.assertEqual(comparison, __str)

def test___repr__(self):
a = ht.array([1, 2, 3, 4])
a = ht.array([1, 2, 3, 4], split=0)
r = a.__repr__()
expect_meta = f"DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__}"
self.assertEqual(r[:r.index('\n')], expect_meta)

if ht.comm.size == 1:
loc_data_str = '1, 2, 3, 4'
elif ht.comm.size == 2:
loc_data_str = f'{ht.comm.rank*2+1}, {ht.comm.rank*2+2}'
elif ht.comm.size == 3:
if ht.comm.rank == 0:
loc_data_str = '1, 2'
else:
loc_data_str = f'{ht.comm.rank + 2}'
else:
if ht.comm.rank < 4:
loc_data_str = f'{ht.comm.rank + 1}'
else:
loc_data_str = ''

expect = f'{expect_meta}\n [{loc_data_str}])'

self.assertEqual(
r,
f"<DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__})>",
expect,
)




class TestPrintingGPU(TestCase):
def test_print_GPU(self):
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...
Expand Down
Loading