Skip to content

Commit 1052324

Browse files
Revert "[Backport stable] Print contents of array in __repr__ (#2106)"
This reverts commit 2a9cd4b.
1 parent 2a9cd4b commit 1052324

File tree

2 files changed

+3
-26
lines changed

2 files changed

+3
-26
lines changed

heat/core/printing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def __repr__(dndarray) -> str:
210210
Returns a printable representation of the passed DNDarray.
211211
Unlike the __str__ method, which prints a representation targeted at users, this method targets developers by showing key internal parameters of the DNDarray.
212212
"""
213-
tensor_string = torch._tensor_str._tensor_str(dndarray.larray, __INDENT + 1)
214-
return f"DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__}, Data:\n{' ' * __INDENT} {tensor_string})"
213+
return f"<DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__})>"
215214

216215

217216
def _torch_data(dndarray, summarize) -> DNDarray:

heat/core/tests/test_printing.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -431,36 +431,14 @@ def test_split_2_above_threshold(self):
431431
self.assertEqual(comparison, __str)
432432

433433
def test___repr__(self):
434-
a = ht.array([1, 2, 3, 4], split=0)
434+
a = ht.array([1, 2, 3, 4])
435435
r = a.__repr__()
436-
expect_meta = f"DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__}, Data:"
437-
self.assertEqual(r[:r.index('\n')], expect_meta)
438-
439-
if ht.comm.size == 1:
440-
loc_data_str = '1, 2, 3, 4'
441-
elif ht.comm.size == 2:
442-
loc_data_str = f'{ht.comm.rank*2+1}, {ht.comm.rank*2+2}'
443-
elif ht.comm.size == 3:
444-
if ht.comm.rank == 0:
445-
loc_data_str = '1, 2'
446-
else:
447-
loc_data_str = f'{ht.comm.rank + 2}'
448-
else:
449-
if ht.comm.rank < 4:
450-
loc_data_str = f'{ht.comm.rank + 1}'
451-
else:
452-
loc_data_str = ''
453-
454-
expect = f'{expect_meta}\n [{loc_data_str}])'
455-
456436
self.assertEqual(
457437
r,
458-
expect,
438+
f"<DNDarray(MPI-rank: {a.comm.rank}, Shape: {a.shape}, Split: {a.split}, Local Shape: {a.lshape}, Device: {a.device}, Dtype: {a.dtype.__name__})>",
459439
)
460440

461441

462-
463-
464442
class TestPrintingGPU(TestCase):
465443
def test_print_GPU(self):
466444
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...

0 commit comments

Comments
 (0)