Skip to content

Commit

Permalink
Use scipy.linalg.lstsq instead of np.linalg.lstsq (#83)
Browse files Browse the repository at this point in the history
* Use scipy.linalg.lstsq

* Remove line_profielr line.

* gelsy does not support empty target.
  • Loading branch information
mlondschien authored Jun 7, 2024
1 parent d758176 commit 485c953
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
37 changes: 17 additions & 20 deletions ivmodels/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import scipy

try:
import pandas as pd
Expand Down Expand Up @@ -36,13 +37,24 @@ def proj(Z, *args):
raise ValueError(f"Shape mismatch: Z.shape={Z.shape}, f.shape={f.shape}.")

if len(args) == 1:
return np.dot(Z, np.linalg.lstsq(Z, args[0], rcond=None)[0])
# The gelsy driver raises in this case - we handle it separately
if len(args[0].shape) == 2 and args[0].shape[1] == 0:
return np.zeros_like(args[0])

return np.dot(
Z, scipy.linalg.lstsq(Z, args[0], cond=None, lapack_driver="gelsy")[0]
)

csum = np.cumsum([f.shape[1] if len(f.shape) == 2 else 1 for f in args])
csum = [0] + csum.tolist()

fs = np.hstack([f.reshape(Z.shape[0], -1) for f in args])
fs = np.dot(Z, np.linalg.lstsq(Z, fs, rcond=None)[0])

if fs.shape[1] == 0:
# The gelsy driver raises in this case - we handle it separately
return (*(np.zeros_like(f) for f in args),)

fs = np.dot(Z, scipy.linalg.lstsq(Z, fs, cond=None, lapack_driver="gelsy")[0])

return (
*(fs[:, i:j].reshape(f.shape) for i, j, f in zip(csum[:-1], csum[1:], args)),
Expand All @@ -68,26 +80,11 @@ def oproj(Z, *args):
if Z is None:
return (*args,)

for f in args:
if len(f.shape) > 2:
raise ValueError(
f"*args should have shapes (n, d_f) or (n,). Got {f.shape}."
)
if f.shape[0] != Z.shape[0]:
raise ValueError(f"Shape mismatch: Z.shape={Z.shape}, f.shape={f.shape}.")

if len(args) == 1:
return args[0] - np.dot(Z, np.linalg.lstsq(Z, args[0], rcond=None)[0])

csum = np.cumsum([f.shape[1] if len(f.shape) == 2 else 1 for f in args])
csum = [0] + csum.tolist()

fs = np.hstack([f.reshape(Z.shape[0], -1) for f in args])
fs = fs - np.dot(Z, np.linalg.lstsq(Z, fs, rcond=None)[0])
return args[0] - proj(Z, args[0])

return (
*(fs[:, i:j].reshape(f.shape) for i, j, f in zip(csum[:-1], csum[1:], args)),
)
else:
return (*(x - x_proj for x, x_proj in zip(args, proj(Z, *args))),)


def to_numpy(x):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_proj_multiple_args():
)
assert np.allclose(proj(X, z1), X @ np.linalg.inv(X.T @ X) @ X.T @ z1)
assert np.allclose(proj(X, z2), X @ np.linalg.inv(X.T @ X) @ X.T @ z2)
assert np.allclose(
proj(X, z2, z2),
X @ np.linalg.inv(X.T @ X) @ X.T @ z2,
X @ np.linalg.inv(X.T @ X) @ X.T @ z2,
)
assert np.allclose(proj(X, z3), X @ np.linalg.inv(X.T @ X) @ X.T @ z3)


Expand Down

0 comments on commit 485c953

Please sign in to comment.