Skip to content

Commit 2a9cd4b

Browse files
github-actions[bot]brownbaerchenClaudiaComito
authored
[Backport stable] Print contents of array in __repr__ (#2106)
* Printing debug information only when debugger has set a trace (cherry picked from commit 24ccced) * `__repr__` now prints both debug information and array contents (cherry picked from commit d3bef15) * Added `Data` tag in `dndarray.__repr__` (cherry picked from commit 9dba529) --------- Co-authored-by: Thomas Baumann <[email protected]> Co-authored-by: Claudia Comito <[email protected]>
1 parent 15cb8ac commit 2a9cd4b

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

heat/core/printing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def __repr__(dndarray) -> str:
210210
Returns a printable representation of the passed DNDarray.
211211
Unlike the __str__ method, which prints a representation targeted at users, this method targets developers by showing key internal parameters of the DNDarray.
212212
"""
213-
return f"<DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__})>"
213+
tensor_string = torch._tensor_str._tensor_str(dndarray.larray, __INDENT + 1)
214+
return f"DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__}, Data:\n{' ' * __INDENT} {tensor_string})"
214215

215216

216217
def _torch_data(dndarray, summarize) -> DNDarray:

heat/core/tests/test_printing.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,36 @@ def test_split_2_above_threshold(self):
431431
self.assertEqual(comparison, __str)
432432

433433
def test___repr__(self):
434-
a = ht.array([1, 2, 3, 4])
434+
a = ht.array([1, 2, 3, 4], split=0)
435435
r = a.__repr__()
436+
expect_meta = f"DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__}, Data:"
437+
self.assertEqual(r[:r.index('\n')], expect_meta)
438+
439+
if ht.comm.size == 1:
440+
loc_data_str = '1, 2, 3, 4'
441+
elif ht.comm.size == 2:
442+
loc_data_str = f'{ht.comm.rank*2+1}, {ht.comm.rank*2+2}'
443+
elif ht.comm.size == 3:
444+
if ht.comm.rank == 0:
445+
loc_data_str = '1, 2'
446+
else:
447+
loc_data_str = f'{ht.comm.rank + 2}'
448+
else:
449+
if ht.comm.rank < 4:
450+
loc_data_str = f'{ht.comm.rank + 1}'
451+
else:
452+
loc_data_str = ''
453+
454+
expect = f'{expect_meta}\n [{loc_data_str}])'
455+
436456
self.assertEqual(
437457
r,
438-
f"<DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__})>",
458+
expect,
439459
)
440460

441461

462+
463+
442464
class TestPrintingGPU(TestCase):
443465
def test_print_GPU(self):
444466
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...

0 commit comments

Comments
 (0)