Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust DenseMatrix._get_col_stds #436

Merged
merged 14 commits into from
Jan 29, 2025
9 changes: 8 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
name: CI
on: [push]

on:
pull_request:
branches:
- main
push:

jobs:
pre-commit-checks:
Expand Down Expand Up @@ -48,6 +53,8 @@ jobs:
run: pixi run -e ${{ matrix.environment }} install-nightlies
- name: Install repository
run: pixi run -e ${{ matrix.environment }} postinstall
- name: Run Malte's example
run: pixi run -e ${{ matrix.environment }} python example.py
- name: Run pytest
run: pixi run -e ${{ matrix.environment }} test -nauto -m "not high_memory"
- name: Run doctest
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Changelog
Unreleased
----------

**Bug fix:**

- A more robust :meth:`DenseMatrix._get_col_stds` results in more accurate :meth:`StandardizedMatrix.sandwich` results.

**Other changes:**

- Build wheel for pypi on python 3.13.
Expand Down
37 changes: 37 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np

import tabmat

np.set_printoptions(suppress=True)

for dtype in [np.float32, np.float64]:
X = np.array(
[
[46.231056, 126.05263, 144.46439],
[46.231224, 128.66818, 0.7667693],
[46.231186, 104.97506, 193.8872],
[46.230835, 130.10156, 143.88954],
[46.230896, 116.76007, 7.5629334],
],
dtype=dtype,
)
v = np.array(
[0.12428328, 0.67062443, 0.6471895, 0.6153851, 0.38367754], dtype=dtype
)

weights = np.full(X.shape[0], 1 / X.shape[0], dtype=dtype)

stmat, out_means, col_stds = tabmat.DenseMatrix(X).standardize(weights, True, True)

print(stmat.toarray().T @ v)
print(stmat.transpose_matvec(v))

# compute by hand
res = np.zeros(X.shape[1], dtype=dtype)
for col in range(X.shape[1]):
res[col] += (
stmat.shift[col] + stmat.mult[col] * stmat.mat.toarray()[:, col]
) @ v

print(res)
print("\n")
4 changes: 2 additions & 2 deletions src/tabmat/dense_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _cross_sandwich(
raise TypeError

def _get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray:
"""Get standard deviations of columns."""
sqrt_arg = transpose_square_dot_weights(self._array, weights) - col_means**2
"""Get standard deviations of columns using weights `weights`."""
sqrt_arg = transpose_square_dot_weights(self._array, weights, col_means)
# Minor floating point errors above can result in a very slightly
# negative sqrt_arg (e.g. -5e-16). We just set those values equal to
# zero.
Expand Down
6 changes: 3 additions & 3 deletions src/tabmat/ext/dense.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dense_matvec(np.ndarray X, floating[:] v, int[:] rows, int[:] cols):
raise Exception("The matrix X is not contiguous.")
return out

def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
def transpose_square_dot_weights(np.ndarray X, floating[:] weights, floating[:] shift):
cdef floating* Xp = <floating*>X.data
cdef int nrows = weights.shape[0]
cdef int ncols = X.shape[1]
Expand All @@ -112,11 +112,11 @@ def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
if X.flags["C_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[i * ncols + j] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[i * ncols + j] - shift[j]) ** 2)
elif X.flags["F_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[j * nrows + i] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[j * nrows + i] - shift[j]) ** 2)
else:
raise Exception("The matrix X is not contiguous.")
return out
10 changes: 9 additions & 1 deletion src/tabmat/standardized_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ def sandwich(
if not hasattr(d, "dtype"):
d = np.asarray(d)
check_sandwich_compatible(self, d)

# stat_mat = mat * mult[newaxis, :] + shift[newaxis, :]
# stat_mat.T @ d[:, newaxis] * stat_mat
# = mult[:, newaxis] * mat.T @ d[:, newaxis] * mat * mult[newaxis, :] + (1)
# mult[:, newaxis] * mat.T @ d[:, newaxis] * np.outer(ones, shift) + (2)
# shift[:, newaxis] @ d[:, newaxis] * mat * mult[newaxis, :] + (3)
# shift[:, newaxis] @ d[:, newaxis] * shift[newaxis, :] (4)
#
# (1) = self.mat.sandwich(d) * np.outer(limited_mult, limited_mult)
# (2) = mult * self.transpose_matvec(d) * shift[newaxis, :]
if rows is not None or cols is not None:
setup_rows, setup_cols = setup_restrictions(self.shape, rows, cols)
if rows is not None:
Expand Down
Loading