-
Notifications
You must be signed in to change notification settings - Fork 61
Closed
Labels
Description
Description
np.nonzero(a) returns a tuple of 1-D ndarrays with length a.ndim. This can be then used directly and without modifications for advanced indexing across dimensions.
>>> a
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> mask
array([[False, True, True],
[False, True, True],
[False, True, True]])
>>> a[mask].shape
(6,)
>>> a[np.nonzero(mask)].shape
(6,)
>>> a[mask] == a[np.nonzero(mask)]
array([ True, True, True, True, True, True])
ht.nonzero returns the indices in torch.nonzero format, in the example above it returns a DNDarray with shape (6,2).
This is not what the users expect, and complicates our advanced indexing unnecessarily.
Will be addressed together with #914
Expected behavior
ht.nonzero should return a tuple of 1-D dndarrays, for API consistency and to simplify heat's advanced indexing implementation
Version Info
main