Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion heat/core/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def __repr__(dndarray) -> str:
Returns a printable representation of the passed DNDarray.
Unlike the __str__ method, which prints a representation targeted at users, this method targets developers by showing key internal parameters of the DNDarray.
"""
return f"<DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__})>"
tensor_string = torch._tensor_str._tensor_str(dndarray.larray, __INDENT + 1)
return f"DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__}, Data:\n{' ' * __INDENT} {tensor_string})"


def _torch_data(dndarray, summarize) -> DNDarray:
Expand Down
26 changes: 24 additions & 2 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,36 @@ def test_split_2_above_threshold(self):
self.assertEqual(comparison, __str)

def test___repr__(self):
a = ht.array([1, 2, 3, 4])
a = ht.array([1, 2, 3, 4], split=0)
r = a.__repr__()
expect_meta = f"DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__}, Data:"
self.assertEqual(r[:r.index('\n')], expect_meta)

if ht.comm.size == 1:
loc_data_str = '1, 2, 3, 4'
elif ht.comm.size == 2:
loc_data_str = f'{ht.comm.rank*2+1}, {ht.comm.rank*2+2}'
elif ht.comm.size == 3:
if ht.comm.rank == 0:
loc_data_str = '1, 2'
else:
loc_data_str = f'{ht.comm.rank + 2}'
else:
if ht.comm.rank < 4:
loc_data_str = f'{ht.comm.rank + 1}'
else:
loc_data_str = ''

expect = f'{expect_meta}\n [{loc_data_str}])'

self.assertEqual(
r,
f"<DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__})>",
expect,
)




class TestPrintingGPU(TestCase):
def test_print_GPU(self):
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...
Expand Down