Skip to content

Commit

Permalink
CategoricalMatrix column indexing (#110)
Browse files Browse the repository at this point in the history
* categorical indexing

* moved test
  • Loading branch information
MarcAntoineSchmidtQC authored Sep 22, 2021
1 parent 24c3413 commit 462e65e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/quantcore/matrix/categorical_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -7,6 +7,7 @@
from .ext.categorical import matvec, sandwich_categorical, transpose_matvec
from .ext.split import sandwich_cat_cat, sandwich_cat_dense
from .matrix_base import MatrixBase
from .sparse_matrix import SparseMatrix
from .util import (
check_matvec_out_shape,
check_transpose_matvec_out_shape,
Expand All @@ -15,6 +16,23 @@
)


def _is_indexer_full_length(full_length: int, indexer: Any):
if isinstance(indexer, int):
return full_length == 1
elif isinstance(indexer, list):
if (np.asarray(indexer) > full_length - 1).any():
raise IndexError("Index out-of-range.")
return len(set(indexer)) == full_length
elif isinstance(indexer, np.ndarray):
if (indexer > full_length - 1).any():
raise IndexError("Index out-of-range.")
return len(np.unique(indexer)) == full_length
elif isinstance(indexer, slice):
return len(range(*indexer.indices(full_length))) == full_length
else:
raise ValueError(f"Indexing with {type(indexer)} is not allowed.")


def _none_to_slice(arr: Optional[np.ndarray], n: int) -> Union[slice, np.ndarray]:
if arr is None or len(arr) == n:
return slice(None, None, None)
Expand Down Expand Up @@ -262,8 +280,14 @@ def get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray
def __getitem__(self, item):
if isinstance(item, tuple):
row, col = item
if not (isinstance(col, slice) and col == slice(None, None, None)):
raise IndexError("Only column indexing is supported.")
if _is_indexer_full_length(self.shape[1], col):
if isinstance(row, int):
row = [row]
return CategoricalMatrix(self.cat[row])
else:
# return a SparseMatrix if we subset columns
# TODO: this is inefficient. See issue #101.
return SparseMatrix(self.tocsr()[row, col], dtype=self.dtype)
else:
row = item
if isinstance(row, int):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,9 @@ def test_nulls(mi_element):
vec = [0, mi_element, 1]
with pytest.raises(ValueError, match="Categorical data can't have missing values"):
CategoricalMatrix(vec)


def test_categorical_indexing():
catvec = [0, 1, 2, 0, 1, 2]
mat = CategoricalMatrix(catvec)
mat[:, [0, 1]]

0 comments on commit 462e65e

Please sign in to comment.