-
Notifications
You must be signed in to change notification settings - Fork 1
Add support for deflation in arnoldi_decomposition #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
26d9def
0a806ef
ddbe6ed
63b9552
0b4a5e3
eb141be
b03f465
c34ba5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| .PHONY: lint tests | ||
|
|
||
| lint: | ||
| ruff check src tests | ||
|
|
||
| tests: | ||
| uv run pytest tests -s | ||
|
|
||
| tests-all: lint tests |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,31 @@ | ||||||||||
| import numpy as np | ||||||||||
|
|
||||||||||
|
|
||||||||||
| from .decomposition import RitzDecomposition, arnoldi_decomposition | ||||||||||
| from .utils import rand_normalized_vector | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def naive_explicit_restarts(A, m=None, *, stopping_criterion=None, max_restarts=10): | ||||||||||
| tol = stopping_criterion or np.sqrt(np.finfo(A.dtype).eps) | ||||||||||
| dtype = np.promote_types(A.dtype, np.complex64) | ||||||||||
|
|
||||||||||
| n = A.shape[0] | ||||||||||
| k = 1 # Naive arnoldi w/o restart only really works for 1 eigenvalue | ||||||||||
| m = m or min(max(2 * k + 1, 20), n) | ||||||||||
|
|
||||||||||
| V = np.zeros((n, m+1), dtype) | ||||||||||
| H = np.zeros((m+1, m), dtype) | ||||||||||
|
Comment on lines
+16
to
+17
|
||||||||||
| V = np.zeros((n, m+1), dtype) | |
| H = np.zeros((m+1, m), dtype) | |
| V = np.zeros((n, m+1), dtype=dtype) | |
| H = np.zeros((m+1, m), dtype=dtype) |
Copilot
AI
Sep 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division by zero risk if both ritz.values[0] and tol are zero. Consider using max(abs(ritz.values[0]), tol) or adding an explicit check for zero denominator.
| if residuals[0] / max(ritz.values[0], tol) < tol: | |
| if residuals[0] / max(abs(ritz.values[0]), tol) < tol: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # Max retries for short tests | ||
| MAX_RETRIES_SHORT = 3 | ||
|
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,11 +7,12 @@ | |||||||||||||||
| from arnoldi.matrices import mark, laplace | ||||||||||||||||
| from arnoldi.utils import rand_normalized_vector | ||||||||||||||||
|
|
||||||||||||||||
| from .common import MAX_RETRIES_SHORT | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| ATOL = 1e-8 | ||||||||||||||||
| RTOL = 1e-4 | ||||||||||||||||
| # Max retries for short tests | ||||||||||||||||
| MAX_RETRIES_SHORT = 3 | ||||||||||||||||
|
|
||||||||||||||||
| norm = np.linalg.norm | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -25,7 +26,7 @@ def inject_noise(A): | |||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def basis_vector(n, k, dtype=np.int64): | ||||||||||||||||
| """Create the basis vector e_k in R^n, aka e_k is (n,), and with e_k[k] = | ||||||||||||||||
| """Create the basis vector e_k in R^n, aka e_k is (n,), and with e_k[k-1] = | ||||||||||||||||
| 1 | ||||||||||||||||
| """ | ||||||||||||||||
|
Comment on lines
+29
to
31
|
||||||||||||||||
| """Create the basis vector e_k in R^n, aka e_k is (n,), and with e_k[k-1] = | |
| 1 | |
| """ | |
| """ | |
| Create the k-th standard basis vector in R^n (of shape (n,)), where the entry at index (k-1) | |
| is 1 and all other entries are 0. (Note: k is 1-based, as in mathematical notation; Python uses 0-based indexing.) | |
| """ |
Copilot
AI
Sep 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable V_m is not defined in this scope. Based on the context, this should use the existing variable V that is passed to the function, sliced appropriately as V[:, :m].
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import pytest | ||
|
|
||
| from arnoldi.explicit_restarts import naive_explicit_restarts | ||
| from arnoldi.matrices import mark | ||
|
|
||
| from .common import MAX_RETRIES_SHORT | ||
|
|
||
|
|
||
| class TestNaiveExplicitRestarts: | ||
| @pytest.mark.parametrize( | ||
| "restarts, digits", [(1, 0), (2, 1), (3, 3), (4, 5), (5, 6)] | ||
| ) | ||
| @pytest.mark.flaky(reruns=MAX_RETRIES_SHORT) | ||
| def test_mark10(self, restarts, digits): | ||
| # For the numerical value, see table 6.2 of Numerical Methods for Large | ||
| # Eigenvalue Problems, 2nd edition. | ||
|
|
||
| ## Given | ||
| A = mark(10) | ||
| m = 10 | ||
|
|
||
| ## When | ||
| ritz, *_ = naive_explicit_restarts(A, m, max_restarts=restarts) | ||
|
|
||
| ## Then | ||
| assert ritz.compute_true_residuals(A) <= 2 * 10**(-digits) | ||
|
|
||
| @pytest.mark.flaky(reruns=MAX_RETRIES_SHORT) | ||
| def test_convergence(self): | ||
| ## Given | ||
| A = mark(10) | ||
| m = 20 | ||
| atol = 1e-6 | ||
|
|
||
| ## When | ||
| ritz, has_converged, *_ = naive_explicit_restarts(A, m, | ||
| max_restarts=200, | ||
| stopping_criterion=atol) | ||
|
|
||
| ## Then | ||
| assert ritz.compute_true_residuals(A) <= atol | ||
| assert has_converged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing parentheses around
dtypeparameter. Should beV = np.zeros((n, m+1), dtype=dtype)to properly specify the dtype parameter.