Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CategoricalMatrix column indexing #110

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/quantcore/matrix/categorical_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -7,6 +7,7 @@
from .ext.categorical import matvec, sandwich_categorical, transpose_matvec
from .ext.split import sandwich_cat_cat, sandwich_cat_dense
from .matrix_base import MatrixBase
from .sparse_matrix import SparseMatrix
from .util import (
check_matvec_out_shape,
check_transpose_matvec_out_shape,
Expand All @@ -15,6 +16,23 @@
)


def _is_indexer_full_length(full_length: int, indexer: Any):
if isinstance(indexer, int):
return full_length == 1
elif isinstance(indexer, list):
if (np.asarray(indexer) > full_length - 1).any():
raise IndexError("Index out-of-range.")
return len(set(indexer)) == full_length
elif isinstance(indexer, np.ndarray):
if (indexer > full_length - 1).any():
raise IndexError("Index out-of-range.")
return len(np.unique(indexer)) == full_length
elif isinstance(indexer, slice):
return len(range(*indexer.indices(full_length))) == full_length
else:
raise ValueError(f"Indexing with {type(indexer)} is not allowed.")


def _none_to_slice(arr: Optional[np.ndarray], n: int) -> Union[slice, np.ndarray]:
if arr is None or len(arr) == n:
return slice(None, None, None)
Expand Down Expand Up @@ -262,8 +280,14 @@ def get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray
def __getitem__(self, item):
if isinstance(item, tuple):
row, col = item
if not (isinstance(col, slice) and col == slice(None, None, None)):
raise IndexError("Only column indexing is supported.")
if _is_indexer_full_length(self.shape[1], col):
if isinstance(row, int):
row = [row]
return CategoricalMatrix(self.cat[row])
else:
# return a SparseMatrix if we subset columns
# TODO: this is inefficient. See issue #101.
return SparseMatrix(self.tocsr()[row, col], dtype=self.dtype)
else:
row = item
if isinstance(row, int):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,9 @@ def test_nulls(mi_element):
vec = [0, mi_element, 1]
with pytest.raises(ValueError, match="Categorical data can't have missing values"):
CategoricalMatrix(vec)


def test_categorical_indexing():
catvec = [0, 1, 2, 0, 1, 2]
mat = CategoricalMatrix(catvec)
mat[:, [0, 1]]