diff --git a/.talismanrc b/.talismanrc index 976c575f94..9da3a58796 100644 --- a/.talismanrc +++ b/.talismanrc @@ -1,3 +1,4 @@ +version: "1.0" threshold: medium allowed_patterns: - 'uses: [A-Za-z-\/]+@[\w\d]+' diff --git a/heat/classification/kneighborsclassifier.py b/heat/classification/kneighborsclassifier.py index 90d1859537..e438645f70 100644 --- a/heat/classification/kneighborsclassifier.py +++ b/heat/classification/kneighborsclassifier.py @@ -122,11 +122,11 @@ def predict(self, x: DNDarray) -> DNDarray: """ distances = self.effective_metric_(x, self.x) _, indices = ht.topk(distances, self.n_neighbors, largest=False) - predictions = self.y[indices.flatten()] + + predictions = self.y[indices] predictions.balance_() - predictions = ht.reshape(predictions, (indices.gshape + (self.y.gshape[1],))) + predictions = ht.reshape(predictions, indices.gshape + (self.y.gshape[1],)) predictions = ht.sum(predictions, axis=1) self.classes_ = ht.argmax(predictions, axis=1) - return self.classes_ diff --git a/heat/cluster/batchparallelclustering.py b/heat/cluster/batchparallelclustering.py index de795cdb89..d6d64e3b1c 100644 --- a/heat/cluster/batchparallelclustering.py +++ b/heat/cluster/batchparallelclustering.py @@ -42,7 +42,25 @@ def _initialize_plus_plus( for i in range(1, n_clusters): dist = torch.cdist(X, X[idxs[:i]], p=p) dist = torch.min(dist, dim=1)[0] - idxs[i] = torch.multinomial(weights * dist, 1) + probs = weights * dist + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + + # Minimal fallback ONLY if multinomial would crash + if probs.sum() <= 0: + # fall back to standard k-means++ (ignore weights) + probs = torch.nan_to_num(dist, nan=0.0, posinf=0.0, neginf=0.0) + + if probs.sum() <= 0: + # fully degenerate (all distances zero) -> pick any not-yet-picked index if possible + mask = torch.ones(X.shape[0], dtype=torch.bool, device=X.device) + mask[idxs[:i]] = False + candidates = torch.nonzero(mask, as_tuple=False).flatten() + if candidates.numel() > 0: + idxs[i] = candidates[torch.randint(0, candidates.numel(), (1,), device=X.device)] + else: + idxs[i] = torch.randint(0, X.shape[0], (1,), device=X.device) + else: + idxs[i] = torch.multinomial(probs, 1) return X[idxs] diff --git a/heat/cluster/kmedoids.py b/heat/cluster/kmedoids.py index fe65ba64d8..52ec093c61 100644 --- a/heat/cluster/kmedoids.py +++ b/heat/cluster/kmedoids.py @@ -141,9 +141,11 @@ def fit(self, x: DNDarray, oversampling: float = 2, iter_multiplier: float = 1): # increment the iteration count self._n_iter += 1 # determine the centroids + matching_centroids = self._assign_to_cluster(x) # update the centroids + new_cluster_centers = self._update_centroids(x, matching_centroids) # check whether centroid movement has converged diff --git a/heat/cluster/tests/test_batchparallelclustering.py b/heat/cluster/tests/test_batchparallelclustering.py index 684d9d9247..ffbf972b44 100644 --- a/heat/cluster/tests/test_batchparallelclustering.py +++ b/heat/cluster/tests/test_batchparallelclustering.py @@ -39,8 +39,42 @@ def test_kmex(self): _kmex(X, 2, 2, init, max_iter, tol) def test_initialize_plus_plus(self): - X = torch.rand(100, 3) - _initialize_plus_plus(X, 3, 2, random_state=None, max_samples=50) + with self.subTest("subsampling"): + X = torch.rand(100, 3) + centers = _initialize_plus_plus(X, 3, 2, random_state=0, max_samples=50) + self.assertEqual(centers.shape, (3, 3)) + + # 2) probs.sum() <= 0 because weights are all zero -> fallback to dist -> multinomial runs + with self.subTest("weights_zero_fallback_to_dist"): + X = torch.rand(30, 3) + weights = torch.zeros(X.shape[0], dtype=X.dtype) + centers = _initialize_plus_plus(X, 3, 2, random_state=0, weights=weights) + self.assertEqual(centers.shape, (3, 3)) + + # 3) fully degenerate distances (all points identical) -> probs.sum() <= 0 twice -> candidate selection branch + with self.subTest("all_distances_zero_candidate_selection"): + X = torch.ones(10, 3) + weights = torch.ones(X.shape[0], dtype=X.dtype) + centers = _initialize_plus_plus(X, 3, 2, random_state=0, weights=weights) + self.assertEqual(centers.shape, (3, 3)) + + # 4) extreme degenerate case: only one sample, n_clusters>1 -> candidates empty branch + with self.subTest("single_sample_candidates_empty"): + X = torch.ones(1, 3) + centers = _initialize_plus_plus(X, 2, 2, random_state=0) + self.assertEqual(centers.shape, (2, 3)) + + # 5) NaN-handling path -> nan_to_num is exercised (should not crash) + with self.subTest("nan_to_num_path"): + X = torch.tensor( + [[0.0, 0.0, 0.0], + [float("nan"), 0.0, 0.0], + [1.0, 0.0, 0.0]], + dtype=torch.float32, + ) + # seed chosen so first centroid is deterministic (helps avoid flakiness) + centers = _initialize_plus_plus(X, 2, 2, random_state=2) + self.assertEqual(centers.shape, (2, 3)) def test_BatchParallelKClustering(self): with self.assertRaises(TypeError): diff --git a/heat/cluster/tests/test_kmedians.py b/heat/cluster/tests/test_kmedians.py index ee8b534e50..6053659950 100644 --- a/heat/cluster/tests/test_kmedians.py +++ b/heat/cluster/tests/test_kmedians.py @@ -36,7 +36,7 @@ def test_fit_iris_unsplit(self): # fit the clusters k = 3 - kmedian = ht.cluster.KMedians(n_clusters=k) + kmedian = ht.cluster.KMedians(n_clusters=k, random_state=1) kmedian.fit(iris) # check whether the results are correct diff --git a/heat/cluster/tests/test_kmedoids.py b/heat/cluster/tests/test_kmedoids.py index a1a261eca8..bb6bd947e3 100644 --- a/heat/cluster/tests/test_kmedoids.py +++ b/heat/cluster/tests/test_kmedoids.py @@ -49,6 +49,8 @@ def test_fit_iris_unsplit(self): ht.any(ht.sum(ht.abs(kmedoid.cluster_centers_[i, :] - iris), axis=1) == 0) ) + + def test_exceptions(self): # get some test data iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ae1bb2f689..7fe1855199 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2,15 +2,13 @@ from __future__ import annotations -import math import numpy as np import torch import warnings -from inspect import stack from mpi4py import MPI -from pathlib import Path -from typing import List, Union, Tuple, TypeVar, Optional +from typing import TypeVar, Any +from collections.abc import Iterable warnings.simplefilter("always", ResourceWarning) @@ -45,7 +43,7 @@ class DNDarray: ---------- array : torch.Tensor Local array elements - gshape : Tuple[int,...] + gshape : tuple[int,...] The global shape of the array dtype : datatype The datatype of the array @@ -64,12 +62,12 @@ class DNDarray: def __init__( self, array: torch.Tensor, - gshape: Tuple[int, ...], + gshape: tuple[int, ...], dtype: datatype, - split: Union[int, None], device: Device, comm: Communication, balanced: bool, + split: int | None, ): self.__array = array self.__gshape = gshape @@ -77,10 +75,10 @@ def __init__( self.__split = split self.__device = device self.__comm = comm - self.__balanced = balanced + self.__balanced: bool = balanced self.__ishalo = False - self.__halo_next = None - self.__halo_prev = None + self.__halo_next: torch.Tensor | None = None + self.__halo_prev: torch.Tensor | None = None self.__partitions_dict__ = None self.__lshape_map = None @@ -116,7 +114,7 @@ def dtype(self) -> datatype: return self.__dtype @property - def gshape(self) -> Tuple: + def gshape(self) -> tuple: """ Returns the global shape of the ``DNDarray`` across all processes """ @@ -263,7 +261,7 @@ def lnumel(self) -> int: return np.prod(self.__array.shape) @property - def lloc(self) -> Union[DNDarray, None]: + def lloc(self) -> "DNDarray" | None: """ Local item setter and getter. i.e. this function operates on a local level and only on the PyTorch tensors composing the :class:`DNDarray`. @@ -272,7 +270,7 @@ def lloc(self) -> Union[DNDarray, None]: Parameters ---------- - key : int or slice or Tuple[int,...] + key : int or slice or tuple[int,...] Indices of the desired data. value : scalar, optional All types compatible with pytorch tensors, if none given then this is a getter function @@ -297,7 +295,7 @@ def lloc(self) -> Union[DNDarray, None]: return LocalIndex(self.__array) @property - def lshape(self) -> Tuple[int]: + def lshape(self) -> tuple[int]: """ Returns the shape of the ``DNDarray`` on each node """ @@ -318,36 +316,36 @@ def real(self) -> DNDarray: return complex_math.real(self) @property - def shape(self) -> Tuple[int]: + def shape(self) -> tuple[int, ...]: """ Returns the shape of the ``DNDarray`` as a whole """ return self.__gshape @property - def split(self) -> int: + def split(self) -> int | None: """ Returns the axis on which the ``DNDarray`` is split """ return self.__split @property - def stride(self) -> Tuple[int]: + def stride(self) -> tuple[int, ...]: """ Returns the steps in each dimension when traversing a ``DNDarray``. torch-like usage: ``self.stride()`` """ - return self.__array.stride + return self.__array.stride() @property - def strides(self) -> Tuple[int]: + def strides(self) -> tuple[int, ...]: """ Returns bytes to step in each dimension when traversing a ``DNDarray``. numpy-like usage: ``self.strides()`` """ - steps = list(self.larray.stride()) + steps = list(self.__array.stride()) try: - itemsize = self.larray.untyped_storage().element_size() + itemsize = self.__array.untyped_storage().element_size() except AttributeError: - itemsize = self.larray.storage().element_size() + itemsize = self.__array.storage().element_size() strides = tuple(step * itemsize for step in steps) return strides @@ -552,7 +550,7 @@ def astype(self, dtype, copy=True) -> DNDarray: return self - def balance_(self) -> DNDarray: + def balance_(self) -> None: """ Function for balancing a :class:`DNDarray` between all nodes. To determine if this is needed use the :func:`is_balanced()` function. If the ``DNDarray`` is already balanced this function will do nothing. This function modifies the ``DNDarray`` @@ -588,6 +586,8 @@ def balance_(self) -> DNDarray: [1/2] (7, 2) (2, 2) [2/2] (7, 2) (2, 2) """ + if not self.is_distributed(): + self.__balanced = True if self.is_balanced(force_check=True): return self.redistribute_() @@ -598,7 +598,7 @@ def __bool__(self) -> bool: """ return self.__cast(bool) - def __cast(self, cast_function) -> Union[float, int]: + def __cast(self, cast_function) -> float | int: """ Implements a generic cast function for ``DNDarray`` objects. @@ -624,7 +624,7 @@ def __cast(self, cast_function) -> Union[float, int]: raise TypeError("only size-1 arrays can be converted to Python scalars") - def collect_(self, target_rank: Optional[int] = 0) -> None: + def collect_(self, target_rank: int | None = 0) -> None: """ A method collecting a distributed DNDarray to one MPI rank, chosen by the `target_rank` variable. It is a specific case of the ``redistribute_`` method. @@ -677,7 +677,7 @@ def __complex__(self) -> DNDarray: """ return self.__cast(complex) - def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]: + def counts_displs(self) -> tuple[tuple[int], tuple[int]]: """ Returns actual counts (number of items per process) and displacements (offsets) of the DNDarray. Does not assume load balance. @@ -686,8 +686,8 @@ def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]: counts = self.lshape_map[:, self.split] displs = [0] + torch.cumsum(counts, dim=0)[:-1].tolist() return tuple(counts.tolist()), tuple(displs) - else: - raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") + + raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") def cpu(self) -> DNDarray: """ @@ -715,9 +715,10 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device ) - if not self.is_distributed: + if not self.is_distributed(): lshape_map[:] = torch.tensor(self.gshape, device=self.device.torch_device) - return lshape_map + self.__lshape_map = lshape_map + return lshape_map.clone() if self.is_balanced(force_check=True): for i in range(self.comm.size): _, lshape, _ = self.comm.chunk(self.gshape, self.split, rank=i) @@ -790,15 +791,15 @@ def create_partition_interface(self): part_tiling = [1] * self.ndim lcls = [0] * self.ndim - z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type()) + z = torch.tensor([0], device=self.device.torch_device, dtype=torch.int64) + if self.split is not None: starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0) lcls[self.split] = self.comm.rank part_tiling[self.split] = self.comm.size + start_idx_map[:, self.split] = starts else: - starts = torch.zeros(self.ndim, dtype=torch.int, device=self.device.torch_device) - - start_idx_map[:, self.split] = starts + start_idx_map[:] = 0 partitions = {} base_key = [0] * self.ndim @@ -879,7 +880,642 @@ def fill_diagonal(self, value: float) -> DNDarray: return self - def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: + def __process_key( + arr: "DNDarray", + key: tuple[int, ...] | list[int], + return_local_indices: bool | None = False, + op: str | None = None, + ) -> tuple: + """ + Private method to process the key used for indexing a ``DNDarray`` so that it can be applied to the process-local data, i.e. `key` must be "torch-proof". + In a processed key: + - any ellipses or newaxis have been replaced with the appropriate number of slice objects + - ndarrays and DNDarrays have been converted to torch tensors + - the dimensionality is the same as the ``DNDarray`` it indexes + This function also manipulates `arr` if necessary, inserting and/or transposing dimensions as indicated by `key`. It calculates the output shape, split axis and balanced status of the indexed array. + + Parameters + ---------- + arr : DNDarray + The ``DNDarray`` to be indexed + key : int, tuple[int, ...], list[int, ...] + The key used for indexing + return_local_indices : bool, optional + Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_ordered == 1`. Default: False + op : str, optional + The indexing operation that the key is being processed for. Get be "get" for `__getitem__` or "set" for `__setitem__`. Default: "get". + + Returns + ------- + arr : DNDarray + The ``DNDarray`` to be indexed. Its dimensions might have been modified if advanced, dimensional, broadcasted indexing is used. + key : tuple[Any, ...] | "DNDarray" | np.ndarray | torch.Tensor | slice | int | list[int] + The processed key ready for indexing ``arr``. Its dimensions match the (potentially modified) dimensions of ``arr``. + Note: the key indices along the split axis are LOCAL indices, i.e. refer to the process-local data, if ordered indexing is used. Otherwise, they are GLOBAL indices, referring to the global memory-distributed DNDarray. Communication to extract the non-ordered elements of the input ``DNDarray`` is handled by the ``__getitem__`` function. + output_shape : tuple[int, ...] + The shape of the output ``DNDarray`` + new_split : int + The new split axis + split_key_is_ordered : int + Whether the split key is sorted or ordered. Can be 1: ascending, 0: not ordered, -1: descending order. + out_is_balanced : bool + Whether the output ``DNDarray`` is balanced + root : int + The root process for the ``MPI.Bcast`` call when single-element indexing along the split axis is used + backwards_transpose_axes : tuple[int, ...] + The axes to transpose the input ``DNDarray`` back to its original shape if it has been transposed for advanced indexing + """ + output_shape = list(arr.gshape) + split_bookkeeping = [None] * arr.ndim + new_split = arr.split + arr_is_distributed = False + if arr.split is not None: + split_bookkeeping[arr.split] = "split" + if arr.is_distributed(): + counts, displs = arr.counts_displs() + arr_is_distributed = True + + advanced_indexing = False + split_key_is_ordered = 1 + key_is_mask_like = False + out_is_balanced = True if not arr.is_distributed() else arr.balanced + root = None + backwards_transpose_axes = tuple(range(arr.ndim)) + + if isinstance(key, list): + try: + key = torch.tensor(key, device=arr.larray.device) + except RuntimeError: + raise IndexError("Invalid indices: expected a list of integers, got {}".format(key)) + if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): + if key.dtype in (ht_bool, ht_uint8, torch.bool, torch.uint8, np.bool_, np.uint8): + # boolean indexing: shape must be consistent with arr.shape + key_ndim = key.ndim + if not tuple(key.shape) == arr.shape[:key_ndim]: + raise IndexError( + "Boolean index of shape {} does not match indexed array of shape {}".format( + tuple(key.shape), arr.shape + ) + ) + # extract non-zero elements + try: + # key is torch tensor + key = key.nonzero(as_tuple=True) + except TypeError: + # key is np.ndarray or DNDarray + key = key.nonzero() + key_is_mask_like = True + else: + # advanced indexing on first dimension: first dim will expand to shape of key + output_shape = tuple(list(key.shape) + output_shape[1:]) + # adjust split axis accordingly + if arr_is_distributed: + if arr.split != 0: + # split axis is not affected + split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] + new_split = ( + split_bookkeeping.index("split") + if "split" in split_bookkeeping + else None + ) + out_is_balanced = arr.balanced + else: + # split axis is affected + if key.ndim > 1: + try: + key_numel = key.numel() + except AttributeError: + key_numel = key.size + if key_numel == arr.shape[0]: + new_split = tuple(key.shape).index(arr.shape[0]) + else: + new_split = key.ndim - 1 + try: + key_split = key[new_split].larray + sorted, _ = key_split.sort(stable=True) + except AttributeError: + key_split = key[new_split] + sorted = key_split.sort() + else: + new_split = 0 + # assess if key is sorted along split axis + try: + # DNDarray key + sorted, _ = torch.sort(key.larray, stable=True) + split_key_is_ordered = torch.tensor( + (key.larray == sorted).all(), + dtype=torch.uint8, + device=key.larray.device, + ) + if key.split is not None: + out_is_balanced = key.balanced + split_key_is_ordered = ( + factories.array( + [split_key_is_ordered], + is_split=0, + device=arr.device, + copy=False, + ) + .all() + .astype(types.canonical_heat_types.uint8) + .item() + ) + else: + split_key_is_ordered = split_key_is_ordered.item() + key = key.larray + except AttributeError: + try: + sorted, _ = torch.sort(key, stable=True) + except TypeError: + # ndarray key -> move key to same device as arr before any torch ops / comparisons + key = torch.as_tensor(key, device=arr.larray.device) + try: + sorted, _ = torch.sort(key, stable=True) + except TypeError: + # fallback for older torch without stable= + sorted, _ = torch.sort(key) + + split_key_is_ordered = (key == sorted).all().item() + if not split_key_is_ordered: + # prepare for distributed non-ordered indexing: distribute torch/numpy key + key = factories.array(key, split=0, device=arr.device).larray + out_is_balanced = True + if split_key_is_ordered: + # extract local key + cond1 = key >= displs[arr.comm.rank] + cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] + key = key[cond1 & cond2] + if return_local_indices: + key -= displs[arr.comm.rank] + out_is_balanced = False + else: + try: + out_is_balanced = key.balanced + new_split = key.split + key = key.larray + except AttributeError: + # torch or numpy key, non-distributed indexed array + out_is_balanced = True + new_split = None + return ( + arr, + key, + output_shape, + new_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) + + key = list(key) if isinstance(key, Iterable) else [key] + + # check for ellipsis, newaxis. NB: (np.newaxis is None)==True + add_dims = sum(k is None for k in key) + ellipsis = sum(isinstance(k, type(...)) for k in key) + if ellipsis > 1: + raise ValueError("indexing key can only contain 1 Ellipsis (...)") + if ellipsis: + # key contains exactly 1 ellipsis + # replace with explicit `slice(None)` for affected dimensions + # output_shape, split_bookkeeping not affected + expand_key = [slice(None)] * (arr.ndim + add_dims) + ellipsis_index = key.index(...) + ellipsis_dims = arr.ndim - (len(key) - ellipsis - add_dims) + expand_key[:ellipsis_index] = key[:ellipsis_index] + expand_key[ellipsis_index + ellipsis_dims :] = key[ellipsis_index + 1 :] + key = expand_key + while add_dims > 0: + # expand array dims: output_shape, split_bookkeeping to reflect newaxis + # replace newaxis with slice(None) in key + for i, k in reversed(list(enumerate(key))): + if k is None: + key[i] = slice(None) + arr = arr.expand_dims(i - add_dims + 1) + output_shape = ( + output_shape[: i - add_dims + 1] + [1] + output_shape[i - add_dims + 1 :] + ) + split_bookkeeping = ( + split_bookkeeping[: i - add_dims + 1] + + [None] + + split_bookkeeping[i - add_dims + 1 :] + ) + add_dims -= 1 + + # recalculate new_split, transpose_axes after dimensions manipulation + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + transpose_axes, backwards_transpose_axes = tuple(range(arr.ndim)), tuple(range(arr.ndim)) + # check for advanced indexing and slices + advanced_indexing_dims = [] + advanced_indexing_shapes = [] + lose_dims = 0 + for i, k in enumerate(key): + if isinstance(k, DNDarray) and k.ndim == 0: + k = k.larray.item() + key[i] = k + # for robustness: handle list/tuple keys that contain DNDarrays + elif isinstance(k, (list, tuple)) and any(isinstance(kk, DNDarray) for kk in k): + # Case 1: singleton container (common from where/nonzero): (idx,) -> idx + if len(k) == 1 and isinstance(k[0], DNDarray): + k = k[0] + key[i] = k + + else: + # Case 2: sequence of scalar DNDarrays -> unwrap to python scalars + new_k = [] + all_scalar = True + for kk in k: + if isinstance(kk, DNDarray): + if kk.ndim != 0: + all_scalar = False + break + new_k.append(kk.larray.item()) + else: + new_k.append(kk) + + if all_scalar: + k = new_k + key[i] = k + else: + # This is an ambiguous nested "tuple of index arrays" inside a single axis. + # In NumPy semantics such tuples belong at TOP LEVEL (arr[idx0, idx1, ...]), + # not nested as one axis key. + raise TypeError( + "Nested tuple/list of non-scalar DNDarray indices is not supported. " + "Pass them as separate indices (e.g. arr[idx0, idx1, ...]) or unwrap " + "singleton tuples (e.g. idx = idx[0])." + ) + + if np.isscalar(k) or getattr(k, "ndim", 1) == 0: + # single-element indexing along axis i + try: + output_shape[i], split_bookkeeping[i] = None, None + except IndexError: + raise IndexError( + f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" + ) + lose_dims += 1 + if i == arr.split: + key[i], root = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=return_local_indices + ) + else: + key[i], _ = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=False + ) + elif isinstance(k, Iterable) or isinstance(k, DNDarray): + advanced_indexing = True + advanced_indexing_dims.append(i) + + if not isinstance(k, DNDarray): + k = factories.array(k, device=arr.device, comm=arr.comm, copy=None) + + # Normalize negative integer indices (NumPy/PyTorch semantics) and validate bounds + if k.dtype in (types.int32, types.int64) and k.ndim >= 1: + dim = arr.gshape[i] + + # compute local flags even if k.larray is empty (any() on empty -> False) + invalid_local = ((k.larray < -dim) | (k.larray >= dim)).any().item() + has_neg_local = (k.larray < 0).any().item() + + # Decide once, then ALL ranks take the same path for collectives + do_reduce = ( + arr.comm is not None + and getattr(arr.comm, "size", 1) > 1 + and k.is_distributed() + ) + + if do_reduce: + invalid_sum = arr.comm.allreduce(int(invalid_local), op=MPI.SUM) + has_neg_sum = arr.comm.allreduce(int(has_neg_local), op=MPI.SUM) + else: + invalid_sum = int(invalid_local) + has_neg_sum = int(has_neg_local) + + if invalid_sum > 0: + raise IndexError(f"index out of bounds for axis {i} with size {dim}") + + if has_neg_sum > 0: + k_l = k.larray.clone() + k_l[k_l < 0] += dim + k = factories.array( + k_l, + dtype=k.dtype, + split=k.split, + device=arr.device, + comm=arr.comm, + copy=False, + ) + + advanced_indexing_shapes.append(k.gshape) + if arr_is_distributed and i == arr.split: + if ( + not k.is_distributed() + and k.ndim == 1 + and (k.larray == torch.sort(k.larray, stable=True)[0]).all() + ): + split_key_is_ordered = 1 + out_is_balanced = False + else: + split_key_is_ordered = 0 + + # redistribute key along last axis to match split axis of indexed array + k = k.resplit(-1) + out_is_balanced = True + key[i] = k + + elif isinstance(k, slice) and k != slice(None): + if k.step == 0: + raise ValueError("Slice step cannot be zero") + start, stop, step = slice(k.start, k.stop, k.step).indices(arr.gshape[i]) + + if step < 0 and start > stop: + # PyTorch doesn't support negative step as of 1.13 + # Lazy solution, potentially large memory footprint + # TODO: implement ht.fromiter (implemented in ASSET_ht) + key[i] = torch.arange( + start, stop, step, device=arr.larray.device, dtype=torch.int64 + ) + output_shape[i] = len(key[i]) + split_key_is_ordered = -1 + if arr_is_distributed and new_split == i: + if op == "set": + # setitem: flip key and keep process-local indices + key[i] = key[i].flip(0) + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + if return_local_indices: + key[i] -= displs[arr.comm.rank] + else: + # getitem: distribute key and proceed with non-ordered indexing + key[i] = factories.array( + key[i], split=0, device=arr.device, copy=False + ).larray + out_is_balanced = True + elif step > 0 and start < stop: + # output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) + output_shape[i] = len(range(start, stop, step)) + + if arr_is_distributed and new_split == i: + split_key_is_ordered = 1 + out_is_balanced = False + local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] + if stop > displs[arr.comm.rank] and start < local_arr_end: + index_in_cycle = (displs[arr.comm.rank] - start) % step + if start >= displs[arr.comm.rank]: + # slice begins on current rank + local_start = start - displs[arr.comm.rank] + else: + local_start = 0 if index_in_cycle == 0 else step - index_in_cycle + if stop <= local_arr_end: + # slice ends on current rank + local_stop = stop - displs[arr.comm.rank] + else: + local_stop = counts[arr.comm.rank] + + key[i] = slice(local_start, local_stop, step) + else: + key[i] = slice(0, 0) + elif step == 0: + raise ValueError("Slice step cannot be zero") + else: + key[i] = slice(0, 0) + output_shape[i] = 0 + + if advanced_indexing: + # adv indexing key elements are DNDarrays: extract torch tensors + # options: 1. key is mask-like (covers boolean mask as well), 2. adv indexing along split axis, 3. everything else + # 1. define key as mask-like if each element of key is a DNDarray, and all elements of key are of the same shape, and the advanced-indexing dimensions are consecutive + key_is_mask_like = ( + all(isinstance(k, DNDarray) for k in key) + and len(set(k.shape for k in key)) == 1 + and torch.tensor(advanced_indexing_dims).diff().eq(1).all() + ) + # if split axis is affected by advanced indexing, keep track of non-split dimensions for later + if arr.is_distributed() and arr.split in advanced_indexing_dims: + non_split_dims = list(advanced_indexing_dims).copy() + if arr.split is not None: + non_split_dims.remove(arr.split) + # 1. key is mask-like + if key_is_mask_like: + key = list(key) + key_splits = [k.split for k in key] + if arr.split is not None and arr.split in advanced_indexing_dims: + split_key_pos = advanced_indexing_dims.index(arr.split) + + if not key_splits.count(key_splits[split_key_pos]) == len(key_splits): + if ( + key_splits[arr.split] is not None + and key_splits.count(None) == len(key_splits) - 1 + ): + for i in non_split_dims: + key[i] = factories.array( + key[i], + split=key_splits[arr.split], + device=arr.device, + comm=arr.comm, + copy=None, + ) + else: + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) + else: + # all key_splits must be the same, otherwise raise IndexError + if not key_splits.count(key_splits[0]) == len(key_splits): + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) + # all key elements are now DNDarrays of the same shape, same split axis + # 2. advanced indexing along split axis + if arr.is_distributed() and arr.split in advanced_indexing_dims: + if split_key_is_ordered == 1: + # extract torch tensors, keep process-local indices only + k = key[arr.split].larray + cond1 = k >= displs[arr.comm.rank] + cond2 = k < displs[arr.comm.rank] + counts[arr.comm.rank] + k = k[cond1 & cond2] + if return_local_indices: + k -= displs[arr.comm.rank] + key[arr.split] = k + for i in non_split_dims: + if key_is_mask_like: + # select the same elements along non-split dimensions + key[i] = key[i].larray[cond1 & cond2] + else: + key[i] = key[i].larray + elif split_key_is_ordered == 0: + # extract torch tensors, any other communication + mask-like case are handled in __getitem__ or __setitem__ + for i in advanced_indexing_dims: + key[i] = key[i].larray + # split_key_is_ordered == -1 not treated here as it is slicing, not advanced indexing + else: + # advanced indexing does not affect split axis, return torch tensors + for i in advanced_indexing_dims: + key[i] = key[i].larray + # all adv indexing keys are now torch tensors + + # shapes of adv indexing arrays must be broadcastable + try: + broadcasted_shape = torch.broadcast_shapes(*advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes + ) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = broadcasted_shape + if key_is_mask_like: + # advanced indexing dimensions will be collapsed into one dimension + if ( + "split" in split_bookkeeping + and split_bookkeeping.index("split") in advanced_indexing_dims + ): + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = ["split"] + else: + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = [None] + else: + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + # keep track of transpose axes order, to be able to transpose back later + transpose_axes = tuple(advanced_indexing_dims + non_adv_ind_dims) + arr = arr.transpose(transpose_axes) + backwards_transpose_axes = tuple( + torch.tensor(transpose_axes, device=arr.larray.device) + .argsort(stable=True) + .tolist() + ) + # output shape and split bookkeeping + output_shape = list(output_shape[i] for i in transpose_axes) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + split_bookkeeping = list(split_bookkeeping[i] for i in transpose_axes) + split_bookkeeping = [None] * add_dims + split_bookkeeping + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + + # expand key to match the number of dimensions of the DNDarray + if arr.ndim > len(key): + key += [slice(None)] * (arr.ndim - len(key)) + + key = tuple(key) + for i in range(output_shape.count(None)): + lost_dim = output_shape.index(None) + output_shape.remove(None) + split_bookkeeping = split_bookkeeping[:lost_dim] + split_bookkeeping[lost_dim + 1 :] + output_shape = tuple(output_shape) + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + return ( + arr, + key, + output_shape, + new_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) + + def __process_scalar_key( + arr: "DNDarray", + key: int | "DNDarray" | torch.Tensor | np.ndarray, + indexed_axis: int, + return_local_indices: bool | None = False, + ) -> tuple[int, int]: + """ + Private method to process a single-item scalar key used for indexing a ``DNDarray``. + + """ + device = arr.larray.device + try: + # is key an ndarray or DNDarray or torch.Tensor? + key = key.item() + except AttributeError: + # key is already an integer, do nothing + pass + if not arr.is_distributed(): + root = None + return key, root + if arr.split == indexed_axis: + # adjust negative key + if key < 0: + key += arr.shape[0] + # work out active process + _, displs = arr.counts_displs() + if key in displs: + root = displs.index(key) + else: + displs = torch.cat( + ( + torch.tensor(displs, device=device), + torch.tensor(key, device=device).reshape(-1), + ), + dim=0, + ) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1].item() - 1 + displs = displs.tolist() + # correct key for rank-specific displacement + if return_local_indices: + if arr.comm.rank == root: + key -= displs[root] + else: + root = None + return key, root + + def __get_local_slice(self, key: slice): + split = self.split + if split is None: + return key + key = stride_tricks.sanitize_slice(key, self.shape[split]) + start, stop, step = key.start, key.stop, key.step + if step < 0: # NOT supported by torch, should be filtered by torch_proxy + key = self.__get_local_slice(slice(stop + 1, start + 1, abs(step))) + if key is None: + return None + start, stop, step = key.start, key.stop, key.step + return slice(key.stop - 1, key.start - 1, -1 * key.step) + + _, offsets = self.counts_displs() + offset = offsets[self.comm.rank] + range_proxy = range(self.lshape[split]) + local_inds = range_proxy[start - offset : stop - offset] # only works if stop - offset > 0 + local_inds = local_inds[max(offset - start, 0) % step :: step] + if len(local_inds) and stop > offset: + # otherwise if (stop-offset) > -self.lshape[split] this can index into the local chunk despite ending before it + return slice(local_inds.start, local_inds.stop, local_inds.step) + return None + + def __getitem__(self, key: int | tuple[int, ...] | list[int]) -> DNDarray: """ Global getter function for DNDarrays. Returns a new DNDarray composed of the elements of the original tensor selected by the indices @@ -889,7 +1525,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar Parameters ---------- - key : int, slice, Tuple[int,...], List[int,...] + key : int, slice, tuple[int,...], list[int,...] Indices to get from the tensor. Examples @@ -909,231 +1545,458 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (1/2) >>> tensor([0.]) (2/2) >>> tensor([0., 0.]) """ - key = getattr(key, "copy()", key) - l_dtype = self.dtype.torch_type() - advanced_ind = False - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used.""" - # NOTE: this gathers the entire key on every process!! - # TODO: remove this resplit!! - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = [key[i].squeeze_(1) for i in range(len(key))] - else: - key = [key] - advanced_ind = True - elif not isinstance(key, tuple): - """this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * max(self.ndim, 1) - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() - else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) i am not certain why this works without being a list. but it works...for now - h[0] = torch.nonzero(key).flatten() # .tolist() - else: - h[0] = key.tolist() - else: - h[0] = key + # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof - key = list(h) + if key is None: + return self.expand_dims(0) + if ( + key is ... + or (isinstance(key, slice) and key == slice(None)) + or (isinstance(key, tuple) and key == ()) + ): + return self - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - # this might be a good place to check if the dtype is there - try: - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - - # ellipsis - key = list(key) - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - else: - key = key + [slice(None)] * (self.ndim - len(key)) + from .types import bool as ht_bool, uint8 as ht_uint8 # avoid circulars - self_proxy = self.__torch_proxy__() - for i in range(len(key)): - if self.__key_adds_dimension(key, i, self_proxy): - key[i] = slice(None) - return self.expand_dims(i)[tuple(key)] + original_split = self.split - key = tuple(key) - # assess final global shape - gout_full = list(self_proxy[key].shape) - - # calculate new split axis - new_split = self.split - # when slicing, squeezed singleton dimensions may affect new split axis - if self.split is not None and len(gout_full) < self.ndim: - if advanced_ind: - new_split = 0 - else: - for i in range(len(key[: self.split + 1])): - if self.__key_is_singular(key, i, self_proxy): - new_split = None if i == self.split else new_split - 1 + def _normalize_index_component(comp): + if isinstance(comp, DNDarray): + if comp.dtype in (ht_bool, ht_uint8): + return comp - key = tuple(key) - if not self.is_distributed(): - arr = self.__array[key].reshape(gout_full) - return DNDarray( - arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced - ) + if comp.split is not None: + return comp - # else: (DNDarray is distributed) - arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) - rank = self.comm.rank - counts, chunk_starts = self.counts_displs() - counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] + return comp.larray.to(torch.int64) - if len(key) == 0: # handle empty list - # this will return an array of shape (0, ...) - arr = self.__array[key] + return comp - """ At the end of the following if/elif/elif block the output array will be set. - each block handles the case where the element of the key along the split axis - is a different type and converts the key from global indices to local indices. """ - lout = gout_full.copy() + if isinstance(key, DNDarray): + key = _normalize_index_component(key) + elif isinstance(key, (list, tuple)): + key = type(key)(_normalize_index_component(k) for k in key) - if ( - isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) - and len(key[self.split]) > 1 - ): - # advanced indexing, elements in the split dimension are adjusted to the local indices - lkey = list(key) - if isinstance(key[self.split], DNDarray): - lkey[self.split] = key[self.split].larray - - if not isinstance(lkey[self.split], torch.Tensor): - inds = torch.tensor( - lkey[self.split], dtype=torch.long, device=self.device.torch_device + if isinstance(key, tuple) and len(key) >= 1 and self.ndim >= 1: + first = key[0] + + # Case 1: DNDarray boolean mask + if ( + isinstance(first, DNDarray) + and first.dtype in (ht_bool, ht_uint8) + and first.ndim == 1 + and first.gshape == (self.gshape[0],) + ): + nz = first.nonzero() + if isinstance(nz, tuple): + nz = nz[0] + if getattr(nz, "ndim", 1) > 1 and nz.shape[-1] == 1: + nz = nz.squeeze(-1) + idx0 = nz + key = (idx0,) + key[1:] + + # Case 2: torch.Tensor boolean mask + elif ( + isinstance(first, torch.Tensor) + and first.ndim == 1 + and first.shape[0] == self.gshape[0] + and first.dtype in (torch.bool, torch.uint8) + ): + idx0 = torch.nonzero(first, as_tuple=False).flatten() + key = (idx0,) + key[1:] + + # Case 3: numpy.ndarray boolean mask + elif ( + isinstance(first, np.ndarray) + and first.ndim == 1 + and first.shape[0] == self.gshape[0] + and first.dtype in (np.bool_, np.uint8) + ): + idx0 = np.nonzero(first)[0].astype(np.int64) + key = (idx0,) + key[1:] + + if isinstance(key, DNDarray): + # Exclude boolean masks; they have their own dedicated handling. + if key.ndim == 1 and key.dtype not in (ht_bool, ht_uint8): + key = key.larray.to(torch.int64) + + # Single-element indexing + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + # single-element indexing on axis 0 + if self.ndim == 0: + raise IndexError( + "Too many indices for DNDarray: DNDarray is 0-dimensional, but 1 were indexed" ) - elif lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? - # need to convert the bools to indices - inds = torch.nonzero(lkey[self.split]) + output_shape = self.gshape[1:] + if original_split is None or original_split == 0: + output_split = None else: - inds = lkey[self.split] - # todo: remove where in favor of nonzero? might be a speed upgrade. testing required - loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) - # if there are no local indices on a process, then `arr` is empty - # if local indices exist: - if len(loc_inds[0]) != 0: - # select same local indices for other (non-split) dimensions if necessary - for i, k in enumerate(lkey): - if isinstance(k, (list, torch.Tensor, DNDarray)) and i != self.split: - lkey[i] = k[loc_inds] - # correct local indices for offset - inds = inds[loc_inds] - chunk_start - lkey[self.split] = inds - lout[new_split] = len(inds) - arr = self.__array[tuple(lkey)].reshape(tuple(lout)) - elif len(loc_inds[0]) == 0: - if new_split is not None: - lout[new_split] = len(loc_inds[0]) - else: - lout = [0] * len(gout_full) - arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( - tuple(lout) + output_split = original_split - 1 + split_key_is_ordered = 1 + out_is_balanced = True + backwards_transpose_axes = tuple(range(self.ndim)) + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + if root is None: + # early out for single-element indexing not affecting split axis + indexed_arr = self.larray[key] + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, ) - - elif isinstance(key[self.split], slice): - # standard slicing along the split axis, - # adjust the slice start, stop, and step, then run it on the processes which have the requested data - key = list(key) - key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) - key_start, key_stop, key_step = ( - key[self.split].start, - key[self.split].stop, - key[self.split].step, - ) - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - if rank in actives: - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + return indexed_arr + else: + # ------------------------------------------------------------------ + # Special case: 2D array with 1D boolean mask along split axis 0 + # Pattern: x[mask_1d] with + # - self.ndim == 2 + # - self.split == 0 + # - key is DNDarray, bool, 1D, same split and length as axis 0 + # This corresponds to NumPy's "select rows by mask" semantics. + # ------------------------------------------------------------------ + if ( + isinstance(key, DNDarray) + and key.dtype in (ht_bool, ht_uint8) + and key.ndim == 1 + and self.ndim == 2 + and self.split == 0 + and key.split == 0 + and key.gshape == (self.gshape[0],) + ): + # Local boolean mask on this rank + local_mask = key.larray # torch.bool, shape (local_rows,) + local_result = self.larray[local_mask, :] # shape (n_local_true, 2) + + # Compute global number of selected rows (sum over ranks) + local_rows = torch.tensor( + [local_result.shape[0]], + device=self.larray.device, + dtype=torch.int64, ) - key[self.split] = slice(key_start, key_stop, key_step) - lout[new_split] = ( - math.ceil((key_stop - key_start) / key_step) - if key_step is not None - else key_stop - key_start + rows_buffer = torch.zeros( + (self.comm.size,), + device=self.larray.device, + dtype=torch.int64, ) - arr = self.__array[tuple(key)].reshape(lout) - else: - lout[new_split] = 0 - arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) + self.comm.Allgather(local_rows, rows_buffer) + total_rows = int(rows_buffer.sum().item()) + + # Global output shape: (total_rows, n_cols) + output_shape = (total_rows, self.gshape[1]) + + # Result remains split along axis 0, generally unbalanced. + result = DNDarray( + local_result, + gshape=output_shape, + dtype=self.dtype, + split=0, + device=self.device, + comm=self.comm, + balanced=False, + ) + return result - elif self.__key_is_singular(key, self.split, self_proxy): - # getting one item along split axis: - key = list(key) - if isinstance(key[self.split], list): - key[self.split] = key[self.split].pop() - elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): - key[self.split] = key[self.split].item() - # translate negative index - if key[self.split] < 0: - key[self.split] += self.gshape[self.split] - - active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() - # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast - if rank == active_rank: - key[self.split] -= chunk_start.item() - arr = self.__array[tuple(key)].reshape(tuple(lout)) + # process multi-element key + ( + self, + key, + output_shape, + output_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) = self.__process_key(key, return_local_indices=True) + + # Do not treat keys that contain slices as "mask-like". + # For such keys, we fall back to the simpler non-mask-like + # path below, which only treats the split axis as globally indexed. + if key_is_mask_like and isinstance(key, (tuple, list)): + if any(isinstance(k, slice) for k in key): + key_is_mask_like = False + + # ------------------------------------------------------------ + # Fast path: pure BASIC slicing/indexing must never trigger any + # cross-rank reductions or communication. + # Example: X[:, 1:], X[5:10], X[:, :-1], ... + # ------------------------------------------------------------ + def _is_basic_component(k): + return k is ... or k is None or isinstance(k, (slice, int, np.integer)) + + _basic_index = isinstance(key, (tuple, list)) and all( + _is_basic_component(k) for k in key + ) + + if _basic_index: + # Slices are ordered by definition; also not mask-like. + split_key_is_ordered = 1 + key_is_mask_like = False else: - arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) - # broadcast result - # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 - arr = self.comm.bcast(arr, root=active_rank) - if arr.device != self.larray.device: - # todo: remove when unnecessary (also after #784) - arr = arr.to(device=self.larray.device) + if self.is_distributed(): + # branch_code: 2 => ordered (1), 1 => descending slice (-1), 0 => unordered (0) + # Use MIN so unordered dominates, then descending, then ordered. + local_code = ( + 2 if split_key_is_ordered == 1 else (1 if split_key_is_ordered == -1 else 0) + ) + global_code = self.comm.allreduce(local_code, op=MPI.MIN) + split_key_is_ordered = ( + 1 if global_code == 2 else (-1 if global_code == 1 else 0) + ) - return DNDarray( - arr.type(l_dtype), - gout_full if isinstance(gout_full, tuple) else tuple(gout_full), - self.dtype, - new_split, - self.device, - self.comm, - balanced=True if new_split is None else None, + # key_is_mask_like must also be consistent across ranks (False dominates) + km_local = 1 if key_is_mask_like else 0 + km_global = self.comm.allreduce(km_local, op=MPI.MIN) + key_is_mask_like = bool(km_global) + + if not self.is_distributed(): + # key is torch-proof, index underlying torch tensor + indexed_arr = self.larray[key] + # transpose array back if needed + if self.ndim > 0: + self = self.transpose(backwards_transpose_axes) + return DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + + if split_key_is_ordered == 1: + if root is not None: + # single-element indexing along split axis + # prepare for Bcast: allocate buffer on all processes + if self.comm.rank == root: + indexed_arr = self.larray[key] + else: + indexed_arr = torch.zeros( + output_shape, dtype=self.larray.dtype, device=self.larray.device + ) + # broadcast result to all processes + self.comm.Bcast(indexed_arr, root=root) + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + # transpose array back if needed + if self.ndim > 0: + self = self.transpose(backwards_transpose_axes) + return indexed_arr + # This covers patterns like A[idx] where A is distributed (split=0) and idx has global indices (e.g. (N,k)). + if self.is_distributed() and self.split == 0 and self.ndim == 1: + k0 = key + # key may be wrapped as a singleton tuple + if isinstance(k0, tuple) and len(k0) == 1: + k0 = k0[0] + + # tolerate DNDarray key (can still happen depending on __process_key path) + if isinstance(k0, DNDarray): + idx_t = k0.larray + else: + idx_t = k0 + + if isinstance(idx_t, torch.Tensor) and idx_t.dtype in ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ): + return self.__take_split0_global_1d( + idx_t, + out_gshape=output_shape, + out_split=0, + out_is_balanced=out_is_balanced, + ) + # root is None, i.e. indexing does not affect split axis, apply as is + indexed_arr = self.larray[key] + # transpose array back if needed + if self.ndim > 0: + self = self.transpose(backwards_transpose_axes) + + return DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + balanced=out_is_balanced, + comm=self.comm, + ) + # key along split axis is not ordered, indices are GLOBAL + # prepare for communication of indices and data + counts, displs = self.counts_displs() + rank, size = self.comm.rank, self.comm.size + + key_is_single_tensor = isinstance(key, torch.Tensor) + if key_is_single_tensor: + split_key = key + else: + split_key = key[self.split] + # split_key might be multi-dimensional, flatten it for communication + if split_key.ndim > 1: + original_split_key_shape = split_key.shape + communication_split = output_split - (split_key.ndim - 1) + split_key = split_key.flatten() + else: + communication_split = output_split + + # determine the number of elements to be received from each process + recv_counts = torch.zeros((size,), dtype=torch.int64, device=self.larray.device) + if key_is_mask_like: + recv_indices = torch.zeros( + (len(split_key), len(key)), dtype=split_key.dtype, device=self.larray.device + ) + else: + recv_indices = torch.zeros( + (split_key.shape), dtype=split_key.dtype, device=self.larray.device + ) + for p in range(size): + cond1 = split_key >= displs[p] + cond2 = split_key < displs[p] + counts[p] + indices_from_p = torch.nonzero(cond1 & cond2, as_tuple=False) + incoming_indices = split_key[indices_from_p].flatten() + recv_counts[p] = incoming_indices.numel() + start = int(recv_counts[:p].sum().item()) + stop = start + int(recv_counts[p].item()) + if incoming_indices.numel() > 0: + if key_is_mask_like: + # apply selection to all dimensions + for i in range(len(key)): + recv_indices[start:stop, i] = key[i][indices_from_p].flatten() + recv_indices[start:stop, self.split] -= displs[p] + else: + recv_indices[start:stop] = incoming_indices - displs[p] + # build communication matrix by sharing recv_counts with all processes + # comm_matrix rows contain the send_counts for each process, columns contain the recv_counts + comm_matrix = torch.zeros((size, size), dtype=torch.int64, device=self.larray.device) + self.comm.Allgather(recv_counts, comm_matrix) + send_counts = comm_matrix[:, rank] + + active_rank_pairs = torch.nonzero(comm_matrix, as_tuple=False) + + # rank sicher als Python-int + rank = int(rank) + + mask_recv = active_rank_pairs[:, 1].eq(rank) + mask_send = active_rank_pairs[:, 0].eq(rank) + + active_recv_indices_from = [int(x.item()) for x in active_rank_pairs[mask_recv, 0]] + active_send_indices_to = [int(x.item()) for x in active_rank_pairs[mask_send, 1]] + rank_is_active = (len(active_recv_indices_from) > 0) or (len(active_send_indices_to) > 0) + + # allocate recv_buf for incoming data + recv_buf_shape = list(output_shape) + if communication_split != output_split: + # split key was flattened, flatten corresponding dims in recv_buf accordingly + recv_buf_shape = ( + recv_buf_shape[:communication_split] + + [recv_counts.sum().item()] + + recv_buf_shape[output_split + 1 :] + ) + else: + recv_buf_shape[communication_split] = recv_counts.sum().item() + recv_buf = torch.zeros( + tuple(recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device ) + if rank_is_active: + # non-blocking send indices to `active_send_indices_to` + send_requests = [] + for i in active_send_indices_to: + start = recv_counts[:i].sum().item() + stop = start + recv_counts[i].item() + outgoing_indices = recv_indices[start:stop] + send_requests.append(self.comm.Isend(outgoing_indices, dest=i)) + del outgoing_indices + del recv_indices + for i in active_recv_indices_from: + # receive indices from `active_recv_indices_from` + if key_is_mask_like: + incoming_indices = torch.zeros( + (send_counts[i].item(), len(key)), + dtype=torch.int64, + device=self.larray.device, + ) + else: + incoming_indices = torch.zeros( + send_counts[i].item(), dtype=torch.int64, device=self.larray.device + ) + self.comm.Recv(incoming_indices, source=i) + # prepare send_buf for outgoing data + if key_is_single_tensor: + send_buf = self.larray[incoming_indices] + else: + if key_is_mask_like: + send_key = tuple( + incoming_indices[:, i].reshape(-1) + for i in range(incoming_indices.shape[1]) + ) + send_buf = self.larray[send_key] + else: + send_key = list(key) + send_key[self.split] = incoming_indices + send_buf = self.larray[tuple(send_key)] + # non-blocking send requested data to i + send_requests.append(self.comm.Isend(send_buf, dest=i)) + del send_buf + # allocate temporary recv_buf to receive data from all active processes + tmp_recv_buf_shape = recv_buf_shape.copy() + tmp_recv_buf_shape[communication_split] = recv_counts.max().item() + tmp_recv_buf = torch.zeros( + tuple(tmp_recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device + ) + for i in active_send_indices_to: + # receive data from i + tmp_recv_slice = [slice(None)] * tmp_recv_buf.ndim + tmp_recv_slice[communication_split] = slice(0, recv_counts[i].item()) + self.comm.Recv(tmp_recv_buf[tmp_recv_slice], source=i) + # write received data to appropriate portion of recv_buf + cond1 = split_key >= displs[i] + cond2 = split_key < displs[i] + counts[i] + recv_buf_indices = torch.nonzero(cond1 & cond2, as_tuple=False).flatten() + recv_buf_key = [slice(None)] * recv_buf.ndim + recv_buf_key[communication_split] = recv_buf_indices + recv_buf[recv_buf_key] = tmp_recv_buf[tmp_recv_slice] + del tmp_recv_buf + # wait for all non-blocking communication to finish + for req in send_requests: + req.Wait() + if communication_split != output_split: + # split_key has been flattened, bring back recv_buf to intended shape + original_local_shape = ( + output_shape[:communication_split] + + original_split_key_shape + + output_shape[output_split + 1 :] + ) + recv_buf = recv_buf.reshape(original_local_shape) + + # construct indexed array from recv_buf + indexed_arr = DNDarray( + recv_buf, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + # transpose array back if needed + if self.ndim > 0: + self = self.transpose(backwards_transpose_axes) + return indexed_arr if torch.cuda.device_count() > 0: @@ -1219,9 +2082,13 @@ def __len__(self) -> int: """ The length of the ``DNDarray``, i.e. the number of items in the first dimension. """ - return self.shape[0] + try: + len = self.shape[0] + return len + except IndexError: + raise TypeError("len() of unsized DNDarray") - def numpy(self) -> np.array: + def numpy(self) -> np.typing.NDArray[Any]: """ Returns a copy of the :class:`DNDarray` as numpy ndarray. If the ``DNDarray`` resides on the GPU, the underlying data will be copied to the CPU first. @@ -1251,7 +2118,7 @@ def __repr__(self) -> str: """ return printing.__repr__(self) - def ravel(self): + def ravel(self) -> "DNDarray": """ Flattens the ``DNDarray``. @@ -1270,8 +2137,8 @@ def ravel(self): return manipulations.ravel(self) def redistribute_( - self, lshape_map: Optional[torch.Tensor] = None, target_map: Optional[torch.Tensor] = None - ): + self, lshape_map: torch.Tensor | None = None, target_map: torch.Tensor | None = None + ) -> None: """ Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. This function does not modify the non-split dimensions of the ``DNDarray``. @@ -1423,9 +2290,9 @@ def redistribute_( def __redistribute_shuffle( self, - snd_pr: Union[int, torch.Tensor], - send_amt: Union[int, torch.Tensor], - rcv_pr: Union[int, torch.Tensor], + snd_pr: int | torch.Tensor, + send_amt: int | torch.Tensor, + rcv_pr: int | torch.Tensor, snd_dtype: torch.dtype, ): """ @@ -1515,16 +2382,14 @@ def resplit_(self, axis: int = None): # sanitize the axis to check whether it is in range axis = sanitize_axis(self.shape, axis) + self.__partitions_dict__ = None + # early out for unchanged content if self.comm.size == 1: self.__split = axis - if axis is None: - self.__partitions_dict__ = None if axis == self.split: return self - self.__partitions_dict__ = None - if axis is None: gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device @@ -1570,17 +2435,17 @@ def resplit_(self, axis: int = None): def __setitem__( self, - key: Union[int, Tuple[int, ...], List[int, ...]], - value: Union[float, DNDarray, torch.Tensor], + key: int | tuple[int, ...] | list[int], + value: float | "DNDarray" | torch.Tensor, ): """ Global item setter Parameters ---------- - key : Union[int, Tuple[int,...], List[int,...]] + key : int | tuple[int, ...] | list[int] Index/indices to be set - value: Union[float, DNDarray,torch.Tensor] + value: float | "DNDarray" | torch.Tensor Value to be set to the specified positions in the DNDarray (self) Notes @@ -1602,265 +2467,866 @@ def __setitem__( (2/2) >>> tensor([[0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.]]) """ - key = getattr(key, "copy()", key) - try: - if value.split != self.split: - val_split = int(value.split) - sp = self.split - warnings.warn( - f"\nvalue.split {val_split} not equal to this DNDarray's split:" - f" {sp}. this may cause errors or unwanted behavior", - category=RuntimeWarning, - ) - except (AttributeError, TypeError): - pass - # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction - # of this next block of code. this is shared with __getitem__. I attempted to abstract it - # in a standard way, but it was causing errors in the test suite. If someone else is - # motived to do this they are welcome to, but i have no time right now - # print(key) - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used.""" - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = [key[i].squeeze_(1) for i in range(len(key))] + def __broadcast_value( + arr: "DNDarray", + key: int | tuple[int, ...] | slice, + value: "DNDarray", + **kwargs, + ): + """ + Broadcasts the assignment DNDarray `value` to the shape of the indexed array `arr[key]` if necessary. + """ + is_scalar = ( + np.isscalar(value) + or getattr(value, "ndim", 1) == 0 + or (value.shape == (1,) and value.split is None) + ) + if is_scalar: + # no need to broadcast + return value, is_scalar + # need information on indexed array + output_shape = kwargs.get("output_shape", None) + if output_shape is not None: + indexed_dims = len(output_shape) else: - key = [key] - elif not isinstance(key, tuple): - """this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * self.ndim - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() + if isinstance(key, (int, tuple)): + # direct indexing, output_shape has not been calculated + # use proxy to avoid MPI communication and limit memory usage + indexed_proxy = arr.__torch_proxy__()[key] + indexed_dims = indexed_proxy.ndim + output_shape = tuple(indexed_proxy.shape) else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) im not sure why this works without being a list...but it does...for now - h[0] = torch.nonzero(key).flatten() # .tolist() + raise RuntimeError( + "Not enough information to broadcast value to indexed array, please provide `output_shape`" + ) + value_shape = value.shape + # check if value needs to be broadcasted + if value_shape != output_shape: + # assess whether the shapes are compatible, starting from the trailing dimension + for i in range(1, min(len(value_shape), len(output_shape)) + 1): + if i == 1: + if value_shape[-i] != output_shape[-i] and not value_shape[-i] == 1: + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + else: + if value_shape[-i] != output_shape[-i] and (not value_shape[-i] == 1): + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input from shape {value_shape} into shape {output_shape}" + ) + # value has more dimensions than indexed array + if value.ndim > indexed_dims: + # check if all dimensions except the indexed ones are singletons + all_singletons = value.shape[: value.ndim - indexed_dims] == (1,) * ( + value.ndim - indexed_dims + ) + if not all_singletons: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + # squeeze out singleton dimensions + value = value.squeeze(tuple(range(value.ndim - indexed_dims))) else: - h[0] = key.tolist() + while value.ndim < indexed_dims: + # broadcasting + # expand missing dimensions to align split axis + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + return value, is_scalar + + def __dedup_last_wins_advanced_index( + key_in, + rhs_in: torch.Tensor, + target_shape: tuple[int, ...], + ): + """ + CUDA-safe handling for duplicate advanced indices: + enforce NumPy semantics (last assignment wins) by dropping earlier duplicates. + Works for: + - key_in: torch.Tensor (indexes axis 0) + - key_in: tuple/list of torch.Tensors (pure advanced indexing) + rhs_in must match the indexing result shape. + """ + # Scalars or single element: no need to dedup + if rhs_in.numel() <= 1: + return key_in, rhs_in + + # Normalize key to either a single tensor or tuple of tensors + if torch.is_tensor(key_in): + idx_tensors = (key_in,) + elif ( + isinstance(key_in, (tuple, list)) + and len(key_in) > 0 + and all(torch.is_tensor(k) for k in key_in) + ): + idx_tensors = tuple(key_in) else: - h[0] = key - key = list(h) + # Not pure advanced-tensor indexing -> don't touch + return key_in, rhs_in - # key must be torch-proof - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - try: # extract torch tensor - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - # remove bools from a torch tensor in favor of indexes - try: - if key[i].dtype in [torch.bool, torch.uint8]: - key[i] = torch.nonzero(key[i]).flatten() - except (AttributeError, TypeError): - pass - - key = list(key) - - # ellipsis stuff - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - # ---------- end ellipsis stuff ------------- - - for c, k in enumerate(key): + device = rhs_in.device + + # Broadcast indices to common shape, then flatten try: - key[c] = k.item() - except (AttributeError, ValueError, RuntimeError): - pass + idx_b = torch.broadcast_tensors(*idx_tensors) + except RuntimeError: + # If broadcast fails, leave it to PyTorch (will error appropriately) + return key_in, rhs_in - rank = self.comm.rank - if self.split is not None: - counts, chunk_starts = self.counts_displs() - else: - counts, chunk_starts = 0, [0] * self.comm.size - counts = torch.tensor(counts, device=self.device.torch_device) - chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] - # determine which elements are on the local process (if the key is a torch tensor) + pos_shape = idx_b[0].shape + pos_ndim = len(pos_shape) + n = idx_b[0].numel() + + idx_flat = [t.to(device=device, dtype=torch.int64).reshape(-1) for t in idx_b] + + # Build linear index for duplicate detection + if len(idx_flat) == 1: + lin = idx_flat[0] + else: + lin = idx_flat[0] + # linearize across the first len(idx_flat) dimensions of the target tensor + for d in range(1, len(idx_flat)): + lin = lin * int(target_shape[d]) + idx_flat[d] + + # Fast path: no duplicates + if torch.unique(lin).numel() == n: + return key_in, rhs_in + + # Determine "last occurrence" per linear index (last wins) + pos = torch.arange(n, device=device, dtype=torch.int64) + + # Prefer stable sort by lin if available; otherwise sort by combined key + try: + order = torch.argsort(lin, stable=True) + except TypeError: + # combined key sorts by lin, then by pos + combined = lin.to(torch.int64) * (n + 1) + pos + order = torch.argsort(combined) + + lin_s = lin[order] + pos_s = pos[order] + + is_last = torch.ones_like(lin_s, dtype=torch.bool) + is_last[:-1] = lin_s[1:] != lin_s[:-1] + keep_pos = pos_s[is_last] # positions in original stream + + # Reduce RHS accordingly: + # Flatten leading "pos_ndim" dims into one, keep trailing dims as payload + rhs_view = rhs_in.reshape(n, *rhs_in.shape[pos_ndim:]) + rhs_u = rhs_view[keep_pos].reshape(keep_pos.numel(), *rhs_in.shape[pos_ndim:]) + + # Reduce indices accordingly (use flattened 1D indices) + if torch.is_tensor(key_in): + key_u = idx_flat[0][keep_pos] + return key_u, rhs_u + + key_u = tuple(t[keep_pos] for t in idx_flat) + return key_u, rhs_u + + def __set( + arr: "DNDarray", + key: int | tuple[int, ...] | list[int], + value: float | "DNDarray" | torch.Tensor, + ): + """ + Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. + """ + # only assign values if key does not contain empty slices + process_is_inactive = arr.larray[key].numel() == 0 + if not process_is_inactive: + rhs = value.larray.type(arr.dtype.torch_type()) + key_to_use = key + + # CUDA: make advanced indexing assignment deterministic for duplicate indices + if arr.larray.is_cuda: + key_to_use, rhs = __dedup_last_wins_advanced_index( + key_to_use, rhs, arr.larray.shape + ) + + arr.larray[key_to_use] = rhs + return + + # make sure `value` is a DNDarray try: - # if isinstance(key[self.split], torch.Tensor): - filter_key = torch.nonzero( - (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) + value = factories.array(value) + except TypeError: + raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") + + # keep the key in its original form to handle edge cases + original_key = key + + # single-element key + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + value, value_is_scalar = __broadcast_value(self, key, value) + + if root is not None: + if self.comm.rank == root: + indexed_proxy = self.__torch_proxy__()[key] + if indexed_proxy.names.count("split") != 0: + indexed_lshape_map = self.lshape_map[:, 1:] + if value.lshape_map != indexed_lshape_map: + try: + value.redistribute_(target_map=indexed_lshape_map) + except ValueError: + raise ValueError( + f"cannot assign value to indexed DNDarray because " + f"distribution schemes do not match: " + f"{value.lshape_map} vs. {indexed_lshape_map}" + ) + __set(self, key, value) + else: + if not value_is_scalar: + value = sanitation.sanitize_distribution(value, target=self[key]) + __set(self, key, value) + return + + if isinstance(key, tuple) and len(key) >= 1 and self.ndim >= 1: + first = key[0] + if isinstance(first, (DNDarray, torch.Tensor, np.ndarray)): + first_dtype = getattr(first, "dtype", None) + first_ndim = getattr(first, "ndim", 0) + first_shape = tuple(getattr(first, "shape", ())) + + if ( + first_ndim == 1 + and first_shape == (self.shape[0],) + and first_dtype + in (ht_bool, ht_uint8, torch.bool, torch.uint8, np.bool_, np.uint8) + ): + # 1D boolean row mask -> explicit integer indices + if isinstance(first, DNDarray): + nz = first.nonzero() + if isinstance(nz, tuple): + nz = nz[0] + idx0 = nz # DNDarray of int indices (global) + else: + first_t = torch.as_tensor(first, device=self.device.torch_device) + idx0 = torch.nonzero(first_t, as_tuple=False).flatten() + + # Build new key: (idx0, rest...) + new_key = (idx0,) + key[1:] + + # recursuve call with integer advanced indexing. + self[new_key] = value + return + + # handle negative indices in multi-element keys + if isinstance(key, tuple): + key_list = list(key) + for ax, k_ax in enumerate(key_list): + if isinstance(k_ax, (int, np.integer)) and not isinstance(k_ax, (bool, np.bool_)): + if k_ax < 0: + dim = self.gshape[ax] + if -dim <= k_ax < 0: + key_list[ax] = dim + k_ax + else: + raise IndexError( + f"index {k_ax} is out of bounds for axis {ax} with size {dim}" + ) + key = tuple(key_list) + + # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing + ( + self, + key, + output_shape, + output_split, + split_key_is_ordered, + key_is_mask_like, + _, + root, + backwards_transpose_axes, + ) = self.__process_key(key, return_local_indices=True, op="set") + + if self.is_distributed(): + local_code = ( + 2 if split_key_is_ordered == 1 else (1 if split_key_is_ordered == -1 else 0) ) - for k in range(len(key)): - try: - key[k] = key[k][filter_key].flatten() - except TypeError: - pass - except TypeError: # this will happen if the key doesnt have that many - pass + global_code = self.comm.allreduce(local_code, op=MPI.MIN) + split_key_is_ordered = 1 if global_code == 2 else (-1 if global_code == 1 else 0) - key = tuple(key) + km_local = 1 if key_is_mask_like else 0 + km_global = self.comm.allreduce(km_local, op=MPI.MIN) + key_is_mask_like = bool(km_global) - if not self.is_distributed(): - return self.__setter(key, value) # returns None + # match dimensions + value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) - # raise RuntimeError("split axis of array and the target value are not equal") removed - # this will occur if the local shapes do not match - rank = self.comm.rank - ends = [] - for pr in range(self.comm.size): - _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - ends.append(e[self.split].stop - e[self.split].start) - ends = torch.tensor(ends, device=self.device.torch_device) - chunk_ends = ends.cumsum(dim=0) - chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - chunk_start = chunk_slice[self.split].start - chunk_end = chunk_slice[self.split].stop - - self_proxy = self.__torch_proxy__() - - # if the value is a DNDarray, the divisions need to be balanced: - # this means that we need to know how much data is where for both DNDarrays - # if the value data is not in the right place, then it will need to be moved - - if isinstance(key[self.split], slice): - key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] + # early out for non-distributed case + if not self.is_distributed() and not value.is_distributed(): + # no communication needed, just apply the local set + __set(self, key, value) + + # For 0-D arrays there is nothing to transpose; avoid permute() with no dims + if self.ndim > 0: + self = self.transpose(backwards_transpose_axes) + + return + + # distributed case + if split_key_is_ordered == 1: + # key all local + if root is not None: + # single-element assignment along split axis, only one active process + if self.comm.rank == root: + self.larray[key] = value.larray.type(self.dtype.torch_type()) + else: + # indexed elements are process-local + if self.is_distributed() and not value_is_scalar: + if not value.is_distributed(): + # work with distributed `value` + value = factories.array( + value.larray, + dtype=value.dtype, + split=output_split, + device=self.device, + comm=self.comm, + ) + else: + if value.split != output_split: + raise RuntimeError( + f"Cannot assign distributed `value` with split axis {value.split} to indexed DNDarray with split axis {output_split}." + ) + # verify that `self[key]` and `value` distribution are aligned + target_shape = torch.tensor( + tuple(self.larray[key].shape), device=self.device.torch_device + ) + target_map = torch.zeros( + (self.comm.size, len(target_shape)), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(target_shape, target_map) + value.redistribute_(target_map=target_map) + __set(self, key, value) + self = self.transpose(backwards_transpose_axes) + return + + if split_key_is_ordered == -1: + # key along split axis is in descending order, i.e. slice with negative step + # N.B. PyTorch doesn't support negative-step slices. Key has been processed into torch tensor. + + # flip value, match value distribution to key's + # NB: `value.ndim` can be smaller than `self.ndim`, hence `value.split` nominally different from `self.split` + flipped_value = manipulations.flip(value, axis=output_split) + split_key = factories.array( + key[self.split], is_split=0, device=self.device, comm=self.comm ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) + if not flipped_value.is_distributed(): + # work with distributed `flipped_value` + flipped_value = factories.array( + flipped_value.larray, + dtype=flipped_value.dtype, + split=output_split, + device=self.device, + comm=self.comm, + ) + # match `value` distribution to `self[key]` distribution + target_map = flipped_value.lshape_map + target_map[:, output_split] = split_key.lshape_map[:, 0] + flipped_value.redistribute_(target_map=target_map) + __set(self, key, flipped_value) + self = self.transpose(backwards_transpose_axes) + return + + def _advanced_setitem_unordered_local( + x_local: torch.Tensor, + split_key: torch.Tensor, + value_torch: torch.Tensor, + *, + split_axis: int, + value_key_start_dim: int, + local_offset: int, + local_size: int, + value_is_scalar: bool, + out_dtype: torch.dtype, + base_index: tuple | None = None, + ) -> None: + """ + The function is a helper that updates ``x_local`` in-place according to the logical advanced + indexing pattern encoded by ``split_key`` and the broadcasted ``value_torch``. + This helper operates exclusively on local ``torch.Tensor`` views: + - ``x_local`` is the local slice of the distributed array on this rank. + - ``split_key`` contains GLOBAL indices along the split axis. + - Only those indices that fall into ``[local_offset, local_offset + local_size)`` + are applied on this rank. + """ + # 1) Local mask: which global indices in `split_key` belong to this rank? + global_indices = split_key + local_mask = (global_indices >= local_offset) & ( + global_indices < local_offset + local_size + ) + + coord = local_mask.nonzero(as_tuple=True) + + if coord[0].numel() == 0: + # Nothing to do on this rank, exit early. + return + + # 2) Map global → local indices along the split axis + global_split_indices = global_indices[coord] + local_split_indices = global_split_indices - local_offset + + # 3) Build LHS index for x_local (corresponds to self.larray) + if base_index is None: + lhs_index = [slice(None)] * x_local.ndim + else: + lhs_index = list(base_index) + + lhs_index[split_axis] = local_split_indices + lhs_index = tuple(lhs_index) + + # 4) Build RHS index for value_torch + if value_is_scalar: + # Scalar assignment: broadcast scalar to the selected positions + x_local[lhs_index] = value_torch.to(out_dtype) + return + + rhs_index = [slice(None)] * value_torch.ndim + m = split_key.ndim + + for d in range(m): + rhs_index[value_key_start_dim + d] = coord[d] + + rhs = value_torch[tuple(rhs_index)] + x_local[lhs_index] = rhs.to(out_dtype) + + if split_key_is_ordered == 0: + # key along split axis is unordered, communication needed in general + # key along the split axis is torch tensor, indices are GLOBAL + counts, displs = self.counts_displs() + rank, _ = self.comm.rank, self.comm.size + + key_is_single_tensor = isinstance(key, torch.Tensor) if ( - isinstance(value, type(self)) - and value.split is not None - and value.shape[self.split] != self.shape[self.split] + not value.is_distributed() + and value_is_scalar + and isinstance(original_key, tuple) + and len(original_key) == self.ndim + and all( + isinstance(k, DNDarray) + and k.ndim == 1 + and k.dtype in (types.int32, types.int64) + for k in original_key + ) ): - # setting elements in self with a DNDarray which is not the same size in the - # split dimension - local_keys = [] - # below is used if the target needs to be reshaped - target_reshape_map = torch.zeros( - (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device + # Alle Indexvektoren global auf *jedem* Rang verfügbar machen, + # unabhängig davon, wie nz verteilt ist. + global_indices = [] + for k in original_key: + k_full = k.copy() + k_full.resplit_(None) # alle Ränge halten anschließend den kompletten 1D-Vektor + global_indices.append(k_full.larray) + + # Globale Indizes entlang der Split-Achse + idx_split_global = global_indices[self.split] + local_offset = displs[rank] + local_size = counts[rank] + + # Welche Einträge von nz gehören zu diesem Rang? + mask = (idx_split_global >= local_offset) & ( + idx_split_global < local_offset + local_size ) - for r in range(self.comm.size): - if r not in actives: - loc_key = key.copy() - loc_key[self.split] = slice(0, 0, 0) + if not mask.any(): + # Auf diesem Rang ist nichts zu tun + self = self.transpose(backwards_transpose_axes) + return + + # Pro Dimension einen lokalen Indextensor bauen + lhs_index = [] + for dim, gind in enumerate(global_indices): + sel = gind[mask] + if dim == self.split: + # globale -> lokale Indizes + sel = sel - local_offset + lhs_index.append(sel) + lhs_index = tuple(lhs_index) + + # Skalarwert in richtigen Torch-Typ/Device bringen + if hasattr(value, "larray"): + scalar_torch = value.larray + else: + scalar_torch = torch.as_tensor(value, device=self.device.torch_device) + scalar_torch = scalar_torch.type(self.dtype.torch_type()) + + # In-place Update der lokalen Daten + self.larray[lhs_index] = scalar_torch + self = self.transpose(backwards_transpose_axes) + return + + # No communication needed if `value` is not distributed, only set elements local to each process + if not value.is_distributed(): + # Edge case: pure boolean DNDarray mask with same split as `self` + if ( + key_is_mask_like + and isinstance(original_key, DNDarray) + and original_key.split == self.split + and original_key.larray.dtype == torch.bool + ): + local_mask = original_key.larray + + if value_is_scalar: + if hasattr(value, "larray"): + scalar_torch = value.larray + else: + scalar_torch = torch.as_tensor(value, device=self.device.torch_device) + scalar_torch = scalar_torch.type(self.dtype.torch_type()) + self.larray[local_mask] = scalar_torch + else: + if hasattr(value, "larray"): + value_torch = value.larray + else: + value_torch = torch.as_tensor(value, device=self.device.torch_device) + + if value_torch.ndim == 1: + # RHS is already flat, length == #True(global) + # -> we need to extract the appropriate section from value_torch for each rank + + # 1) Local number of True values + local_mask_flat = local_mask.flatten() + local_true = int(local_mask_flat.sum().item()) + + # 2) Prefix sum across ranks to find the start index + if self.comm.size > 1: + if self.comm.rank == 0: + offset = 0 + _ = self.comm.exscan(local_true) + else: + offset = self.comm.exscan(local_true) + else: + offset = 0 + + # 3) Extract the local section from RHS + rhs_local = value_torch[offset : offset + local_true].type( + self.dtype.torch_type() + ) + + # 4) Insert the local section into the True positions + x_flat = self.larray.view(-1) + x_flat[local_mask_flat] = rhs_local + else: + # Value has the same shape as arr (or is broadcastable) + self.larray[local_mask] = value_torch[local_mask].type( + self.dtype.torch_type() + ) + + self = self.transpose(backwards_transpose_axes) + return + + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + # keep local indexing key only and correct for displacements along the split axis + key = key[local_indices] - displs[rank] + if value_is_scalar: + # no need to index value + self.larray[key] = value.larray.type(self.dtype.torch_type()) + else: + # set local elements of `self` to corresponding elements of `value` + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) + return + + if key_is_mask_like: + # Echte boolsche Maske entlang der Split-Achse, lokal auswerten. + split_part = key[self.split] + + if isinstance(split_part, DNDarray): + local_mask = split_part.larray + elif isinstance(split_part, torch.Tensor): + if split_part.dtype not in (torch.bool, torch.uint8): + raise TypeError( + f"mask-like key along the split axis must be boolean, got {split_part.dtype}" + ) + start = displs[rank] + stop = start + counts[rank] + local_mask = split_part[start:stop] + else: + raise TypeError("Unsupported mask-like key type along split axis") + + local_indices = torch.nonzero(local_mask, as_tuple=False).flatten() + + if local_indices.numel() == 0: + self = self.transpose(backwards_transpose_axes) + return + + # Lokalen Key bauen: Split-Achse bekommt lokale Integer-Indizes, + # DNDarray-Komponenten werden zu lokalen Torch-Tensoren. + new_key = [] + for i, k_i in enumerate(key): + if i == self.split: + new_key.append(local_indices) + else: + if isinstance(k_i, DNDarray): + new_key.append(k_i.larray) + else: + new_key.append(k_i) + + key_local = tuple(new_key) + + # Wert vorbereiten + if value_is_scalar: + if hasattr(value, "larray"): + scalar_torch = value.larray + else: + scalar_torch = torch.as_tensor(value, device=self.device.torch_device) + scalar_torch = scalar_torch.type(self.dtype.torch_type()) + self.larray[key_local] = scalar_torch else: - key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] - key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - key_start_l, key_stop_l = self.__xitem_get_key_start_stop( - r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start + if hasattr(value, "larray"): + value_torch = value.larray + else: + value_torch = torch.as_tensor(value, device=self.device.torch_device) + self.larray[key_local] = value_torch[key_local].type( + self.dtype.torch_type() ) - loc_key = key.copy() - loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) - gout_full = torch.tensor( - self_proxy[tuple(loc_key)].shape, device=self.device.torch_device - ) - target_reshape_map[r] = gout_full - local_keys.append(loc_key) + self = self.transpose(backwards_transpose_axes) + return + + # Use original split of ``value`` (applying __process_key splits it like the input array) + # and take care of transposes + original_split_axis = backwards_transpose_axes[self.split] + raw_split_part = original_key[original_split_axis] + + if isinstance(raw_split_part, DNDarray): + split_key = raw_split_part.larray + elif isinstance(raw_split_part, torch.Tensor): + split_key = raw_split_part + else: + # Fallback to previous behaviour: use processed key on the (possibly transposed) split axis + split_key = key[self.split] - key = local_keys[rank] - value = value.redistribute(target_map=target_reshape_map) + # Convert to torch.Tensor if a DNDarray was passed + if isinstance(split_key, DNDarray): + split_key = split_key.larray - if rank not in actives: - return # non-active ranks can exit here + if split_key.dtype == torch.bool: + # assume mask along the split axis: convert to global indices + split_key = torch.nonzero(split_key, as_tuple=False).flatten() - chunk_starts_v = target_reshape_map[:, self.split] - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts_v[rank] - og_key_start).item() + local_offset = displs[rank] + local_size = counts[rank] - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # Ensure value is a local torch.Tensor (avoid DNDarray-style indexing here) + if hasattr(value, "larray"): + value_torch = value.larray + else: + value_torch = torch.as_tensor(value, device=self.device.torch_device) + + feature_dims = self.larray.ndim - (self.split + 1) + + if value_is_scalar: + value_key_start_dim = 0 + else: + value_key_start_dim = value_torch.ndim - split_key.ndim - feature_dims + if value_key_start_dim < 0: + raise RuntimeError("value_key_start_dim < 0 – inconsistent shapes") + + local_split_axis = self.split + + base_index = [slice(None)] * self.larray.ndim + for dim, k_part in enumerate(original_key): + if dim == self.split: + continue + # DNDarray → torch.Tensor + if isinstance(k_part, DNDarray): + base_index[dim] = k_part.larray + else: + # slices, ints, torch.Tensor, ... + base_index[dim] = k_part + + # apply the advanced indexing setitem locally + _advanced_setitem_unordered_local( + x_local=self.larray, + split_key=split_key, + value_torch=value_torch, + split_axis=local_split_axis, + value_key_start_dim=value_key_start_dim, + local_offset=local_offset, + local_size=local_size, + value_is_scalar=value_is_scalar, + out_dtype=self.dtype.torch_type(), + base_index=tuple(base_index), ) - self.__setter(tuple(key), value.larray) + self = self.transpose(backwards_transpose_axes) return - # if rank in actives: - if rank not in actives: - return # non-active ranks can exit here - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - ) - key[self.split] = slice(key_start, key_stop, key_step) - - # todo: need to slice the values to be the right size... - if isinstance(value, (torch.Tensor, type(self))): - # if its a torch tensor, it is assumed to exist on all processes - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts[rank] - og_key_start).item() - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # both `self` and `value` are distributed + # distribution of `key` and `value` must be aligned + if key_is_mask_like: + # redistribute `value` to match distribution of `key` in one pass + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False ) - self.__setter(tuple(key), value[tuple(value_slice)]) + target_map = value.lshape_map + target_map[:, value.split] = global_split_key.lshape_map[:, 0] + value.redistribute_(target_map=target_map) else: - self.__setter(tuple(key), value) - elif isinstance(key[self.split], (torch.Tensor, list)): - key = list(key) - key[self.split] -= chunk_start - if len(key[self.split]) != 0: - self.__setter(tuple(key), value) - - elif key[self.split] in range(chunk_start, chunk_end): - key = list(key) - key[self.split] = key[self.split] - chunk_start - self.__setter(tuple(key), value) + # redistribute split-axis `key` to match distribution of `value` in one pass + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + else: + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False + ) + target_map = global_split_key.lshape_map + target_map[:, 0] = value.lshape_map[:, value.split] + global_split_key.redistribute_(target_map=target_map) + split_key = global_split_key.larray + + # key and value are now aligned + + # prepare for `value` Alltoallv: + # work along axis 0, transpose if necessary + transpose_axes = list(range(value.ndim)) + transpose_axes[0], transpose_axes[value.split] = ( + transpose_axes[value.split], + transpose_axes[0], + ) + value = value.transpose(transpose_axes) + send_counts = torch.zeros( + self.comm.size, dtype=torch.int64, device=self.device.torch_device + ) + send_displs = torch.zeros_like(send_counts) + # allocate send buffer: add 1 column to store sent indices + send_buf_shape = list(value.lshape) + if value.ndim < 2: + send_buf_shape.append(1) + if key_is_mask_like: + send_buf_shape[-1] += len(key) + else: + send_buf_shape[-1] += 1 + send_buf = torch.zeros( + send_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + for proc in range(self.comm.size): + # calculate what local elements of `value` belong on process `proc` + send_indices = torch.nonzero( + (split_key >= displs[proc]) & (split_key < displs[proc] + counts[proc]) + ).flatten() + # calculate outgoing counts and displacements for each process + send_counts[proc] = send_indices.numel() + send_displs[proc] = send_counts[:proc].sum() + # compose send buffer: stack local elements of `value` according to destination process + if send_indices.numel() > 0: + if value.ndim < 2: + # temporarily add a singleton dimension to value to accmodate column dimension for send_indices + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices].unsqueeze(1) + ) + else: + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices] + ) + # store outgoing GLOBAL indices in the last column of send_buf + # TODO: if key_is_mask_like: apply send_indices to all dimensions of key + if key_is_mask_like: + for i in range(-len(key), 0): + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], i + ] = key[i + len(key)][send_indices] + else: + send_indices = split_key[send_indices] + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], -1] = ( + send_indices + ) - elif key[self.split] < 0: + # compose communication matrix: share `send_counts` information with all processes + comm_matrix = torch.zeros( + (self.comm.size, self.comm.size), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(send_counts, comm_matrix) + # comm_matrix columns contain recv_counts for each process + recv_counts = comm_matrix[:, self.comm.rank].squeeze(0) + recv_displs = torch.zeros_like(recv_counts) + recv_displs[1:] = recv_counts.cumsum(0)[:-1] + # allocate receive buffer, with 1 extra column for incoming indices + recv_buf_shape = value.lshape_map[self.comm.rank] + recv_buf_shape[value.split] = recv_counts.sum() + recv_buf_shape = recv_buf_shape.tolist() + if value.ndim < 2: + recv_buf_shape.append(1) + if key_is_mask_like: + recv_buf_shape[-1] += len(key) + else: + recv_buf_shape[-1] += 1 + recv_buf_shape = tuple(recv_buf_shape) + recv_buf = torch.zeros( + recv_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + # perform Alltoallv along the 0 axis + send_counts, send_displs, recv_counts, recv_displs = ( + send_counts.tolist(), + send_displs.tolist(), + recv_counts.tolist(), + recv_displs.tolist(), + ) + self.comm.Alltoallv( + (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) + ) + del send_buf, comm_matrix key = list(key) - if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - self.__setter(tuple(key), value) + if key_is_mask_like: + # extract incoming indices from recv_buf + recv_indices = recv_buf[..., -len(key) :] + # correct split-axis indices for rank offset + recv_indices[:, 0] -= displs[rank] + key = recv_indices.split(1, dim=1) + key = [key[i].squeeze_(1) for i in range(len(key))] + # remove indices from recv_buf + recv_buf = recv_buf[..., : -len(key)] + else: + # store incoming indices in int 1-D tensor and correct for rank offset + recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] + # remove last column from recv_buf + recv_buf = recv_buf[..., :-1] + # replace split-axis key with incoming local indices + key = list(key) + key[self.split] = recv_indices + key = tuple(key) + # transpose back value and recv_buf if necessary, wrap recv_buf in DNDarray + value = value.transpose(transpose_axes) + if value.ndim < 2: + recv_buf.squeeze_(1) + recv_buf = DNDarray( + recv_buf.permute(*transpose_axes), + gshape=value.gshape, + dtype=value.dtype, + split=value.split, + device=value.device, + comm=value.comm, + balanced=value.balanced, + ) + # set local elements of `self` to corresponding elements of `value` + __set(self, key, recv_buf) + self = self.transpose(backwards_transpose_axes) def __setter( self, - key: Union[int, Tuple[int, ...], List[int, ...]], - value: Union[float, DNDarray, torch.Tensor], + key: int | tuple[int, ...] | list[int], + value: float | "DNDarray" | torch.Tensor, ): """ Utility function for checking ``value`` and forwarding to :func:``__setitem__`` @@ -1885,13 +3351,127 @@ def __setter( else: raise NotImplementedError(f"Not implemented for {value.__class__.__name__}") + def __take_split0_global_1d( + self, + idx: torch.Tensor, + out_gshape: tuple[int, ...], + out_split: int | None, + out_is_balanced: bool, + ) -> "DNDarray": + """ + Distributed take for 1D arrays split along axis 0. + idx contains GLOBAL indices (any shape). Returns self[idx] with shape out_gshape. + + Communication strategy: + - each rank sends requested indices to owning ranks (Alltoallv) + - owners lookup local values and send them back (Alltoallv) + - requester reorders to original idx order and reshapes + """ + comm = self.comm + size = comm.Get_size() + rank = comm.Get_rank() + + # flatten local request + idx_flat = idx.reshape(-1).contiguous() + + # handle empty + if idx_flat.numel() == 0: + empty = self.larray.new_empty(idx.shape, dtype=self.larray.dtype) + return DNDarray( + empty, + out_gshape, + dtype=self.dtype, + split=out_split, + device=self.device, + comm=comm, + balanced=out_is_balanced, + ) + + # normalize negative indices + n = self.gshape[0] + if (idx_flat < 0).any(): + idx_flat = idx_flat.clone() + idx_flat[idx_flat < 0] += n + + # bounds check + if (idx_flat < 0).any() or (idx_flat >= n).any(): + raise IndexError("index out of bounds") + + # ownership map via counts/displs of self + counts, displs = self.counts_displs() # python lists + if size == 1: + vals = self.larray[idx_flat].reshape(idx.shape) + return DNDarray( + vals, + out_gshape, + dtype=self.dtype, + split=out_split, + device=self.device, + comm=comm, + balanced=out_is_balanced, + ) + + boundaries = torch.tensor(displs[1:], device=idx_flat.device, dtype=idx_flat.dtype) + owners = torch.bucketize(idx_flat, boundaries, right=True) + + # group requests by owner + owners_sorted, order = owners.sort(stable=True) + idx_sorted = idx_flat[order] + + # send counts/displs + send_counts_t = torch.bincount(owners_sorted, minlength=size).to(torch.int64) + send_counts = send_counts_t.cpu().tolist() + send_displs = [0] + for c in send_counts[:-1]: + send_displs.append(send_displs[-1] + c) + + # recv counts/displs + recv_counts = comm.alltoall(send_counts) + recv_displs = [0] + for c in recv_counts[:-1]: + recv_displs.append(recv_displs[-1] + c) + recv_total = sum(recv_counts) + + # exchange indices + recv_idx = torch.empty((recv_total,), dtype=idx_sorted.dtype, device=idx_sorted.device) + comm.Alltoallv((idx_sorted, send_counts, send_displs), (recv_idx, recv_counts, recv_displs)) + + # local lookup on owner + offset = displs[rank] + local_idx = recv_idx - offset + local_src = self.larray.contiguous() + send_vals = local_src[local_idx] + + # send values back (reverse pattern) + recv_vals_grouped = torch.empty( + (idx_sorted.numel(),), dtype=send_vals.dtype, device=send_vals.device + ) + comm.Alltoallv( + (send_vals, recv_counts, recv_displs), (recv_vals_grouped, send_counts, send_displs) + ) + + # undo grouping permutation + inv = torch.empty_like(order) + inv[order] = torch.arange(order.numel(), device=order.device, dtype=order.dtype) + vals = recv_vals_grouped[inv].reshape(idx.shape) + + return DNDarray( + vals, + out_gshape, + dtype=self.dtype, + split=out_split, + device=self.device, + comm=comm, + balanced=out_is_balanced, + ) + def __str__(self) -> str: """ Computes a string representation of the passed ``DNDarray``. """ return printing.__str__(self) - def tolist(self, keepsplit: bool = False) -> List: + def tolist(self, keepsplit: bool = False) -> list: """ Return a copy of the local array data as a (nested) Python list. For scalars, a standard Python number is returned. @@ -1937,11 +3517,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_proxy__(self) -> torch.Tensor: """ - Return a 1-element `torch.Tensor` strided as the global `self` shape. - Used internally for sanitation purposes. + Return a 1-element `torch.Tensor` strided as the global `self` shape. The split axis of the initial DNDarray is stored in the `names` attribute of the returned tensor. + Used internally to lower memory footprint of sanitation. """ - return torch.ones((1,), dtype=torch.int8, device=self.larray.device).as_strided( - self.gshape, [0] * self.ndim + names = [None] * self.ndim + if self.split is not None: + names[self.split] = "split" + return ( + torch.ones((1,), dtype=torch.int8, device=self.larray.device) + .as_strided(self.gshape, [0] * self.ndim) + .refine_names(*names) ) @staticmethod @@ -1953,7 +3538,7 @@ def __xitem_get_key_start_stop( step: int, ends: torch.Tensor, og_key_st: int, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: # this does some basic logic for adjusting the starting and stoping of the a key for # setitem and getitem if step is not None and rank > actives[0]: @@ -1987,3 +3572,4 @@ def __xitem_get_key_start_stop( from .devices import Device from .stride_tricks import sanitize_axis from .types import datatype, canonical_heat_type +from .types import bool as ht_bool, uint8 as ht_uint8 diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 27564ef02c..8ae88cfb56 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -3,22 +3,22 @@ """ import torch -from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence from .communication import MPI from .dndarray import DNDarray -from . import sanitation from . import types +from . import manipulations __all__ = ["nonzero", "where"] -def nonzero(x: DNDarray) -> DNDarray: +def nonzero(x: DNDarray) -> tuple[DNDarray, ...]: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``) - If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` + TODO: UPDATE DOCS! + Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, + containing the indices of the non-zero elements in that dimension. If ``x`` is split then + the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. - Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -32,10 +32,8 @@ def nonzero(x: DNDarray) -> DNDarray: >>> import heat as ht >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) >>> ht.nonzero(x) - DNDarray([[0, 0], - [1, 1], - [1, 2], - [2, 1]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) >>> y > 3 DNDarray([[False, False, False], @@ -48,41 +46,80 @@ def nonzero(x: DNDarray) -> DNDarray: [2, 0], [2, 1], [2, 2]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - sanitation.sanitize_in(x) - - if x.split is None: - # if there is no split then just return the values from torch - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - gout = list(lcl_nonzero.size()) - is_split = None + try: + local_x = x.larray + except AttributeError: + raise TypeError("Input must be a DNDarray, is {}".format(type(x))) + + if not x.is_distributed(): + # nonzero indices as tuple + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=True) + # bookkeeping for final DNDarray construct + nonzero_size = lcl_nonzero[0].shape[0] + output_split = None + output_balanced = True else: - # a is split - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - _, _, slices = x.comm.chunk(x.shape, x.split) - lcl_nonzero[..., x.split] += slices[x.split].start - gout = list(lcl_nonzero.size()) - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) - is_split = 0 - - if x.ndim == 1: - lcl_nonzero = lcl_nonzero.squeeze(dim=1) - - for g in range(len(gout) - 1, -1, -1): - if gout[g] == 1 and len(gout) > 1: - del gout[g] - - return DNDarray( - lcl_nonzero, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, - ) + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + nonzero_size = torch.tensor( + lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device + ) + # global nonzero_size + x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) + # correct indices along split axis + _, displs = x.counts_displs() + lcl_nonzero[:, x.split] += displs[x.comm.rank] + + if x.split != 0: + # construct global 2D DNDarray of nz indices: + shape_2d = (nonzero_size.item(), x.ndim) + global_nonzero = DNDarray( + lcl_nonzero, + gshape=shape_2d, + dtype=types.int64, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # stabilize distributed result: vectorized sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + # return indices as tuple of columns + lcl_nonzero = global_nonzero.larray.split(1, dim=1) + output_balanced = True + else: + # return indices as tuple of columns + lcl_nonzero = lcl_nonzero.split(1, dim=1) + output_balanced = False + + nonzero_size = nonzero_size.item() + output_split = 0 + + # return global_nonzero as tuple of DNDarrays + global_nonzero = list(lcl_nonzero) + output_shape = (nonzero_size,) + for i, nz_tensor in enumerate(global_nonzero): + if nz_tensor.ndim > 1: + # extra dimension in distributed case from usage of torch.split() + nz_tensor = nz_tensor.squeeze(dim=-1) + nz_array = DNDarray( + nz_tensor, + gshape=output_shape, + dtype=types.int64, + split=output_split, + device=x.device, + comm=x.comm, + balanced=output_balanced, + ) + global_nonzero[i] = nz_array + global_nonzero = tuple(global_nonzero) + + return tuple(global_nonzero) DNDarray.nonzero = lambda self: nonzero(self) @@ -91,8 +128,8 @@ def nonzero(x: DNDarray) -> DNDarray: def where( cond: DNDarray, - x: Union[None, int, float, DNDarray] = None, - y: Union[None, int, float, DNDarray] = None, + x: None | int | float | DNDarray = None, + y: None | int | float | DNDarray = None, ) -> DNDarray: """ Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition. @@ -131,20 +168,52 @@ def where( [ 0, 2, -1], [ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None) """ + # ---- binary where(cond, x, y) branch ------------------------------------ if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)): if (isinstance(x, DNDarray) and cond.split != x.split) or ( isinstance(y, DNDarray) and cond.split != y.split ): - if len(y.shape) >= 1 and y.shape[0] > 1: + # Only raise if the "other" array has a meaningful first dimension. + if isinstance(y, DNDarray) and len(y.shape) >= 1 and y.shape[0] > 1: raise NotImplementedError("binary op not implemented for different split axes") + if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)): + # Simple elementwise selection using arithmetic: + # cond == 0 -> take y, cond == 1 -> take x for var in [x, y]: if isinstance(var, int): var = float(var) return cond.dtype(cond == 0) * y + cond * x + + # ---- where(cond) "indices only" branch ---------------------------------- elif x is None and y is None: - return nonzero(cond) + # General rule: delegate to nonzero(cond), and only wrap into a 2-D + # coordinate matrix in the special distributed case where the array + # is split along a non-zero axis. + nz = nonzero(cond) # tuple of DNDarrays, one per dimension + + # 1) Non-distributed: behave exactly like ht.nonzero(cond) + if cond.split is None: + return nz + + # 2) Distributed along axis 0: keep the legacy tuple-of-indices API. + # This is relied upon in several parts of the code base (e.g. KMeans). + if cond.split == 0: + return nz + + # 3) Distributed along a non-zero axis (split > 0) + coords = manipulations.stack(nz, axis=1) + coords = coords.astype(types.int64, copy=False) + + # Ensure indices are split along axis 0 for stable distributed behavior + if coords.split is None: + coords.resplit_(0) + + return coords + + # ---- invalid combinations ---------------------------------------------- else: raise TypeError( - f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})" + "either both or neither x and y must be given and both must be " + f"DNDarrays or numerical scalars (got {type(x)}, {type(y)})" ) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 87120adcc8..17e0e20c9e 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -301,6 +301,8 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + # distributed + # ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=0) a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -566,6 +568,7 @@ def test_matmul(self): self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) + # splits 1 None a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=None) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 47e83a9d52..d1c003cb3f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4070,7 +4070,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: except AttributeError: x = factories.array(x).reshape(1) - x_proxy = x.__torch_proxy__() + x_proxy = x.__torch_proxy__().rename(None) # drop named-tensor metadata # torch-proof args/kwargs: # torch `reps`: int or sequence of ints; numpy `reps`: can be array-like @@ -4134,7 +4134,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: trans_axes[0], trans_axes[x.split] = x.split, 0 reps[0], reps[x.split] = reps[x.split], reps[0] x = linalg.transpose(x, trans_axes) - x_proxy = x.__torch_proxy__() + x_proxy = x.__torch_proxy__().rename(None) out_gshape = tuple(x_proxy.repeat(reps).shape) local_x = x.larray diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c6123c1cf2..56af318556 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -605,6 +605,260 @@ def test_float_cast(self): with self.assertRaises(TypeError): float(ht.full((ht.MPI_WORLD.size,), 2, split=0)) + def test_getitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.arange(10) + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.int32) + # 1D, distributed + x = ht.arange(10, split=0, dtype=ht.float64) + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x[2].split is None) + # 2D, local + x = ht.arange(10).reshape(2, 5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.int32) + # 2D, distributed + x_split0 = ht.array(x, split=0) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.array(x, split=1) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) + # 3D, local + x = ht.arange(27).reshape(3, 3, 3) + key = -2 + indexed = x[key] + self.assertTrue((indexed.larray == x.larray[key]).all()) + self.assertTrue(indexed.dtype == ht.int32) + self.assertTrue(indexed.split is None) + # 3D, distributed, split = 0 + x_split0 = ht.array(x, dtype=ht.float32, split=0) + indexed_split0 = x_split0[key] + self.assertTrue((indexed_split0.larray == x.larray[key]).all()) + self.assertTrue(indexed_split0.dtype == ht.float32) + self.assertTrue(indexed_split0.split is None) + # 3D, distributed split, != 0 + x_split2 = ht.array(x, dtype=ht.int64, split=2) + key = ht.array(2) + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy() == x.numpy()[key.item()]).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(indexed_split2.split == 1) + + # Slicing and striding + x = ht.arange(20, split=0) + x_sliced = x[1:11:3] + x_np = np.arange(20) + x_sliced_np = x_np[1:11:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 0) + + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x_sliced = x[:, 2:3] + x_np = np.arange(20).reshape(4, 5) + x_sliced_np = x_np[:, 2:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 1) + + # slicing with negative step along split axis 0 + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[17:2:-2, :2, 1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) + + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=1) + key = (slice(None, 2), slice(17, 2, -2), 1) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (slice(None, 2), 1, slice(17, 10, -2)) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (0, 1, slice(17, 13, -1)) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) + + # tests for bug 730: + a = ht.ones((10, 25, 30), split=1) + if a.comm.size > 1: + self.assertEqual(a[0].split, 0) + self.assertEqual(a[:, 0, :].split, None) + self.assertEqual(a[:, :, 0].split, 1) + + # DIMENSIONAL INDEXING + # ellipsis + x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_ellipsis = x_np[..., 0] + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + + # local + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + + # distributed + x.resplit_(axis=1) + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_newaxis = x_np[:, np.newaxis, :2, :] + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + + # newaxis: distributed + x.resplit_(axis=1) + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + self.assertTrue(x_newaxis.split == 2) + self.assertTrue(x_none.split == 2) + + x = ht.arange(5, split=0) + x_np = np.arange(5) + y = x[:, np.newaxis] + x[np.newaxis, :] + y_np = x_np[:, np.newaxis] + x_np[np.newaxis, :] + self.assert_array_equal(y, y_np) + self.assertTrue(y.split == 0) + + # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + + x_np = np.arange(60).reshape(5, 3, 4) + indexed_x_np = x_np[(1, 2, 3)] + adv_indexed_x_np = x_np[(1, 2, 3),] + x = ht.array(x_np, split=0) + indexed_x = x[(1, 2, 3)] + self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) + adv_indexed_x = x[(1, 2, 3),] + self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) + + # 1d + x = ht.arange(10, 1, -1, split=0) + x_np = np.arange(10, 1, -1) + x_adv_ind = x[np.array([3, 3, 1, 8])] + x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] + self.assert_array_equal(x_adv_ind, x_np_adv_ind) + + # 3d, split 0, non-unique, non-ordered key along split axis + x = ht.arange(60, split=0).reshape(5, 3, 4) + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + self.assert_array_equal( + x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] + ) + # advanced indexing on non-consecutive dimensions + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + self.assert_array_equal(x[key], x_np[key]) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) + self.assertTrue((x == x_copy).all().item()) + + # broadcasting shapes + x.resplit_(axis=0) + self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] + + # more broadcasting + x_np = np.arange(12).reshape(4, 3) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + x_np_indexed = x_np[rows[:, np.newaxis], cols] + x_indexed = x[ht.array(rows)[:, np.newaxis], cols] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 1) + + # combining advanced and basic indexing + y_np = np.arange(35).reshape(5, 7) + y_np_indexed = y_np[np.array([0, 2, 4]), 1:3] + y = ht.array(y_np, split=1) + y_indexed = y[ht.array([0, 2, 4]), 1:3] + self.assert_array_equal(y_indexed, y_np_indexed) + self.assertTrue(y_indexed.split == 1) + + x_np = np.arange(10 * 20 * 30).reshape(10, 20, 30) + x = ht.array(x_np, split=1) + ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + ind_array_np = ind_array.numpy() + x_np_indexed = x_np[..., ind_array_np, :] + x_indexed = x[..., ind_array, :] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 3) + + # boolean mask, local + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) + self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) + + # boolean mask, distributed + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) + self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) + + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + self.assert_array_equal(arr_split1[mask_split1], arr.numpy()[mask]) + + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) + + # boolean edge case + idx = ht.array([2, 0, 1], split=0) + mask = ht.array([True, False, True], split=0) + self.assertTrue((idx[mask] == ht.array([2, 1], dtype=idx.dtype, split=0)).all().item()) + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) @@ -1146,6 +1400,309 @@ def test_resplit(self): self.assertTrue(ht.all(t1_sub == res)) self.assertEqual(t1_sub.split, None) + def test_setitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.zeros(10) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.float32) + # 1D, distributed + x = ht.zeros(10, split=0, dtype=ht.float64) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x.split == 0) + # 2D, local + x = ht.zeros(10).reshape(2, 5) + x[0] = ht.arange(5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.float32) + # 2D, distributed + x_split0 = ht.zeros(10, split=0).reshape(2, 5) + x_split0[0] = ht.arange(5) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.zeros(10, split=0).reshape(2, 5, new_split=1) + x_split1[-2] = ht.arange(5) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) + # 3D, distributed, split = 0 + x_split0 = ht.zeros(27, split=0).reshape(3, 3, 3) + key = -2 + x_split0[key] = ht.arange(3) + self.assertTrue((x_split0[key] == ht.arange(3, device=x_split0.device)).all().item()) + self.assertTrue(x_split0[key].dtype == ht.float32) + self.assertTrue(x_split0.split == 0) + # 3D, distributed split, != 0 + x_split2 = ht.zeros(27, dtype=ht.int64, split=0).reshape(3, 3, 3, new_split=2) + key = ht.array(2) + x_split2[key] = [6, 7, 8] + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy()[0] == np.array([6, 7, 8])).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(x_split2.split == 2) + + # Slicing and striding + x = ht.arange(20, split=0) + x[1:11:3] = ht.array([10, 40, 70, 100]) + x_np = np.arange(20) + x_np[1:11:3] = np.array([10, 40, 70, 100]) + self.assert_array_equal(x, x_np) + self.assertTrue(x.split == 0) + + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x[:, 2:3] = ht.array([10, 40, 70, 100]).reshape(4, 1) + x_np = np.arange(20).reshape(4, 5) + x_np[:, 2:3] = np.array([10, 40, 70, 100]).reshape(4, 1) + self.assert_array_equal(x, x_np) + self.assertTrue(x.split == 1) + with self.assertRaises(ValueError): + x[:, 2:3] = ht.array([10, 40, 70, 100]) + + # slicing with negative step along split axis 0 + # assign different dtype + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + value = ht.random.randn(8, 2) + x_3d[17:2:-2, :2, ht.array(1)] = value + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float32).reshape(shape) + x_3d.resplit_(axis=1) + key = (slice(None, 2), slice(17, 2, -2), 1) + value = ht.random.randn(2, 8) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float64).reshape(shape) + x_3d.resplit_(axis=2) + key = (slice(None, 2), 1, slice(17, 10, -2)) + value = ht.random.randn(2, 4) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (0, 1, slice(17, 13, -1)) + value = ht.random.randint( + 0, + 5, + ( + 1, + 4, + ), + split=1, + ) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.squeeze(0).astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # DIMENSIONAL INDEXING + + # ellipsis + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + # local + value = x.squeeze() + 7 + x[..., 0] = value + self.assertTrue(ht.all(x[..., 0] == value).item()) + value -= 7 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + + # distributed + x.resplit_(axis=1) + value *= 2 + x[..., 0] = value + x_ellipsis = x[..., 0] + self.assertTrue(ht.all(x_ellipsis == value).item()) + value += 2 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local, w. broadcasting and different dtype + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + value = ht.array([10.0, 20.0]).reshape(2, 1) + x[:, None, :2, :] = value + x_newaxis = x[:, None, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + self.assertTrue(ht.all(x[:, None, :2, :] == value).item()) + self.assertTrue(x[:, None, :2, :].dtype == x.dtype) + + # newaxis: distributed w. broadcasting and different dtype + x.resplit_(axis=1) + value = ht.array([30.0, 40.0]).reshape(1, 2, 1) + x[:, np.newaxis, :2, :] = value + x_newaxis = x[:, np.newaxis, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + x_none = x[:, None, :2, :] + self.assertTrue(ht.all(x_none == value).item()) + self.assertTrue(x_none.dtype == x.dtype) + + # distributed value + x = ht.arange(6).reshape(1, 1, 2, 3) + x.resplit_(axis=-1) + value = ht.arange(3).reshape(1, 3) + value.resplit_(axis=1) + x[..., 0, :] = value + self.assertTrue(ht.all(x[..., 0, :] == value).item()) + + # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + + x = ht.arange(60, split=0).reshape(5, 3, 4) + value = 99.0 + x[(1, 2, 3)] = value + indexed_x = x[(1, 2, 3)] + self.assertTrue((indexed_x == value).item()) + self.assertTrue(indexed_x.dtype == x.dtype) + x[(1, 2, 3),] = value + adv_indexed_x = x[(1, 2, 3),] + self.assertTrue(ht.all(adv_indexed_x == value).item()) + self.assertTrue(adv_indexed_x.dtype == x.dtype) + + # 1d + x = ht.arange(10, 1, -1, split=0) + value = ht.arange(4) + x[ht.array([3, 2, 1, 8])] = value + x_adv_ind = x[np.array([3, 2, 1, 8])] + self.assertTrue(ht.all(x_adv_ind == value).item()) + self.assertTrue(x_adv_ind.dtype == x.dtype) + + # TODO: n-d value + + # 3d, split 0, non-unique, non-ordered key along split axis, key mask-like + x = ht.arange(60, split=0).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + value = ht.array([99, 98, 97, 96], split=0) + x[k1, k2, k3] = value + print("DEBUGGING: x[k1, k2, k3]", x[k1, k2, k3].larray) + self.assertTrue((x[k1, k2, k3] == ht.array([96, 98, 97, 96], split=0)).all().item()) + + # advanced indexing on non-consecutive dimensions, split dimension will be lost + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + k1 = np.array([0, 4, 1, 2]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + value = ht.array([99, 98, 97, 96]) + x[key] = value + self.assertTrue((x[key] == ht.array([99, 98, 97, 96])).all().item()) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) + + # broadcasting shapes + x.resplit_(axis=0) + key = (ht.array(k1, split=0), ht.array(1), 2) + value = ht.array([99, 98, 97, 96], split=0) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] = value + + # more broadcasting + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + key = (ht.array(rows)[:, np.newaxis], cols) + value = ht.array([[99, 98], [97, 96]], split=1) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + if x.comm.size > 1: + with self.assertRaises(RuntimeError): + value = ht.array([[99, 98], [97, 96]], split=0) + x[key] = value + + # combining advanced and basic indexing + + y = ht.arange(35).reshape(5, 7) + y.resplit_(1) + y_copy = y.copy() + # assign non-distributed value + value = ht.arange(6).reshape(3, 2) + y[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y[ht.array([0, 2, 4]), 1:3] == value).all().item()) + # assign distributed value + value.resplit_(1) + y_copy[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y_copy[ht.array([0, 2, 4]), 1:3] == value).all().item()) + + + x = ht.arange(10 * 20 * 30).reshape(10, 20, 30) + x.resplit_(1) + ind_array = ht.array( + torch.tensor( + [ + [[11, 10, 3, 2], [13, 10, 0, 4], [9, 3, 2, 0]], + [[6, 10, 3, 8], [16, 10, 12, 9], [10, 18, 6, 15]], + ] + ), + dtype=ht.int64, + ) + value = ht.ones((1, 2, 3, 4, 1)) + x[..., ind_array, :] = value + self.assertTrue((x[..., ind_array, :] == value).all().item()) + + # boolean mask, local + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) + value = 99.0 + arr[mask] = value + self.assertTrue((arr[mask] == value).all().item()) + self.assertTrue(arr[mask].dtype == arr.dtype) + value = ht.ones_like(arr) + arr[mask] = value[mask] + self.assertTrue((arr[mask] == value[mask]).all().item()) + + # boolean mask, distributed, non-distributed `value` + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) + arr_split0[mask_split0] = value[mask] + indexed_arr = arr_split0[mask_split0] + indexed_arr.balance_() + self.assertTrue((indexed_arr == value[mask]).all().item()) + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + arr_split1[mask_split1] = value[mask] + self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + arr_split2[mask_split2] = value[mask] + self.assertTrue((arr_split2[mask_split2] == value[mask]).all().item()) + # 3D non-contiguous resplit testing (Column mayor ordering) torch_array = torch.arange(100, device=self.device.torch_device).reshape((10, 5, 2)) heat_array = ht.array(torch_array, split=2, order="F") @@ -1156,9 +1713,7 @@ def test_resplit(self): self.assertEqual(heat_array.split, 1) # 4D non-contiguous resplit testing (from transpose - torch_array = torch.arange(5 * 4 * 3 * 6, device=self.device.torch_device).reshape( - 5, 4, 3, 6 - ) + torch_array = torch.arange(5 * 4 * 3 * 6, device=self.device.torch_device).reshape(5, 4, 3, 6) res = torch_array.cpu().numpy().transpose((3, 1, 2, 0)) heat_array = ht.array(torch_array, split=2).transpose((3, 1, 2, 0)) heat_array.resplit_(axis=1) @@ -1166,34 +1721,26 @@ def test_resplit(self): self.assertTrue(ht.all(heat_array == ht.array(res))) self.assertEqual(heat_array.split, 1) - def test_setitem_getitem(self): # tests for bug #825 a = ht.ones((102, 102), split=0) setting = ht.zeros((100, 100), split=0) a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((30, 100), split=1) a[-30:, 1:-1] = setting - self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) + self.assertTrue(ht.all(a[-30:, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((100, 100), split=1) a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((100, 20), split=1) a[1:-1, :20] = setting - self.assertTrue(ht.all(a[1:-1, :20] == 0)) - - # tests for bug 730: - a = ht.ones((10, 25, 30), split=1) - if a.comm.size > 1: - self.assertEqual(a[0].split, 0) - self.assertEqual(a[:, 0, :].split, None) - self.assertEqual(a[:, :, 0].split, 1) + self.assertTrue(ht.all(a[1:-1, :20] == 0).item()) # set and get single value a = ht.zeros((13, 5), split=0) @@ -1216,7 +1763,7 @@ def test_setitem_getitem(self): self.assertEqual(b.dtype, ht.float32) self.assertEqual(b.gshape, (5,)) - # slice in 1st dim only on 1 node + # slice in 1st dim only on 1 node a = ht.zeros((13, 5), split=0) a[1:4] = 1 self.assertTrue((a[1:4] == 1).all()) @@ -1295,6 +1842,13 @@ def test_setitem_getitem(self): if a.comm.rank == 0: self.assertEqual(a[3:13, 2:5:2].lshape, (4, 2)) + # setting with heat tensor + a = ht.zeros((4, 5), split=0) + a[1, 0:4] = ht.arange(4) + # if a.comm.size == 2: + for c, i in enumerate(range(4)): + self.assertEqual(a[1, c], i) + # setting with heat tensor a = ht.zeros((4, 5), split=0) if self.is_mps: @@ -1315,19 +1869,8 @@ def test_setitem_getitem(self): for c, i in enumerate(range(4)): self.assertEqual(a[1, c], i) - ################################################### - a = ht.zeros((13, 5), split=1) - # # set value on one node - a[10] = 1 - self.assertEqual(a[10].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10].lshape, (3,)) - if a.comm.rank == 1: - self.assertEqual(a[10].lshape, (2,)) - a = ht.zeros((13, 5), split=1) - # # set value on one node + # set value on one node a[10, 0] = 1 self.assertEqual(a[10, 0], 1) self.assertEqual(a[10, 0].dtype, ht.float32) @@ -1409,6 +1952,14 @@ def test_setitem_getitem(self): if a.comm.rank == 0: self.assertEqual(a[:, 2:5:2].lshape, (13, 1)) + # setting with heat tensor + a = ht.zeros((4, 5), split=1) + a[1, 0:4] = ht.arange(4) + for c, i in enumerate(range(4)): + b = a[1, c] + if b.larray.numel() > 0: + self.assertEqual(b.item(), i) + # setting with heat tensor a = ht.zeros((4, 5), split=1) if self.is_mps: @@ -1429,18 +1980,6 @@ def test_setitem_getitem(self): for c, i in enumerate(range(4)): self.assertEqual(a[1, c], i) - #################################################### - a = ht.zeros((13, 5, 7), split=2) - # # set value on one node - a[10, :, :] = 1 - self.assertEqual(a[10, :, :].dtype, ht.float32) - self.assertEqual(a[10, :, :].gshape, (5, 7)) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10, :, :].lshape, (5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[10, :, :].lshape, (5, 3)) - a = ht.zeros((13, 5, 7), split=2) # # set value on one node a[10, ...] = 1 @@ -1550,6 +2089,8 @@ def test_setitem_getitem(self): ht_key = ht.array(np_key, split=split) arr[ht_key, 4] = 10.0 np_arr[np_key, 4] = 10.0 + #print(f"\n\n\n arr.numpy(): {arr.numpy()}, np_arr: {np_arr}\n\n\n ") + print(f"\n\n\n arr[ht_key, 4] : {arr[ht_key, 4] }\n\n\n ") self.assertTrue(np.all(arr.numpy() == np_arr)) self.assertTrue(ht.all(arr[ht_key, 4] == 10.0)) @@ -1600,7 +2141,7 @@ def test_setitem_getitem(self): with self.assertRaises(ValueError): a[..., ...] = 1 if a.comm.size > 1: - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): x = ht.ones((10, 10), split=0) setting = ht.zeros((8, 8), split=1) x[1:-1, 1:-1] = setting @@ -1812,6 +2353,7 @@ def test_torch_proxy(self): dndarray_proxy.storage().size() * dndarray_proxy.storage().element_size() ) self.assertTrue(dndarray_proxy_nbytes == 1) + self.assertTrue(dndarray_proxy.names.index("split") == 1) def test_torch_function(self): arr = ht.array([1, 2, 3, 4]) @@ -1828,3 +2370,154 @@ def test_xor(self): self.assertTrue( ht.equal(int16_tensor ^ int16_vector, ht.bitwise_xor(int16_tensor, int16_vector)) ) + + def test_getitem_boolean_fewer_dims(self): + # Test case: 2D array, 1D boolean mask (selects rows) + # NumPy behavior: x_2D[bool_1D] selects entire rows + arr_np = np.arange(20).reshape((10, 2)) + mask_np = np.array([True, False, True, False, True, False, True, False, True, False]) + result_np = arr_np[mask_np] # Shape (5, 2) + + # Case 1: split=None (local) + arr_ht = ht.array(arr_np, split=None) + mask_ht = ht.array(mask_np, split=None) + result_ht = arr_ht[mask_ht] + self.assert_array_equal(result_ht, result_np) + self.assertEqual(result_ht.split, None) + self.assertEqual(result_ht.gshape, (5, 2)) + + # Case 2: split=0 (split on the indexed dimension) + arr_ht_s0 = ht.array(arr_np, split=0) + mask_ht_s0 = ht.array(mask_np, split=0) + + result_ht_s0 = arr_ht_s0[mask_ht_s0] + + self.assert_array_equal(result_ht_s0, result_np) + self.assertEqual(result_ht_s0.split, 0) + self.assertEqual(result_ht_s0.gshape, (5, 2)) + + # Case 3: split=1 (split on a non-indexed dimension) + arr_ht_s1 = ht.array(arr_np, split=1) + # Mask can be local or split=0, test local (None) for broadcasting + mask_ht_sNone = ht.array(mask_np, split=None) + result_ht_s1 = arr_ht_s1[mask_ht_sNone] + self.assert_array_equal(result_ht_s1, result_np) + print(f"result_ht_s1.split: {result_ht_s1.split}") + self.assertEqual(result_ht_s1.split, 1) + self.assertEqual(result_ht_s1.gshape, (5, 2)) + + # Case 4: 3D array, 2D boolean mask + arr_np_3d = np.arange(30).reshape((2, 3, 5)) + mask_np_2d = np.array([[True, True, False], [False, True, True]]) + result_np_3d = arr_np_3d[mask_np_2d] # Shape (4, 5) + + # Test split=None + arr_ht_3d = ht.array(arr_np_3d, split=None) + mask_ht_2d = ht.array(mask_np_2d, split=None) + result_ht_3d = arr_ht_3d[mask_ht_2d] + self.assert_array_equal(result_ht_3d, result_np_3d) + self.assertEqual(result_ht_3d.gshape, (4, 5)) + + # Test split=2 (split on the non-indexed dimension) + arr_ht_3d_s2 = ht.array(arr_np_3d, split=2) + mask_ht_2d_sNone = ht.array(mask_np_2d, split=None) # Broadcast mask + result_ht_3d_s2 = arr_ht_3d_s2[mask_ht_2d_sNone] + self.assert_array_equal(result_ht_3d_s2, result_np_3d) + self.assertEqual(result_ht_3d_s2.gshape, (4, 5)) + self.assertEqual(result_ht_3d_s2.split, 1) # New split axis (originally 2, 2 dims removed) + + def test_setitem_boolean_fewer_dims(self): + # Test case: 2D array, 1D boolean mask (selects rows) + arr_np = np.arange(20).reshape((10, 2)) + mask_np = np.array([True, False, True, False, True, False, True, False, True, False]) + value = 99 + arr_np_set = arr_np.copy() + arr_np_set[mask_np] = value + + # Case 1: split=None (local) + arr_ht = ht.array(arr_np, split=None) + mask_ht = ht.array(mask_np, split=None) + arr_ht[mask_ht] = value + self.assert_array_equal(arr_ht, arr_np_set) + + # Case 2: split=0 (split on the indexed dimension) + arr_ht_s0 = ht.array(arr_np, split=0) + mask_ht_s0 = ht.array(mask_np, split=0) + arr_ht_s0[mask_ht_s0] = value + self.assert_array_equal(arr_ht_s0, arr_np_set) + + # Case 3: split=1 (split on a non-indexed dimension) + arr_ht_s1 = ht.array(arr_np, split=1) + mask_ht_sNone = ht.array(mask_np, split=None) + arr_ht_s1[mask_ht_sNone] = value + self.assert_array_equal(arr_ht_s1, arr_np_set) + + def test_getitem_edge_cases(self): + # Test edge cases from NumPy docs + + # Case 1: 0-D (Scalar) DNDarray + x_ht_0d = ht.array(10) + self.assertEqual(x_ht_0d.ndim, 0) + result_0d = x_ht_0d[()] + # NumPy returns a scalar, heat returns a 0-D tensor + self.assertEqual(result_0d.ndim, 0) + self.assertEqual(result_0d.item(), 10) + + # Case 2: N-D local DNDarray + arr_np = np.arange(10).reshape((5, 2)) + arr_ht_local = ht.array(arr_np, split=None) + + # Test [...] + result_ellipsis = arr_ht_local[...] + self.assert_array_equal(result_ellipsis, arr_np) + self.assertIs(result_ellipsis.larray, arr_ht_local.larray) # Check for view + + # Test [()] + result_empty_tuple = arr_ht_local[()] + self.assert_array_equal(result_empty_tuple, arr_np) + self.assertIs(result_empty_tuple.larray, arr_ht_local.larray) # Check for view + + # Case 3: N-D split DNDarray + arr_ht_split = ht.array(arr_np, split=0) + + # Test [...] + result_split_ellipsis = arr_ht_split[...] + self.assert_array_equal(result_split_ellipsis, arr_np) + self.assertEqual(result_split_ellipsis.split, 0) + self.assertIs(result_split_ellipsis.larray, arr_ht_split.larray) # Check for view + + # Test [()] + result_split_empty_tuple = arr_ht_split[()] + self.assert_array_equal(result_split_empty_tuple, arr_np) + self.assertEqual(result_split_empty_tuple.split, 0) + self.assertIs(result_split_empty_tuple.larray, arr_ht_split.larray) # Check for view + + def test_setitem_edge_cases(self): + # Test edge cases from NumPy docs + + # Case 1: 0-D (Scalar) DNDarray + x_ht_0d = ht.array(10) + x_ht_0d[()] = 99 + self.assertEqual(x_ht_0d.item(), 99) + + # Case 2: N-D local DNDarray + arr_ht_local = ht.ones((5, 2), split=None) + + # Test [...] + arr_ht_local[...] = 99 + self.assertTrue(ht.all(arr_ht_local == 99).item()) + + # Test [()] + arr_ht_local[()] = 100 + self.assertTrue(ht.all(arr_ht_local == 100).item()) + + # Case 3: N-D split DNDarray + arr_ht_split = ht.ones((5, 2), split=0) + + # Test [...] + arr_ht_split[...] = 99 + self.assertTrue(ht.all(arr_ht_split == 99).item()) + + # Test [()] + arr_ht_split[()] = 100 + self.assertTrue(ht.all(arr_ht_split == 100).item()) diff --git a/heat/core/tests/test_factories.py b/heat/core/tests/test_factories.py index fe17e897c4..a6ec511e50 100644 --- a/heat/core/tests/test_factories.py +++ b/heat/core/tests/test_factories.py @@ -6,581 +6,581 @@ class TestFactories(TestCase): - def test_arange(self): - # testing one positional integer argument - one_arg_arange_int = ht.arange(10) - self.assertIsInstance(one_arg_arange_int, ht.DNDarray) - self.assertEqual(one_arg_arange_int.shape, (10,)) - self.assertLessEqual(one_arg_arange_int.lshape[0], 10) - self.assertEqual(one_arg_arange_int.dtype, ht.int32) - self.assertEqual(one_arg_arange_int.larray.dtype, torch.int32) - self.assertEqual(one_arg_arange_int.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(one_arg_arange_int.sum(), 45) - - # testing one positional float argument - one_arg_arange_float = ht.arange(10.0) - self.assertIsInstance(one_arg_arange_float, ht.DNDarray) - self.assertEqual(one_arg_arange_float.shape, (10,)) - self.assertLessEqual(one_arg_arange_float.lshape[0], 10) - self.assertEqual(one_arg_arange_float.dtype, ht.float32) - self.assertEqual(one_arg_arange_float.larray.dtype, torch.float32) - self.assertEqual(one_arg_arange_float.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(one_arg_arange_float.sum(), 45.0) - - # testing two positional integer arguments - two_arg_arange_int = ht.arange(0, 10) - self.assertIsInstance(two_arg_arange_int, ht.DNDarray) - self.assertEqual(two_arg_arange_int.shape, (10,)) - self.assertLessEqual(two_arg_arange_int.lshape[0], 10) - self.assertEqual(two_arg_arange_int.dtype, ht.int32) - self.assertEqual(two_arg_arange_int.larray.dtype, torch.int32) - self.assertEqual(two_arg_arange_int.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(two_arg_arange_int.sum(), 45) - - # testing two positional arguments, one being float - two_arg_arange_float = ht.arange(0.0, 10) - self.assertIsInstance(two_arg_arange_float, ht.DNDarray) - self.assertEqual(two_arg_arange_float.shape, (10,)) - self.assertLessEqual(two_arg_arange_float.lshape[0], 10) - self.assertEqual(two_arg_arange_float.dtype, ht.float32) - self.assertEqual(two_arg_arange_float.larray.dtype, torch.float32) - self.assertEqual(two_arg_arange_float.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(two_arg_arange_float.sum(), 45.0) - - # testing three positional integer arguments - three_arg_arange_int = ht.arange(0, 10, 2) - self.assertIsInstance(three_arg_arange_int, ht.DNDarray) - self.assertEqual(three_arg_arange_int.shape, (5,)) - self.assertLessEqual(three_arg_arange_int.lshape[0], 5) - self.assertEqual(three_arg_arange_int.dtype, ht.int32) - self.assertEqual(three_arg_arange_int.larray.dtype, torch.int32) - self.assertEqual(three_arg_arange_int.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(three_arg_arange_int.sum(), 20) - - # testing three positional arguments, one being float - three_arg_arange_float = ht.arange(0, 10, 2.0) - self.assertIsInstance(three_arg_arange_float, ht.DNDarray) - self.assertEqual(three_arg_arange_float.shape, (5,)) - self.assertLessEqual(three_arg_arange_float.lshape[0], 5) - self.assertEqual(three_arg_arange_float.dtype, ht.float32) - self.assertEqual(three_arg_arange_float.larray.dtype, torch.float32) - self.assertEqual(three_arg_arange_float.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(three_arg_arange_float.sum(), 20.0) - - # testing splitting - three_arg_arange_dtype_float32 = ht.arange(0, 10, 2.0, split=0) - self.assertIsInstance(three_arg_arange_dtype_float32, ht.DNDarray) - self.assertEqual(three_arg_arange_dtype_float32.shape, (5,)) - self.assertLessEqual(three_arg_arange_dtype_float32.lshape[0], 5) - self.assertEqual(three_arg_arange_dtype_float32.dtype, ht.float32) - self.assertEqual(three_arg_arange_dtype_float32.larray.dtype, torch.float32) - self.assertEqual(three_arg_arange_dtype_float32.split, 0) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(three_arg_arange_dtype_float32.sum(axis=0, keepdims=True), 20.0) - - # testing setting dtype to int16 - three_arg_arange_dtype_short = ht.arange(0, 10, 2.0, dtype=torch.int16) - self.assertIsInstance(three_arg_arange_dtype_short, ht.DNDarray) - self.assertEqual(three_arg_arange_dtype_short.shape, (5,)) - self.assertLessEqual(three_arg_arange_dtype_short.lshape[0], 5) - self.assertEqual(three_arg_arange_dtype_short.dtype, ht.int16) - self.assertEqual(three_arg_arange_dtype_short.larray.dtype, torch.int16) - self.assertEqual(three_arg_arange_dtype_short.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(three_arg_arange_dtype_short.sum(axis=0, keepdims=True), 20) - - # testing setting dtype to float64 - if not self.is_mps: - three_arg_arange_dtype_float64 = ht.arange(0, 10, 2, dtype=torch.float64) - self.assertIsInstance(three_arg_arange_dtype_float64, ht.DNDarray) - self.assertEqual(three_arg_arange_dtype_float64.shape, (5,)) - self.assertLessEqual(three_arg_arange_dtype_float64.lshape[0], 5) - self.assertEqual(three_arg_arange_dtype_float64.dtype, ht.float64) - self.assertEqual(three_arg_arange_dtype_float64.larray.dtype, torch.float64) - self.assertEqual(three_arg_arange_dtype_float64.split, None) - # make an in direct check for the sequence, compare against the gaussian sum - self.assertEqual(three_arg_arange_dtype_float64.sum(axis=0, keepdims=True), 20.0) - - check_precision = ht.arange(16777217.0, 16777218, 1, dtype=ht.float64) - self.assertEqual(check_precision.sum(), 16777217) - - # exceptions - with self.assertRaises(ValueError): - ht.arange(-5, 3, split=1) - with self.assertRaises(TypeError): - ht.arange() - with self.assertRaises(TypeError): - ht.arange(1, 2, 3, 4) - - def test_array(self): - # basic array function, unsplit data - unsplit_data = [[1, 2, 3], [4, 5, 6]] - a = ht.array(unsplit_data) - self.assertIsInstance(a, ht.DNDarray) - self.assertEqual(a.dtype, ht.int64) - self.assertEqual(a.lshape, (2, 3)) - self.assertEqual(a.gshape, (2, 3)) - self.assertEqual(a.split, None) - self.assertTrue( - (a.larray == torch.tensor(unsplit_data, device=self.device.torch_device)).all() - ) - - # basic array function, unsplit data, different datatype - tuple_data = ((0, 0), (1, 1)) - b = ht.array(tuple_data, dtype=ht.int8) - self.assertIsInstance(b, ht.DNDarray) - self.assertEqual(b.dtype, ht.int8) - self.assertEqual(b.larray.dtype, torch.int8) - self.assertEqual(b.lshape, (2, 2)) - self.assertEqual(b.gshape, (2, 2)) - self.assertEqual(b.split, None) - self.assertTrue( - ( - b.larray - == torch.tensor(tuple_data, dtype=torch.int8, device=self.device.torch_device) - ).all() - ) - if not self.is_mps: - check_precision = ht.array(16777217.0, dtype=ht.float64) - self.assertEqual(check_precision.sum(), 16777217) - - # basic array function, unsplit data, no copy - torch_tensor = torch.tensor([6, 5, 4, 3, 2, 1], device=self.device.torch_device) - c = ht.array(torch_tensor, copy=False) - self.assertIsInstance(c, ht.DNDarray) - self.assertEqual(c.dtype, ht.int64) - self.assertEqual(c.lshape, (6,)) - self.assertEqual(c.gshape, (6,)) - self.assertEqual(c.split, None) - self.assertIs(c.larray, torch_tensor) - self.assertTrue((c.larray == torch_tensor).all()) - - # basic array function, unsplit data, additional dimensions - vector_data = [4.0, 5.0, 6.0] - d = ht.array(vector_data, ndmin=3) - self.assertIsInstance(d, ht.DNDarray) - self.assertEqual(d.dtype, ht.float32) - self.assertEqual(d.lshape, (3, 1, 1)) - self.assertEqual(d.gshape, (3, 1, 1)) - self.assertEqual(d.split, None) - self.assertTrue( - ( - d.larray - == torch.tensor(vector_data, device=self.device.torch_device).reshape(-1, 1, 1) - ).all() - ) - - # basic array function, unsplit data, additional dimensions - vector_data = [4.0, 5.0, 6.0] - d = ht.array(vector_data, ndmin=-3) - self.assertIsInstance(d, ht.DNDarray) - self.assertEqual(d.dtype, ht.float32) - self.assertEqual(d.lshape, (1, 1, 3)) - self.assertEqual(d.gshape, (1, 1, 3)) - self.assertEqual(d.split, None) - self.assertTrue( - ( - d.larray - == torch.tensor(vector_data, device=self.device.torch_device).reshape(1, 1, -1) - ).all() - ) - - # distributed array, chunk local data (split), copy True - if self.is_mps: - np_dtype = np.float32 - torch_dtype = torch.float32 - else: - np_dtype = np.float64 - torch_dtype = torch.float64 - ht_dtype = ht.types.canonical_heat_type(torch_dtype) - - array_2d = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=np_dtype) - dndarray_2d = ht.array(array_2d, split=0, copy=True) - self.assertIsInstance(dndarray_2d, ht.DNDarray) - self.assertEqual(dndarray_2d.dtype, ht_dtype) - self.assertEqual(dndarray_2d.gshape, (3, 3)) - self.assertEqual(len(dndarray_2d.lshape), 2) - self.assertLessEqual(dndarray_2d.lshape[0], 3) - self.assertEqual(dndarray_2d.lshape[1], 3) - self.assertEqual(dndarray_2d.split, 0) - self.assertTrue( - ( - dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) - ).all() - ) - - # distributed array, chunk local data (split), copy False, torch devices - array_2d = torch.tensor( - [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], - dtype=torch_dtype, - device=self.device.torch_device, - ) - dndarray_2d = ht.array(array_2d, split=0, copy=False, dtype=ht_dtype) - self.assertIsInstance(dndarray_2d, ht.DNDarray) - self.assertEqual(dndarray_2d.dtype, ht_dtype) - self.assertEqual(dndarray_2d.gshape, (3, 3)) - self.assertEqual(len(dndarray_2d.lshape), 2) - self.assertLessEqual(dndarray_2d.lshape[0], 3) - self.assertEqual(dndarray_2d.lshape[1], 3) - self.assertEqual(dndarray_2d.split, 0) - self.assertTrue( - ( - dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) - ).all() - ) - # Check that the array is not a copy, (only really works when the array is not split) - if ht.communication.MPI_WORLD.size == 1: - self.assertIs(dndarray_2d.larray, array_2d) - - # The array should not change as all properties match - dndarray_2d_new = ht.array(dndarray_2d, split=0, copy=False, dtype=ht_dtype) - self.assertIsInstance(dndarray_2d_new, ht.DNDarray) - self.assertEqual(dndarray_2d_new.dtype, ht_dtype) - self.assertEqual(dndarray_2d_new.gshape, (3, 3)) - self.assertEqual(len(dndarray_2d_new.lshape), 2) - self.assertLessEqual(dndarray_2d_new.lshape[0], 3) - self.assertEqual(dndarray_2d_new.lshape[1], 3) - self.assertEqual(dndarray_2d_new.split, 0) - self.assertTrue( - ( - dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) - ).all() - ) - # Reuse the same array - self.assertIs(dndarray_2d_new.larray, dndarray_2d.larray) - - # Should throw exeception because it causes a resplit - with self.assertRaises(ValueError): - dndarray_2d_new = ht.array(dndarray_2d, split=1, copy=False, dtype=ht.double) - - # The array should not change as all properties match - dndarray_2d_new = ht.array(dndarray_2d, is_split=0, copy=False, dtype=ht_dtype) - self.assertIsInstance(dndarray_2d_new, ht.DNDarray) - self.assertEqual(dndarray_2d_new.dtype, ht_dtype) - self.assertEqual(dndarray_2d_new.gshape, (3, 3)) - self.assertEqual(len(dndarray_2d_new.lshape), 2) - self.assertLessEqual(dndarray_2d_new.lshape[0], 3) - self.assertEqual(dndarray_2d_new.lshape[1], 3) - self.assertEqual(dndarray_2d_new.split, 0) - self.assertTrue( - ( - dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) - ).all() - ) - - # Should throw exeception because of array is split along another dimension - with self.assertRaises(ValueError): - dndarray_2d_new = ht.array(dndarray_2d, is_split=1, copy=False, dtype=ht.double) - - # distributed array, partial data (is_split) - if ht.communication.MPI_WORLD.rank == 0: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] - else: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] - e = ht.array(split_data, ndmin=3, is_split=0) - - self.assertIsInstance(e, ht.DNDarray) - self.assertEqual(e.dtype, ht.float32) - if ht.communication.MPI_WORLD.rank == 0: - self.assertEqual(e.lshape, (3, 3, 1)) - else: - self.assertEqual(e.lshape, (2, 3, 1)) - self.assertEqual(e.split, 0) - for index, ele in enumerate(e.gshape): - if index != e.split: - self.assertEqual(ele, e.lshape[index]) - else: - self.assertGreaterEqual(ele, e.lshape[index]) - - # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: - split_data = [4.0, 5.0, 6.0] - else: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] - - # this will fail as the shapes do not match - with self.assertRaises(ValueError): - ht.array(split_data, is_split=0) - - # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] - else: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] - - # this will fail as the shapes do not match on a specific axis (here: 0) - with self.assertRaises(ValueError): - ht.array(split_data, is_split=1) - - # check exception on mutually exclusive split and is_split - with self.assertRaises(ValueError): - ht.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], split=0, is_split=0) - - e = ht.array(split_data, ndmin=-3, is_split=1) - - self.assertIsInstance(e, ht.DNDarray) - self.assertEqual(e.dtype, ht.float32) - if ht.communication.MPI_WORLD.rank == 0: - self.assertEqual(e.lshape, (1, 3, 3)) - else: - self.assertEqual(e.lshape, (1, 2, 3)) - self.assertEqual(e.split, 1) - for index, ele in enumerate(e.gshape): - if index != e.split: - self.assertEqual(ele, e.lshape[index]) - else: - self.assertGreaterEqual(ele, e.lshape[index]) - - # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: - split_data = [4.0, 5.0, 6.0] - else: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] - - # this will fail as the shapes do not match - with self.assertRaises(ValueError): - ht.array(split_data, is_split=0) - - # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] - else: - split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] - - # this will fail as the shapes do not match on a specific axis (here: 0) - with self.assertRaises(ValueError): - ht.array(split_data, is_split=1) - - # check exception on mutually exclusive split and is_split - with self.assertRaises(ValueError): - ht.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], split=1, is_split=1) - - # non iterable type - with self.assertRaises(TypeError): - ht.array(map) - # iterable, but unsuitable type - with self.assertRaises(TypeError): - ht.array("abc") - # iterable, but unsuitable type, with copy=True - with self.assertRaises(TypeError): - ht.array("abc", copy=True) - # unknown dtype - with self.assertRaises(TypeError): - ht.array((4,), dtype="a") - # invalid ndmin - with self.assertRaises(TypeError): - ht.array((4,), ndmin=3.0) - # invalid split axis type - with self.assertRaises(TypeError): - ht.array((4,), split="a") - # invalid split axis value - with self.assertRaises(ValueError): - ht.array((4,), split=3) - # invalid communicator - with self.assertRaises(TypeError): - ht.array((4,), comm={}) - # copy=False but copy is necessary - data = np.arange(10) - with self.assertRaises(ValueError): - ht.array(data, dtype=ht.int32, copy=False) - - # data already distributed but don't match in shape - if self.get_size() > 1: - with self.assertRaises(ValueError): - dim = self.get_rank() + 1 - ht.array([[0] * dim] * dim, is_split=0) - - def test_asarray(self): - # same heat array - arr = ht.array([1, 2]) - self.assertTrue(ht.asarray(arr) is arr) - - # from distributed python list - arr = ht.array([1, 2, 3, 4, 5, 6], split=0) - lst = arr.tolist(keepsplit=True) - asarr = ht.asarray(lst, is_split=0) - - self.assertEqual(asarr.shape, arr.shape) - self.assertEqual(asarr.split, 0) - self.assertEqual(asarr.device, ht.get_device()) - self.assertTrue(ht.equal(asarr, arr)) - - # from numpy array - arr = np.array([1, 2, 3, 4]) - asarr = ht.asarray(arr) - - self.assertTrue(np.all(np.equal(asarr.numpy(), arr))) - - asarr[0] = 0 - if asarr.device == ht.cpu: - self.assertEqual(asarr.numpy()[0], arr[0]) - - # from torch tensor - arr = torch.tensor([1, 2, 3, 4], device=self.device.torch_device) - asarr = ht.asarray(arr) - - self.assertTrue(torch.equal(asarr.larray, arr)) - - asarr[0] = 0 - self.assertEqual(asarr.larray[0].item(), arr[0].item()) - - def test_empty(self): - # scalar input - simple_empty_float = ht.empty(3) - self.assertIsInstance(simple_empty_float, ht.DNDarray) - self.assertEqual(simple_empty_float.shape, (3,)) - self.assertEqual(simple_empty_float.lshape, (3,)) - self.assertEqual(simple_empty_float.split, None) - self.assertEqual(simple_empty_float.dtype, ht.float32) - - # different data type - simple_empty_uint = ht.empty(5, dtype=ht.bool) - self.assertIsInstance(simple_empty_uint, ht.DNDarray) - self.assertEqual(simple_empty_uint.shape, (5,)) - self.assertEqual(simple_empty_uint.lshape, (5,)) - self.assertEqual(simple_empty_uint.split, None) - self.assertEqual(simple_empty_uint.dtype, ht.bool) - - # multi-dimensional - elaborate_empty_int = ht.empty((2, 3), dtype=ht.int32) - self.assertIsInstance(elaborate_empty_int, ht.DNDarray) - self.assertEqual(elaborate_empty_int.shape, (2, 3)) - self.assertEqual(elaborate_empty_int.lshape, (2, 3)) - self.assertEqual(elaborate_empty_int.split, None) - self.assertEqual(elaborate_empty_int.dtype, ht.int32) - - # split axis - elaborate_empty_split = ht.empty((6, 4), dtype=ht.int32, split=0) - self.assertIsInstance(elaborate_empty_split, ht.DNDarray) - self.assertEqual(elaborate_empty_split.shape, (6, 4)) - self.assertLessEqual(elaborate_empty_split.lshape[0], 6) - self.assertEqual(elaborate_empty_split.lshape[1], 4) - self.assertEqual(elaborate_empty_split.split, 0) - self.assertEqual(elaborate_empty_split.dtype, ht.int32) - - # exceptions - with self.assertRaises(TypeError): - ht.empty("(2, 3,)", dtype=ht.float64) - with self.assertRaises(ValueError): - ht.empty((-1, 3), dtype=ht.float64) - with self.assertRaises(TypeError): - ht.empty((2, 3), dtype=ht.float64, split="axis") - - def test_empty_like(self): - # scalar - like_int = ht.empty_like(3) - self.assertIsInstance(like_int, ht.DNDarray) - self.assertEqual(like_int.shape, (1,)) - self.assertEqual(like_int.lshape, (1,)) - self.assertEqual(like_int.split, None) - self.assertEqual(like_int.dtype, ht.int32) - - # sequence - like_str = ht.empty_like("abc") - self.assertIsInstance(like_str, ht.DNDarray) - self.assertEqual(like_str.shape, (3,)) - self.assertEqual(like_str.lshape, (3,)) - self.assertEqual(like_str.split, None) - self.assertEqual(like_str.dtype, ht.float32) - - # elaborate tensor - ones = ht.ones((2, 3), dtype=ht.uint8) - like_ones = ht.empty_like(ones) - self.assertIsInstance(like_ones, ht.DNDarray) - self.assertEqual(like_ones.shape, (2, 3)) - self.assertEqual(like_ones.lshape, (2, 3)) - self.assertEqual(like_ones.split, None) - self.assertEqual(like_ones.dtype, ht.uint8) - - # elaborate tensor with split - ones_split = ht.ones((2, 3), dtype=ht.uint8, split=0) - like_ones_split = ht.empty_like(ones_split) - self.assertIsInstance(like_ones_split, ht.DNDarray) - self.assertEqual(like_ones_split.shape, (2, 3)) - self.assertLessEqual(like_ones_split.lshape[0], 2) - self.assertEqual(like_ones_split.lshape[1], 3) - self.assertEqual(like_ones_split.split, 0) - self.assertEqual(like_ones_split.dtype, ht.uint8) - - # exceptions - with self.assertRaises(TypeError): - ht.empty_like(ones, dtype="abc") - with self.assertRaises(TypeError): - ht.empty_like(ones, split="axis") - - def test_eye(self): - def get_offset(tensor_array): - x, y = tensor_array.shape - for k in range(x): - for li in range(y): - if tensor_array[k][li] == 1: - return k, li - return x, y - - shape = 5 - eye = ht.eye(shape, dtype=ht.uint8, split=1) - self.assertIsInstance(eye, ht.DNDarray) - self.assertEqual(eye.dtype, ht.uint8) - self.assertEqual(eye.shape, (shape, shape)) - self.assertEqual(eye.split, 1) - - offset_x, offset_y = get_offset(eye.larray) - self.assertGreaterEqual(offset_x, 0) - self.assertGreaterEqual(offset_y, 0) - x, y = eye.larray.shape - for i in range(x): - for j in range(y): - expected = 1 if i - offset_x is j - offset_y else 0 - self.assertEqual(eye.larray[i][j], expected) - - shape = (10, 20) - eye = ht.eye(shape, dtype=ht.float32) - self.assertIsInstance(eye, ht.DNDarray) - self.assertEqual(eye.dtype, ht.float32) - self.assertEqual(eye.shape, shape) - self.assertEqual(eye.split, None) - - offset_x, offset_y = get_offset(eye.larray) - self.assertGreaterEqual(offset_x, 0) - self.assertGreaterEqual(offset_y, 0) - x, y = eye.larray.shape - for i in range(x): - for j in range(y): - expected = 1.0 if i - offset_x is j - offset_y else 0.0 - self.assertEqual(eye.larray[i][j], expected) - - shape = (10,) - eye = ht.eye(shape, dtype=ht.int32, split=0) - self.assertIsInstance(eye, ht.DNDarray) - self.assertEqual(eye.dtype, ht.int32) - self.assertEqual(eye.shape, shape * 2) - self.assertEqual(eye.split, 0) - - offset_x, offset_y = get_offset(eye.larray) - self.assertGreaterEqual(offset_x, 0) - self.assertGreaterEqual(offset_y, 0) - x, y = eye.larray.shape - for i in range(x): - for j in range(y): - expected = 1 if i - offset_x is j - offset_y else 0 - self.assertEqual(eye.larray[i][j], expected) - - shape = (11, 30) - eye = ht.eye(shape, split=1, dtype=ht.float32) - self.assertIsInstance(eye, ht.DNDarray) - self.assertEqual(eye.dtype, ht.float32) - self.assertEqual(eye.shape, shape) - self.assertEqual(eye.split, 1) + # def test_arange(self): + # # testing one positional integer argument + # one_arg_arange_int = ht.arange(10) + # self.assertIsInstance(one_arg_arange_int, ht.DNDarray) + # self.assertEqual(one_arg_arange_int.shape, (10,)) + # self.assertLessEqual(one_arg_arange_int.lshape[0], 10) + # self.assertEqual(one_arg_arange_int.dtype, ht.int32) + # self.assertEqual(one_arg_arange_int.larray.dtype, torch.int32) + # self.assertEqual(one_arg_arange_int.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(one_arg_arange_int.sum(), 45) + + # # testing one positional float argument + # one_arg_arange_float = ht.arange(10.0) + # self.assertIsInstance(one_arg_arange_float, ht.DNDarray) + # self.assertEqual(one_arg_arange_float.shape, (10,)) + # self.assertLessEqual(one_arg_arange_float.lshape[0], 10) + # self.assertEqual(one_arg_arange_float.dtype, ht.float32) + # self.assertEqual(one_arg_arange_float.larray.dtype, torch.float32) + # self.assertEqual(one_arg_arange_float.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(one_arg_arange_float.sum(), 45.0) + + # # testing two positional integer arguments + # two_arg_arange_int = ht.arange(0, 10) + # self.assertIsInstance(two_arg_arange_int, ht.DNDarray) + # self.assertEqual(two_arg_arange_int.shape, (10,)) + # self.assertLessEqual(two_arg_arange_int.lshape[0], 10) + # self.assertEqual(two_arg_arange_int.dtype, ht.int32) + # self.assertEqual(two_arg_arange_int.larray.dtype, torch.int32) + # self.assertEqual(two_arg_arange_int.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(two_arg_arange_int.sum(), 45) + + # # testing two positional arguments, one being float + # two_arg_arange_float = ht.arange(0.0, 10) + # self.assertIsInstance(two_arg_arange_float, ht.DNDarray) + # self.assertEqual(two_arg_arange_float.shape, (10,)) + # self.assertLessEqual(two_arg_arange_float.lshape[0], 10) + # self.assertEqual(two_arg_arange_float.dtype, ht.float32) + # self.assertEqual(two_arg_arange_float.larray.dtype, torch.float32) + # self.assertEqual(two_arg_arange_float.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(two_arg_arange_float.sum(), 45.0) + + # # testing three positional integer arguments + # three_arg_arange_int = ht.arange(0, 10, 2) + # self.assertIsInstance(three_arg_arange_int, ht.DNDarray) + # self.assertEqual(three_arg_arange_int.shape, (5,)) + # self.assertLessEqual(three_arg_arange_int.lshape[0], 5) + # self.assertEqual(three_arg_arange_int.dtype, ht.int32) + # self.assertEqual(three_arg_arange_int.larray.dtype, torch.int32) + # self.assertEqual(three_arg_arange_int.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(three_arg_arange_int.sum(), 20) + + # # testing three positional arguments, one being float + # three_arg_arange_float = ht.arange(0, 10, 2.0) + # self.assertIsInstance(three_arg_arange_float, ht.DNDarray) + # self.assertEqual(three_arg_arange_float.shape, (5,)) + # self.assertLessEqual(three_arg_arange_float.lshape[0], 5) + # self.assertEqual(three_arg_arange_float.dtype, ht.float32) + # self.assertEqual(three_arg_arange_float.larray.dtype, torch.float32) + # self.assertEqual(three_arg_arange_float.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(three_arg_arange_float.sum(), 20.0) + + # # testing splitting + # three_arg_arange_dtype_float32 = ht.arange(0, 10, 2.0, split=0) + # self.assertIsInstance(three_arg_arange_dtype_float32, ht.DNDarray) + # self.assertEqual(three_arg_arange_dtype_float32.shape, (5,)) + # self.assertLessEqual(three_arg_arange_dtype_float32.lshape[0], 5) + # self.assertEqual(three_arg_arange_dtype_float32.dtype, ht.float32) + # self.assertEqual(three_arg_arange_dtype_float32.larray.dtype, torch.float32) + # self.assertEqual(three_arg_arange_dtype_float32.split, 0) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(three_arg_arange_dtype_float32.sum(axis=0, keepdims=True), 20.0) + + # # testing setting dtype to int16 + # three_arg_arange_dtype_short = ht.arange(0, 10, 2.0, dtype=torch.int16) + # self.assertIsInstance(three_arg_arange_dtype_short, ht.DNDarray) + # self.assertEqual(three_arg_arange_dtype_short.shape, (5,)) + # self.assertLessEqual(three_arg_arange_dtype_short.lshape[0], 5) + # self.assertEqual(three_arg_arange_dtype_short.dtype, ht.int16) + # self.assertEqual(three_arg_arange_dtype_short.larray.dtype, torch.int16) + # self.assertEqual(three_arg_arange_dtype_short.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(three_arg_arange_dtype_short.sum(axis=0, keepdims=True), 20) + + # # testing setting dtype to float64 + # if not self.is_mps: + # three_arg_arange_dtype_float64 = ht.arange(0, 10, 2, dtype=torch.float64) + # self.assertIsInstance(three_arg_arange_dtype_float64, ht.DNDarray) + # self.assertEqual(three_arg_arange_dtype_float64.shape, (5,)) + # self.assertLessEqual(three_arg_arange_dtype_float64.lshape[0], 5) + # self.assertEqual(three_arg_arange_dtype_float64.dtype, ht.float64) + # self.assertEqual(three_arg_arange_dtype_float64.larray.dtype, torch.float64) + # self.assertEqual(three_arg_arange_dtype_float64.split, None) + # # make an in direct check for the sequence, compare against the gaussian sum + # self.assertEqual(three_arg_arange_dtype_float64.sum(axis=0, keepdims=True), 20.0) + + # check_precision = ht.arange(16777217.0, 16777218, 1, dtype=ht.float64) + # self.assertEqual(check_precision.sum(), 16777217) + + # # exceptions + # with self.assertRaises(ValueError): + # ht.arange(-5, 3, split=1) + # with self.assertRaises(TypeError): + # ht.arange() + # with self.assertRaises(TypeError): + # ht.arange(1, 2, 3, 4) + + # def test_array(self): + # # basic array function, unsplit data + # unsplit_data = [[1, 2, 3], [4, 5, 6]] + # a = ht.array(unsplit_data) + # self.assertIsInstance(a, ht.DNDarray) + # self.assertEqual(a.dtype, ht.int64) + # self.assertEqual(a.lshape, (2, 3)) + # self.assertEqual(a.gshape, (2, 3)) + # self.assertEqual(a.split, None) + # self.assertTrue( + # (a.larray == torch.tensor(unsplit_data, device=self.device.torch_device)).all() + # ) + + # # basic array function, unsplit data, different datatype + # tuple_data = ((0, 0), (1, 1)) + # b = ht.array(tuple_data, dtype=ht.int8) + # self.assertIsInstance(b, ht.DNDarray) + # self.assertEqual(b.dtype, ht.int8) + # self.assertEqual(b.larray.dtype, torch.int8) + # self.assertEqual(b.lshape, (2, 2)) + # self.assertEqual(b.gshape, (2, 2)) + # self.assertEqual(b.split, None) + # self.assertTrue( + # ( + # b.larray + # == torch.tensor(tuple_data, dtype=torch.int8, device=self.device.torch_device) + # ).all() + # ) + # if not self.is_mps: + # check_precision = ht.array(16777217.0, dtype=ht.float64) + # self.assertEqual(check_precision.sum(), 16777217) + + # # basic array function, unsplit data, no copy + # torch_tensor = torch.tensor([6, 5, 4, 3, 2, 1], device=self.device.torch_device) + # c = ht.array(torch_tensor, copy=False) + # self.assertIsInstance(c, ht.DNDarray) + # self.assertEqual(c.dtype, ht.int64) + # self.assertEqual(c.lshape, (6,)) + # self.assertEqual(c.gshape, (6,)) + # self.assertEqual(c.split, None) + # self.assertIs(c.larray, torch_tensor) + # self.assertTrue((c.larray == torch_tensor).all()) + + # # basic array function, unsplit data, additional dimensions + # vector_data = [4.0, 5.0, 6.0] + # d = ht.array(vector_data, ndmin=3) + # self.assertIsInstance(d, ht.DNDarray) + # self.assertEqual(d.dtype, ht.float32) + # self.assertEqual(d.lshape, (3, 1, 1)) + # self.assertEqual(d.gshape, (3, 1, 1)) + # self.assertEqual(d.split, None) + # self.assertTrue( + # ( + # d.larray + # == torch.tensor(vector_data, device=self.device.torch_device).reshape(-1, 1, 1) + # ).all() + # ) + + # # basic array function, unsplit data, additional dimensions + # vector_data = [4.0, 5.0, 6.0] + # d = ht.array(vector_data, ndmin=-3) + # self.assertIsInstance(d, ht.DNDarray) + # self.assertEqual(d.dtype, ht.float32) + # self.assertEqual(d.lshape, (1, 1, 3)) + # self.assertEqual(d.gshape, (1, 1, 3)) + # self.assertEqual(d.split, None) + # self.assertTrue( + # ( + # d.larray + # == torch.tensor(vector_data, device=self.device.torch_device).reshape(1, 1, -1) + # ).all() + # ) + + # # distributed array, chunk local data (split), copy True + # if self.is_mps: + # np_dtype = np.float32 + # torch_dtype = torch.float32 + # else: + # np_dtype = np.float64 + # torch_dtype = torch.float64 + # ht_dtype = ht.types.canonical_heat_type(torch_dtype) + + # array_2d = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=np_dtype) + # dndarray_2d = ht.array(array_2d, split=0, copy=True) + # self.assertIsInstance(dndarray_2d, ht.DNDarray) + # self.assertEqual(dndarray_2d.dtype, ht_dtype) + # self.assertEqual(dndarray_2d.gshape, (3, 3)) + # self.assertEqual(len(dndarray_2d.lshape), 2) + # self.assertLessEqual(dndarray_2d.lshape[0], 3) + # self.assertEqual(dndarray_2d.lshape[1], 3) + # self.assertEqual(dndarray_2d.split, 0) + # self.assertTrue( + # ( + # dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) + # ).all() + # ) + + # # distributed array, chunk local data (split), copy False, torch devices + # array_2d = torch.tensor( + # [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], + # dtype=torch_dtype, + # device=self.device.torch_device, + # ) + # dndarray_2d = ht.array(array_2d, split=0, copy=False, dtype=ht_dtype) + # self.assertIsInstance(dndarray_2d, ht.DNDarray) + # self.assertEqual(dndarray_2d.dtype, ht_dtype) + # self.assertEqual(dndarray_2d.gshape, (3, 3)) + # self.assertEqual(len(dndarray_2d.lshape), 2) + # self.assertLessEqual(dndarray_2d.lshape[0], 3) + # self.assertEqual(dndarray_2d.lshape[1], 3) + # self.assertEqual(dndarray_2d.split, 0) + # self.assertTrue( + # ( + # dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) + # ).all() + # ) + # # Check that the array is not a copy, (only really works when the array is not split) + # if ht.communication.MPI_WORLD.size == 1: + # self.assertIs(dndarray_2d.larray, array_2d) + + # # The array should not change as all properties match + # dndarray_2d_new = ht.array(dndarray_2d, split=0, copy=False, dtype=ht_dtype) + # self.assertIsInstance(dndarray_2d_new, ht.DNDarray) + # self.assertEqual(dndarray_2d_new.dtype, ht_dtype) + # self.assertEqual(dndarray_2d_new.gshape, (3, 3)) + # self.assertEqual(len(dndarray_2d_new.lshape), 2) + # self.assertLessEqual(dndarray_2d_new.lshape[0], 3) + # self.assertEqual(dndarray_2d_new.lshape[1], 3) + # self.assertEqual(dndarray_2d_new.split, 0) + # self.assertTrue( + # ( + # dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) + # ).all() + # ) + # # Reuse the same array + # self.assertIs(dndarray_2d_new.larray, dndarray_2d.larray) + + # # Should throw exeception because it causes a resplit + # with self.assertRaises(ValueError): + # dndarray_2d_new = ht.array(dndarray_2d, split=1, copy=False, dtype=ht.double) + + # # The array should not change as all properties match + # dndarray_2d_new = ht.array(dndarray_2d, is_split=0, copy=False, dtype=ht_dtype) + # self.assertIsInstance(dndarray_2d_new, ht.DNDarray) + # self.assertEqual(dndarray_2d_new.dtype, ht_dtype) + # self.assertEqual(dndarray_2d_new.gshape, (3, 3)) + # self.assertEqual(len(dndarray_2d_new.lshape), 2) + # self.assertLessEqual(dndarray_2d_new.lshape[0], 3) + # self.assertEqual(dndarray_2d_new.lshape[1], 3) + # self.assertEqual(dndarray_2d_new.split, 0) + # self.assertTrue( + # ( + # dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device) + # ).all() + # ) + + # # Should throw exeception because of array is split along another dimension + # with self.assertRaises(ValueError): + # dndarray_2d_new = ht.array(dndarray_2d, is_split=1, copy=False, dtype=ht.double) + + # # distributed array, partial data (is_split) + # if ht.communication.MPI_WORLD.rank == 0: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] + # else: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] + # e = ht.array(split_data, ndmin=3, is_split=0) + + # self.assertIsInstance(e, ht.DNDarray) + # self.assertEqual(e.dtype, ht.float32) + # if ht.communication.MPI_WORLD.rank == 0: + # self.assertEqual(e.lshape, (3, 3, 1)) + # else: + # self.assertEqual(e.lshape, (2, 3, 1)) + # self.assertEqual(e.split, 0) + # for index, ele in enumerate(e.gshape): + # if index != e.split: + # self.assertEqual(ele, e.lshape[index]) + # else: + # self.assertGreaterEqual(ele, e.lshape[index]) + + # # exception distributed shapes do not fit + # if ht.communication.MPI_WORLD.size > 1: + # if ht.communication.MPI_WORLD.rank == 0: + # split_data = [4.0, 5.0, 6.0] + # else: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] + + # # this will fail as the shapes do not match + # with self.assertRaises(ValueError): + # ht.array(split_data, is_split=0) + + # # exception distributed shapes do not fit + # if ht.communication.MPI_WORLD.size > 1: + # if ht.communication.MPI_WORLD.rank == 0: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] + # else: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] + + # # this will fail as the shapes do not match on a specific axis (here: 0) + # with self.assertRaises(ValueError): + # ht.array(split_data, is_split=1) + + # # check exception on mutually exclusive split and is_split + # with self.assertRaises(ValueError): + # ht.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], split=0, is_split=0) + + # e = ht.array(split_data, ndmin=-3, is_split=1) + + # self.assertIsInstance(e, ht.DNDarray) + # self.assertEqual(e.dtype, ht.float32) + # if ht.communication.MPI_WORLD.rank == 0: + # self.assertEqual(e.lshape, (1, 3, 3)) + # else: + # self.assertEqual(e.lshape, (1, 2, 3)) + # self.assertEqual(e.split, 1) + # for index, ele in enumerate(e.gshape): + # if index != e.split: + # self.assertEqual(ele, e.lshape[index]) + # else: + # self.assertGreaterEqual(ele, e.lshape[index]) + + # # exception distributed shapes do not fit + # if ht.communication.MPI_WORLD.size > 1: + # if ht.communication.MPI_WORLD.rank == 0: + # split_data = [4.0, 5.0, 6.0] + # else: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] + + # # this will fail as the shapes do not match + # with self.assertRaises(ValueError): + # ht.array(split_data, is_split=0) + + # # exception distributed shapes do not fit + # if ht.communication.MPI_WORLD.size > 1: + # if ht.communication.MPI_WORLD.rank == 0: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] + # else: + # split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] + + # # this will fail as the shapes do not match on a specific axis (here: 0) + # with self.assertRaises(ValueError): + # ht.array(split_data, is_split=1) + + # # check exception on mutually exclusive split and is_split + # with self.assertRaises(ValueError): + # ht.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], split=1, is_split=1) + + # # non iterable type + # with self.assertRaises(TypeError): + # ht.array(map) + # # iterable, but unsuitable type + # with self.assertRaises(TypeError): + # ht.array("abc") + # # iterable, but unsuitable type, with copy=True + # with self.assertRaises(TypeError): + # ht.array("abc", copy=True) + # # unknown dtype + # with self.assertRaises(TypeError): + # ht.array((4,), dtype="a") + # # invalid ndmin + # with self.assertRaises(TypeError): + # ht.array((4,), ndmin=3.0) + # # invalid split axis type + # with self.assertRaises(TypeError): + # ht.array((4,), split="a") + # # invalid split axis value + # with self.assertRaises(ValueError): + # ht.array((4,), split=3) + # # invalid communicator + # with self.assertRaises(TypeError): + # ht.array((4,), comm={}) + # # copy=False but copy is necessary + # data = np.arange(10) + # with self.assertRaises(ValueError): + # ht.array(data, dtype=ht.int32, copy=False) + + # # data already distributed but don't match in shape + # if self.get_size() > 1: + # with self.assertRaises(ValueError): + # dim = self.get_rank() + 1 + # ht.array([[0] * dim] * dim, is_split=0) + + # def test_asarray(self): + # # same heat array + # arr = ht.array([1, 2]) + # self.assertTrue(ht.asarray(arr) is arr) + + # # from distributed python list + # arr = ht.array([1, 2, 3, 4, 5, 6], split=0) + # lst = arr.tolist(keepsplit=True) + # asarr = ht.asarray(lst, is_split=0) + + # self.assertEqual(asarr.shape, arr.shape) + # self.assertEqual(asarr.split, 0) + # self.assertEqual(asarr.device, ht.get_device()) + # self.assertTrue(ht.equal(asarr, arr)) + + # # from numpy array + # arr = np.array([1, 2, 3, 4]) + # asarr = ht.asarray(arr) + + # self.assertTrue(np.all(np.equal(asarr.numpy(), arr))) + + # asarr[0] = 0 + # if asarr.device == ht.cpu: + # self.assertEqual(asarr.numpy()[0], arr[0]) + + # # from torch tensor + # arr = torch.tensor([1, 2, 3, 4], device=self.device.torch_device) + # asarr = ht.asarray(arr) + + # self.assertTrue(torch.equal(asarr.larray, arr)) + + # asarr[0] = 0 + # self.assertEqual(asarr.larray[0].item(), arr[0].item()) + + # def test_empty(self): + # # scalar input + # simple_empty_float = ht.empty(3) + # self.assertIsInstance(simple_empty_float, ht.DNDarray) + # self.assertEqual(simple_empty_float.shape, (3,)) + # self.assertEqual(simple_empty_float.lshape, (3,)) + # self.assertEqual(simple_empty_float.split, None) + # self.assertEqual(simple_empty_float.dtype, ht.float32) + + # # different data type + # simple_empty_uint = ht.empty(5, dtype=ht.bool) + # self.assertIsInstance(simple_empty_uint, ht.DNDarray) + # self.assertEqual(simple_empty_uint.shape, (5,)) + # self.assertEqual(simple_empty_uint.lshape, (5,)) + # self.assertEqual(simple_empty_uint.split, None) + # self.assertEqual(simple_empty_uint.dtype, ht.bool) + + # # multi-dimensional + # elaborate_empty_int = ht.empty((2, 3), dtype=ht.int32) + # self.assertIsInstance(elaborate_empty_int, ht.DNDarray) + # self.assertEqual(elaborate_empty_int.shape, (2, 3)) + # self.assertEqual(elaborate_empty_int.lshape, (2, 3)) + # self.assertEqual(elaborate_empty_int.split, None) + # self.assertEqual(elaborate_empty_int.dtype, ht.int32) + + # # split axis + # elaborate_empty_split = ht.empty((6, 4), dtype=ht.int32, split=0) + # self.assertIsInstance(elaborate_empty_split, ht.DNDarray) + # self.assertEqual(elaborate_empty_split.shape, (6, 4)) + # self.assertLessEqual(elaborate_empty_split.lshape[0], 6) + # self.assertEqual(elaborate_empty_split.lshape[1], 4) + # self.assertEqual(elaborate_empty_split.split, 0) + # self.assertEqual(elaborate_empty_split.dtype, ht.int32) + + # # exceptions + # with self.assertRaises(TypeError): + # ht.empty("(2, 3,)", dtype=ht.float64) + # with self.assertRaises(ValueError): + # ht.empty((-1, 3), dtype=ht.float64) + # with self.assertRaises(TypeError): + # ht.empty((2, 3), dtype=ht.float64, split="axis") + + # def test_empty_like(self): + # # scalar + # like_int = ht.empty_like(3) + # self.assertIsInstance(like_int, ht.DNDarray) + # self.assertEqual(like_int.shape, (1,)) + # self.assertEqual(like_int.lshape, (1,)) + # self.assertEqual(like_int.split, None) + # self.assertEqual(like_int.dtype, ht.int32) + + # # sequence + # like_str = ht.empty_like("abc") + # self.assertIsInstance(like_str, ht.DNDarray) + # self.assertEqual(like_str.shape, (3,)) + # self.assertEqual(like_str.lshape, (3,)) + # self.assertEqual(like_str.split, None) + # self.assertEqual(like_str.dtype, ht.float32) + + # # elaborate tensor + # ones = ht.ones((2, 3), dtype=ht.uint8) + # like_ones = ht.empty_like(ones) + # self.assertIsInstance(like_ones, ht.DNDarray) + # self.assertEqual(like_ones.shape, (2, 3)) + # self.assertEqual(like_ones.lshape, (2, 3)) + # self.assertEqual(like_ones.split, None) + # self.assertEqual(like_ones.dtype, ht.uint8) + + # # elaborate tensor with split + # ones_split = ht.ones((2, 3), dtype=ht.uint8, split=0) + # like_ones_split = ht.empty_like(ones_split) + # self.assertIsInstance(like_ones_split, ht.DNDarray) + # self.assertEqual(like_ones_split.shape, (2, 3)) + # self.assertLessEqual(like_ones_split.lshape[0], 2) + # self.assertEqual(like_ones_split.lshape[1], 3) + # self.assertEqual(like_ones_split.split, 0) + # self.assertEqual(like_ones_split.dtype, ht.uint8) + + # # exceptions + # with self.assertRaises(TypeError): + # ht.empty_like(ones, dtype="abc") + # with self.assertRaises(TypeError): + # ht.empty_like(ones, split="axis") + + # def test_eye(self): + # def get_offset(tensor_array): + # x, y = tensor_array.shape + # for k in range(x): + # for li in range(y): + # if tensor_array[k][li] == 1: + # return k, li + # return x, y + + # shape = 5 + # eye = ht.eye(shape, dtype=ht.uint8, split=1) + # self.assertIsInstance(eye, ht.DNDarray) + # self.assertEqual(eye.dtype, ht.uint8) + # self.assertEqual(eye.shape, (shape, shape)) + # self.assertEqual(eye.split, 1) + + # offset_x, offset_y = get_offset(eye.larray) + # self.assertGreaterEqual(offset_x, 0) + # self.assertGreaterEqual(offset_y, 0) + # x, y = eye.larray.shape + # for i in range(x): + # for j in range(y): + # expected = 1 if i - offset_x is j - offset_y else 0 + # self.assertEqual(eye.larray[i][j], expected) + + # shape = (10, 20) + # eye = ht.eye(shape, dtype=ht.float32) + # self.assertIsInstance(eye, ht.DNDarray) + # self.assertEqual(eye.dtype, ht.float32) + # self.assertEqual(eye.shape, shape) + # self.assertEqual(eye.split, None) + + # offset_x, offset_y = get_offset(eye.larray) + # self.assertGreaterEqual(offset_x, 0) + # self.assertGreaterEqual(offset_y, 0) + # x, y = eye.larray.shape + # for i in range(x): + # for j in range(y): + # expected = 1.0 if i - offset_x is j - offset_y else 0.0 + # self.assertEqual(eye.larray[i][j], expected) + + # shape = (10,) + # eye = ht.eye(shape, dtype=ht.int32, split=0) + # self.assertIsInstance(eye, ht.DNDarray) + # self.assertEqual(eye.dtype, ht.int32) + # self.assertEqual(eye.shape, shape * 2) + # self.assertEqual(eye.split, 0) + + # offset_x, offset_y = get_offset(eye.larray) + # self.assertGreaterEqual(offset_x, 0) + # self.assertGreaterEqual(offset_y, 0) + # x, y = eye.larray.shape + # for i in range(x): + # for j in range(y): + # expected = 1 if i - offset_x is j - offset_y else 0 + # self.assertEqual(eye.larray[i][j], expected) + + # shape = (11, 30) + # eye = ht.eye(shape, split=1, dtype=ht.float32) + # self.assertIsInstance(eye, ht.DNDarray) + # self.assertEqual(eye.dtype, ht.float32) + # self.assertEqual(eye.shape, shape) + # self.assertEqual(eye.split, 1) def test_from_partitioned(self): a = ht.zeros((120, 120), split=0) diff --git a/heat/core/tests/test_indexing.py b/heat/core/tests/test_indexing.py index 4707aa28ab..4ff9ead7a2 100644 --- a/heat/core/tests/test_indexing.py +++ b/heat/core/tests/test_indexing.py @@ -9,29 +9,35 @@ def test_nonzero(self): a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) cond = a > 3 nz = ht.nonzero(cond) - self.assertEqual(nz.gshape, (5, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, None) + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 5) + self.assertEqual(nz[0].dtype, ht.int64) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 nz = cond.nonzero() - self.assertEqual(nz.gshape, (6, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, 0) - a[nz] = 10.0 + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 6) + self.assertEqual(nz[0].dtype, ht.int64) + a[nz] = 10 self.assertEqual(ht.all(a[nz] == 10), 1) + # attribute error + a = a.numpy() + with self.assertRaises(TypeError): + ht.nonzero(a) + def test_where(self): # cases to test # no x and y a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None) cond = a > 3 wh = ht.where(cond) - self.assertEqual(wh.gshape, (6, 2)) - self.assertEqual(wh.dtype, ht.int64) - self.assertEqual(wh.split, None) + self.assertEqual(len(wh), 2) + self.assertEqual(wh[0].gshape[0], 6) + self.assertEqual(wh[0].dtype, ht.int64) + self.assertEqual(wh[0].split, None) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index 358c99e857..4c4167e9e4 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -390,6 +390,7 @@ def test_bucketize(self): self.assertEqual(a.dtype, ht.int64) self.assertTrue(a.shape, v.shape) + torch.manual_seed(42) boundaries, _ = torch.sort(torch.rand(5, device=self.device.torch_device)) v = torch.rand(6, device=self.device.torch_device) t = torch.bucketize(v, boundaries, out_int32=True) @@ -569,6 +570,7 @@ def test_histc(self): self.assertEqual(res.dtype, ht.float32) self.assertEqual(res.device, self.device) self.assertEqual(res.split, None) + print(f"\n\n ############ Debug ############ \n {res.larray=} \n {comp=} #################### \n\n") self.assertTrue(torch.equal(res.larray, comp)) a = ht.array(c, split=1)