diff --git a/.github/workflows/test_src.yml b/.github/workflows/test_src.yml index ac037bab..1a37b23d 100644 --- a/.github/workflows/test_src.yml +++ b/.github/workflows/test_src.yml @@ -22,11 +22,11 @@ jobs: - windows-latest - macos-latest python-version: - - "3.9" - "3.10" - "3.11" - "3.12" - "3.13" + - "3.14" runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -39,5 +39,8 @@ jobs: python -m pip install --upgrade pip python -m pip install tox - name: Run tests + if: runner.os == 'Windows' + env: + MPLBACKEND: Agg # Fix a bug in GitHub Actions for Windows run: | python -m tox -e py${{ matrix.python-version }} diff --git a/.github/workflows/test_style.yml b/.github/workflows/test_style.yml index 40869206..96db50b6 100644 --- a/.github/workflows/test_style.yml +++ b/.github/workflows/test_style.yml @@ -19,7 +19,7 @@ jobs: - name: Install Python uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.14 - name: Install tox run: | python -m pip install --upgrade pip diff --git a/docs/source/contributing/how_to_contribute.md b/docs/source/contributing/how_to_contribute.md index 52a01064..a442cbca 100644 --- a/docs/source/contributing/how_to_contribute.md +++ b/docs/source/contributing/how_to_contribute.md @@ -58,7 +58,7 @@ Style checks, unit tests, and documentation builds are managed with [`tox`](http For each of these tasks, `tox` creates a new virtual environment, installs the dependencies (e.g., `pytest` for running unit tests), and executes the task recipe. :::{note} -Unit tests are executed for Python 3.9 through 3.12 if they are installed on your system. +Unit tests are executed for Python 3.10 through 3.14 if they are installed on your system. The best way to install multiple Python versions varies by platform; for MacOS, [we suggest](https://stackoverflow.com/questions/36968425/how-can-i-install-multiple-versions-of-python-on-latest-os-x-and-use-them-in-par#answer-65094122) using [Homebrew](https://brew.sh/). ```shell diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md index 318b9b98..1b6f0a96 100644 --- a/docs/source/contributing/testing.md +++ b/docs/source/contributing/testing.md @@ -176,8 +176,8 @@ If all tests pass, a line coverage report will be generated. Open `htmlcov/index.html` in a browser to view the report. :::{note} -Running `tox` without any arguments tests the code for Python 3.9 through 3.12 (if they are installed on your system). -To test a single Python version, use `tox -e py310` for Python 3.10, `tox -e py311` for Python3.11, and so on. +Running `tox` without any arguments tests the code for Python 3.10 through 3.14 (if they are installed on your system). +To test a single Python version, use `tox -e py314` for Python 3.14, `tox -e py313` for Python3.13, and so on. ::: ## GitHub Actions diff --git a/docs/source/opinf/changelog.md b/docs/source/opinf/changelog.md index c63ee877..ab5c9c9f 100644 --- a/docs/source/opinf/changelog.md +++ b/docs/source/opinf/changelog.md @@ -5,6 +5,15 @@ New versions may introduce substantial new features or API adjustments. ::: +## Version 0.5.17 + +- Dropped support for Python 3.9, added test coverage for Python 3.14. +- Tikhonov / L2 least-squares now take an initial guess +- Expansion of operators to larger or smaller bases (for nested OpInf) +- Parametric and nonparametric polynomial operators of arbitrary order +- Bases now have a `fit_compress()` method +- Update `utils.TimedBlock` to use `time.process_time()` + ## Version 0.5.16 Backend improvements to the regularization selection procedure. diff --git a/pyproject.toml b/pyproject.toml index 003d06fe..6621c906 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ ] maintainers = [{ name = "Shane A. McQuarrie", email = "smcquar@sandia.gov" }] readme = { file = "README.md", content-type = "text/markdown" } -requires-python = ">=3.9" +requires-python = ">=3.10" license = { file = "LICENSE" } keywords = [ "operator inference", diff --git a/src/opinf/__init__.py b/src/opinf/__init__.py index 9d6c372f..72585310 100644 --- a/src/opinf/__init__.py +++ b/src/opinf/__init__.py @@ -7,7 +7,7 @@ https://github.com/Willcox-Research-Group/rom-operator-inference-Python3 """ -__version__ = "0.5.16" +__version__ = "0.5.17" from . import ( basis, diff --git a/src/opinf/basis/_pod.py b/src/opinf/basis/_pod.py index 8a549f10..f38584b9 100644 --- a/src/opinf/basis/_pod.py +++ b/src/opinf/basis/_pod.py @@ -120,10 +120,26 @@ def method_of_snapshots( eigvals = eigvals[::-1] eigvecs = eigvecs[:, ::-1] + # can at most have as many non-zero eigenvalues as there was data + if eigvals.shape[0] > states.shape[0]: + eigvals[states.shape[0] :] = 0 + if eigvals.shape[0] > states.shape[1]: + eigvals[states.shape[1] :] = 0 + # By definition the Gramian is symmetric positive semi-definite. # If any eigenvalues are smaller than zero, they are only measuring # numerical error and can be truncated. positives = eigvals > max(minthresh, abs(np.min(eigvals))) + + if sum(positives) > states.shape[0]: + print(eigvals) + print(positives) + raise RuntimeError( + f"""selected more non-zero singular values {sum(positives)} + than there are state dimensions ({states.shape[0]})""" + ) + + # can at most have as many non-zero entries as there were state dofs eigvecs = eigvecs[:, positives] eigvals = eigvals[positives] @@ -131,6 +147,37 @@ def method_of_snapshots( svals = np.sqrt(eigvals * n_states) V = states @ (eigvecs / svals) + # if basis functions are not orthonormal, apply Gram-Schmidt + counter = 0 + + def inner_product(V): + if inner_product_matrix is None: + return V.T @ V + return V.T @ _Wmult(inner_product_matrix, V) + + while ( + not np.isclose(inner_product(V), np.eye(V.shape[1])).all() + ) and counter < 10: + for i in range(V.shape[1]): + v_i = V[:, i] + for j in range(i): + v_i = ( + v_i + - (v_i.T @ _Wmult(inner_product_matrix, V[:, j])) * V[:, j] + ) + norm_vi = np.sqrt(v_i.T @ _Wmult(inner_product_matrix, v_i)) + V[:, i] = v_i / norm_vi + + if V.shape[0] != states.shape[0]: + print(V.shape, counter) + raise RuntimeError( + "something went seriously wrong in computation of V" + ) + counter += 1 + + if not np.isclose(inner_product(V), np.eye(V.shape[1])).all(): + raise RuntimeWarning("Computed basis functions are not orthonormal.") + return V, svals, eigvecs.T diff --git a/src/opinf/ddt/_base.py b/src/opinf/ddt/_base.py index bca7c1af..4c22ae94 100644 --- a/src/opinf/ddt/_base.py +++ b/src/opinf/ddt/_base.py @@ -144,7 +144,6 @@ def estimate(self, states, inputs=None): """ raise NotImplementedError # pragma: no cover - @abc.abstractmethod def mask(self, arr): """Map an array from the training time domain to the domain of the estimated time derivatives. @@ -171,7 +170,7 @@ def mask(self, arr): >>> Q3.shape == Q.shape True """ - raise NotImplementedError # pragma: no cover + return arr # Verification ------------------------------------------------------------ def verify_shapes(self, r: int = 5, m: int = 3): # pragma: no cover diff --git a/src/opinf/lstsq/_tikhonov.py b/src/opinf/lstsq/_tikhonov.py index 02ef0c71..df50dd57 100644 --- a/src/opinf/lstsq/_tikhonov.py +++ b/src/opinf/lstsq/_tikhonov.py @@ -54,18 +54,45 @@ class _BaseRegularizedSolver(SolverTemplate): .. math:: (\D\trp\D + \bfGamma_i\trp\bfGamma_i)\ohat_i = \D\trp\z_i. + + If an initial guess :math:`\Ohat^{(0)}` is provided, the minimization + problem is changed to + + .. math:: + \argmin_{\ohat_i} + \|\D\ohat_i - \z_i\|_2^2 + + \sum_{i=1}^{r}\|\bfGamma_i(\ohat_i - \ohat_i^{(0)})\|_2^2, + \quad i = 1, \ldots, r. """ + def __init__(self, initial_guess=None): + SolverTemplate.__init__(self=self) + self.initial_guess = initial_guess + # Properties: regularization ---------------------------------------------- @abc.abstractmethod def regularizer(self): """Regularization scalar, matrix, or list of these.""" raise NotImplementedError # pragma: no cover + @property + def initial_guess(self): + r"""Initial guess :math:`\Ohat^{(0)}` for the regression solution.""" + return self.__Ohat0 + + @initial_guess.setter + def initial_guess(self, Ohat0): + if Ohat0 is not None and not hasattr(Ohat0, "shape"): + raise TypeError("initial_guess must be ndarray or None") + self.__Ohat0 = Ohat0 + # Main methods ------------------------------------------------------------ def fit(self, data_matrix: np.ndarray, lhs_matrix: np.ndarray): r"""Verify dimensions and save the data matrices. + If an :attr:`initial_guess` was provided during the initialization, + use ``lhs_matrix - initial_guess`` for the residual data matrix. + Parameters ---------- data_matrix : (k, d) ndarray @@ -74,6 +101,39 @@ def fit(self, data_matrix: np.ndarray, lhs_matrix: np.ndarray): "Left-hand side" data matrix :math:`\Z` (not its transpose!). If one-dimensional, assume :math:`r = 1`. """ + if self.initial_guess is not None: + + # check if initial guess fits to data matrix shape + if self.initial_guess.shape[1] != data_matrix.shape[1]: + raise RuntimeError( + f"""In _BaseRegularizedSolver.fit: + Shape of data matrix ({data_matrix.shape}) + does not match initial guess + ({self.initial_guess.shape})""" + ) + + # check if initial guess shape fits to lhs matrix shape + if self.initial_guess.shape[0] != lhs_matrix.shape[0]: + raise RuntimeError( + f"""In _BaseRegularizedSolver.fit: + Shape of lhs matrix ({lhs_matrix.shape}) + does not match initial guess + (shape {self.initial_guess.shape})""" + ) + if ( + len(lhs_matrix.shape) == 1 + and len(self.initial_guess.shape) != 1 + ): + raise RuntimeError( + f"""In _BaseRegularizedSolver.fit: + Shape of lhs matrix (shape {lhs_matrix.shape} columns) + does not match initial guess + (shape {self.initial_guess.shape})""" + ) + + # adjust lhs matrix + lhs_matrix -= (data_matrix @ self.initial_guess.T).T + SolverTemplate.fit(self, data_matrix, lhs_matrix) if self.k < self.d: warnings.warn( @@ -112,6 +172,20 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: """Compute the residual of the regularized regression problem.""" raise NotImplementedError # pragma: no cover + def _add_initial_guess(self, Ohat): + """Adds the initial guess if one was provided. + Does nothing otherwise.""" + if self.initial_guess is None: + return Ohat + return Ohat + self.initial_guess + + def _subtract_initial_guess(self, Ohat): + """Subtracts the initial guess if one was provided. + Does nothing otherwise.""" + if self.initial_guess is None: + return Ohat + return Ohat - self.initial_guess + # Persistence ------------------------------------------------------------- def reset(self) -> None: """Reset the solver by deleting data matrices and the regularizer.""" @@ -151,6 +225,9 @@ def _save(self, savefile, overwrite=False, extras=tuple()): for attr in extras: hf.create_dataset(attr, data=getattr(self, attr)) + if self.initial_guess is not None: + hf.create_dataset("initial_guess", data=self.initial_guess) + @classmethod def _load(cls, loadfile: str, extras=tuple()): """Load a serialized solver from an HDF5 file, created previously from @@ -176,10 +253,12 @@ def _load(cls, loadfile: str, extras=tuple()): if cls is L2Solver: reg = reg[0] + if "initial_guess" in hf: + cls.initial_guess = hf["initial_guess"] + options = cls._load_dict(hf, "options") kwargs = dict( - regularizer=reg, - lapack_driver=options["lapack_driver"], + regularizer=reg, lapack_driver=options["lapack_driver"] ) if issubclass(cls, TikhonovSolver): @@ -228,6 +307,13 @@ class L2Solver(_BaseRegularizedSolver): :math:`\bfSigma^{*}` is a diagonal matrix with :math:`i`-th diagonal entry :math:`\Sigma_{i,i}^{*} = \Sigma_{i,i}/(\Sigma_{i,i}^{2} + \lambda^2).` + If an initial guess :math:`\Ohat^{(0)}` is provided, the original + minimization problem is changed to + + .. math:: + \argmin_{\Ohat}\|\D\Ohat\trp - \Z\trp\|_F^2 + + \|\lambda(\Ohat - \Ohat^{(0)})\trp\|_F^2 + Parameters ---------- regularizer : float @@ -235,11 +321,19 @@ class L2Solver(_BaseRegularizedSolver): lapack_driver : str LAPACK routine for computing the singular value decomposition. See :func:`scipy.linalg.svd()`. + initial_guess : ndarray or None + Initial guess :math:`\Ohat^{(0)}` for the regression solution. + Defaults to zero. """ - def __init__(self, regularizer=None, lapack_driver: str = "gesdd"): + def __init__( + self, + regularizer=None, + lapack_driver: str = "gesdd", + initial_guess=None, + ): """Store the regularizer and initialize attributes.""" - _BaseRegularizedSolver.__init__(self) + _BaseRegularizedSolver.__init__(self, initial_guess=initial_guess) self.regularizer = regularizer self.__options = types.MappingProxyType( dict(full_matrices=False, lapack_driver=lapack_driver) @@ -323,7 +417,8 @@ def solve(self) -> np.ndarray: raise AttributeError("solver regularizer not set") svals = self._svals.reshape((-1, 1)) svals_inv = svals / (svals**2 + self.regularizer**2) - return (self._ZPhi * svals_inv.T) @ self._PsiT + Odiff = (self._ZPhi * svals_inv.T) @ self._PsiT + return self._add_initial_guess(Odiff) def posterior(self): r"""Solve the Bayesian operator inference regression, constructing the @@ -403,11 +498,13 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: Specifically, given a potential :math:`\Ohat`, compute .. math:: - \|\D\ohat_i - \z_i\|_2^2 + \|\lambda\ohat_i\|_2^2, + \|\D\ohat_i - \z_i\|_2^2 + \|\lambda(\ohat_i-\ohat_i^{(0)})\|_2^2, \quad i = 1, \ldots, r, where :math:`\ohat_i` and :math:`\z_i` are the :math:`i`-th rows of - :math:`\Ohat` and :math:`\Z`, respectively. + :math:`\Ohat` and :math:`\Z`, respectively, and :math:`\ohat_i^{(0)}` + are the rows of the initial guess set during initialization. + If no initial guess was provided, :math:`\ohat_i^{(0)} = \bf0` Parameters ---------- @@ -421,8 +518,9 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: """ if self.regularizer is None: raise AttributeError("solver regularizer not set") - residual = self.residual(Ohat) - return residual + (self.regularizer**2 * np.sum(Ohat**2, axis=-1)) + Odiff = self._subtract_initial_guess(Ohat) + residual = self.residual(Odiff) + return residual + (self.regularizer**2 * np.sum(Odiff**2, axis=-1)) # Persistence ------------------------------------------------------------- def save(self, savefile: str, overwrite: bool = False): @@ -495,6 +593,15 @@ class L2DecoupledSolver(L2Solver): using the singular value decomposition of the data matrix (see :class:`L2Solver`). + If initial guesses :math:`\Ohat^{(0)}` is provided, the original + minimization problem is changed to + + .. math:: + \argmin_{\Ohat}\|\D\ohat_i - \z_i\|_2^2 + + \|\lambda_i(\ohat_i - \ohat_i^{(0)})\|_2^2 + + where :math:`\ohat_i^{(0)}` are the rows of :math:`\Ohat^{(0)}`. + Parameters ---------- regularizer : (r,) ndarray @@ -503,6 +610,9 @@ class L2DecoupledSolver(L2Solver): lapack_driver : str LAPACK routine for computing the singular value decomposition. See :func:`scipy.linalg.svd()`. + initial_guess : ndarray or None + Initial guess :math:`\Ohat^{(0)}` for the regression solution. + Defaults to zero. """ # Properties -------------------------------------------------------------- @@ -627,12 +737,15 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: Specifically, given a potential :math:`\Ohat`, compute .. math:: - \|\D\ohat_i - \z_i\|_2^2 + \|\lambda_i\ohat_i\|_2^2, + \|\D\ohat_i - \z_i\|_2^2 + \|\lambda_i(\ohat_i-\ohat_i^{(0)})\|_2^2, \quad i = 1, \ldots, r, where :math:`\ohat_i` and :math:`\z_i` are the :math:`i`-th rows of :math:`\Ohat` and :math:`\Z`, respectively, and :math:`\lambda_i \ge 0` is the corresponding regularization constant. + If an initial guess was provided during initialization, + :math:`\ohat_i^{(0)}` are the rows of the initial guess. Otherwise + :math:`\ohat_i^{(0)} = \bf0`. Parameters ---------- @@ -677,6 +790,13 @@ class TikhonovSolver(_BaseRegularizedSolver): .. math:: \Ohat = \Z\D(\D\trp\D + \bfGamma\trp\bfGamma)^{-\mathsf{T}}. + If initial guesses :math:`\Ohat^{(0)}` is provided, the original + minimization problem is changed to + + .. math:: + \argmin_\Ohat\|\D\Ohat\trp - \Z\trp\|_F^2 + + \|\bfGamma(\Ohat - \Ohat^{(0)})\trp)\|_F^2 + Parameters ---------- regularizer : (d, d) or (d,) ndarray @@ -701,6 +821,9 @@ class TikhonovSolver(_BaseRegularizedSolver): lapack_driver : str or None Which LAPACK driver is used to solve the least-squares problem, see :func:`scipy.linalg.lstsq()`. Ignored if ``method = "normal"``. + initial_guess : ndarray or None + Initial guess :math:`\Ohat^{(i)}` for the regularization solution. + Defaults to zero. """ def __init__( @@ -709,9 +832,10 @@ def __init__( method: str = "lstsq", cond: float = None, lapack_driver: str = None, + initial_guess=None, ): """Store the regularizer and initialize attributes.""" - _BaseRegularizedSolver.__init__(self) + _BaseRegularizedSolver.__init__(self, initial_guess=initial_guess) self.regularizer = regularizer self.method = method self.__options = dict(cond=cond, lapack_driver=lapack_driver) @@ -951,7 +1075,7 @@ def solve(self) -> np.ndarray: elif self.method == "normal": regD = self._DtD + (self.regularizer.T @ self.regularizer) Ohat = la.solve(regD, self._DtZt, assume_a="pos").T - return Ohat + return self._add_initial_guess(Ohat) def posterior(self): r"""Solve the Bayesian operator inference regression, constructing the @@ -1024,11 +1148,13 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: Specifically, given a potential :math:`\Ohat`, compute .. math:: - \|\D\ohat_i - \z_i\|_2^2 + \|\bfGamma\ohat_i\|_2^2, + \|\D\ohat_i - \z_i\|_2^2 + \|\bfGamma(\ohat_i-\ohat_i^{(0)})\|_2^2, \quad i = 1, \ldots, r, where :math:`\ohat_i` and :math:`\z_i` are the :math:`i`-th rows of :math:`\Ohat` and :math:`\Z`, respectively. + The :math:`\ohat_i^{(0)}` are the rows of the initial guess provided + during the initialization (defaulted to zero). Parameters ---------- @@ -1042,8 +1168,9 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: """ if self.regularizer is None: raise AttributeError("solver regularizer not set") - residual = self.residual(Ohat) - return residual + np.sum((self.regularizer @ Ohat.T) ** 2, axis=0) + Odiff = self._subtract_initial_guess(Ohat) + residual = self.residual(Odiff) + return residual + np.sum((self.regularizer @ Odiff.T) ** 2, axis=0) def save(self, savefile: str, overwrite: bool = False): """Serialize the solver, saving it in HDF5 format. @@ -1123,6 +1250,13 @@ class TikhonovDecoupledSolver(TikhonovSolver): .. math:: (\D\trp\D + \bfGamma_i\trp\bfGamma_i)\ohat_i = \D\trp\z_i. + If initial guesses :math:`\Ohat^{(0)}` is provided, the original + minimization problem is changed to + + .. math:: + \argmin_\Ohat\|\D\Ohat\trp - \Z\trp\|_F^2 + + \|\bfGamma(\Ohat - \Ohat^{(0)})\trp\|_F^2 + Parameters ---------- regularizer : list of r (d, d) or (d,) ndarrays @@ -1148,6 +1282,9 @@ class TikhonovDecoupledSolver(TikhonovSolver): lapack_driver : str or None Which LAPACK driver is used to solve the least-squares problem, see :func:`scipy.linalg.lstsq()`. Ignored if ``method = "normal"``. + initial_guess : ndarray or None + Initial guess :math:`\Ohat^{(0)}` for the regression solution. + Defaults to zero. """ # Properties -------------------------------------------------------------- @@ -1215,7 +1352,7 @@ def solve(self) -> np.ndarray: elif self.method == "normal": regD = self._DtD + Gamma.T @ Gamma Ohat[i] = la.solve(regD, self._DtZt[:, i], assume_a="pos") - return Ohat + return self._add_initial_guess(Ohat) def posterior(self): r"""Solve the Bayesian operator inference regression, constructing the @@ -1295,12 +1432,15 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: Specifically, given a potential :math:`\Ohat`, compute .. math:: - \|\D\ohat_i - \z_i\|_2^2 + \|\bfGamma_i\ohat_i\|_2^2, + \|\D\ohat_i - \z_i\|_2^2 + + \|\bfGamma_i(\ohat_i-\ohat_i^{(0)})\|_2^2, \quad i = 1, \ldots, r, where :math:`\ohat_i` and :math:`\z_i` are the :math:`i`-th rows of :math:`\Ohat` and :math:`\Z`, respectively, and :math:`\bfGamma_i` is the corresponding symmetric-positive-definite regularization matrix. + The :math:`\ohat_i^{(0)}` are the rows of the initial guess provided + during the initialization (defaulted to zero). Parameters ---------- @@ -1314,6 +1454,7 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray: """ if self.regularizer is None: raise AttributeError("solver regularizer not set") - residual = self.residual(Ohat) - rg = [np.sum((G @ oi) ** 2) for G, oi in zip(self.regularizer, Ohat)] + Odiff = self._subtract_initial_guess(Ohat) + residual = self.residual(Odiff) + rg = [np.sum((G @ oi) ** 2) for G, oi in zip(self.regularizer, Odiff)] return residual + np.array(rg) diff --git a/src/opinf/models/mono/_nonparametric.py b/src/opinf/models/mono/_nonparametric.py index 3f99e50a..fb73f94d 100644 --- a/src/opinf/models/mono/_nonparametric.py +++ b/src/opinf/models/mono/_nonparametric.py @@ -244,6 +244,7 @@ def _fit_solver(self, states, lhs, inputs=None): states, lhs, inputs ) D = self._assemble_data_matrix(states_, inputs_) + self.solver.fit(D, lhs_) def refit(self): @@ -266,7 +267,8 @@ def refit(self): return self # Execute non-intrusive learning. - self._extract_operators(self.solver.solve()) + self._extract_operators(Ohat=self.solver.solve()) + return self def fit(self, states, lhs, inputs=None): @@ -710,7 +712,12 @@ def fit(self, states, nextstates=None, inputs=None): states = states[:, :-1] if inputs is not None: inputs = inputs[..., : states.shape[1]] - return _NonparametricModel.fit(self, states, nextstates, inputs=inputs) + return _NonparametricModel.fit( + self, + states, + nextstates, + inputs=inputs, + ) def rhs(self, state, input_=None): r"""Evaluate the right-hand side of the model by applying each operator diff --git a/src/opinf/models/mono/_parametric.py b/src/opinf/models/mono/_parametric.py index 9e109a10..8e6e81f6 100644 --- a/src/opinf/models/mono/_parametric.py +++ b/src/opinf/models/mono/_parametric.py @@ -22,6 +22,7 @@ ) from ... import errors, utils, operators as _operators from ...operators import _utils as oputils +from ...operators._polynomial_operator import PolynomialOperator # Base classes ================================================================ @@ -70,7 +71,15 @@ def _check_operator_types_unique(ops): of operation (e.g., two constant operators). """ OpClasses = { - (op._OperatorClass if oputils.is_parametric(op) else type(op)) + ( + op._OperatorClass + if oputils.is_parametric(op) + else ( + op.polynomial_order + if type(op) is PolynomialOperator + else type(op) + ) + ) for op in ops } if len(OpClasses) != len(ops): @@ -266,7 +275,7 @@ def _assemble_data_matrix(self, parameters, states, inputs): def _fit_solver(self, parameters, states, lhs, inputs=None): """Construct a solver for the operator inference least-squares - regression.""" + regression""" ( parameters_, states_, @@ -281,7 +290,8 @@ def _fit_solver(self, parameters, states, lhs, inputs=None): # Set up non-intrusive learning. D = self._assemble_data_matrix(parameters_, states_, inputs_) - self.solver.fit(D, np.hstack(lhs_)) + R = np.hstack(lhs_) + self.solver.fit(D, R) self.__s = len(parameters_) def _extract_operators(self, Ohat): @@ -320,7 +330,7 @@ def refit(self): return self # Execute non-intrusive learning. - self._extract_operators(self.solver.solve()) + self._extract_operators(Ohat=self.solver.solve()) return self def fit(self, parameters, states, lhs, inputs=None): @@ -511,7 +521,13 @@ class _ParametricDiscreteMixin: _ModelClass = _FrozenDiscreteModel - def fit(self, parameters, states, nextstates=None, inputs=None): + def fit( + self, + parameters, + states, + nextstates=None, + inputs=None, + ): r"""Learn the model operators from data. The operators are inferred by solving the regression problem @@ -799,7 +815,12 @@ def fit(self, parameters, states, ddts, inputs=None): ------- self """ - return super().fit(parameters, states, ddts, inputs=inputs) + return super().fit( + parameters, + states, + ddts, + inputs=inputs, + ) def rhs(self, t, parameter, state, input_func=None): r"""Evaluate the right-hand side of the model by applying each operator @@ -1150,11 +1171,7 @@ def _fit_solver(self, parameters, states, lhs, inputs=None): ], solver=self.solver.copy(), ) - model_i._fit_solver( - states_[i], - lhs_[i], - inputs_[i], - ) + model_i._fit_solver(states_[i], lhs_[i], inputs_[i]) nonparametric_models.append(model_i) self.solvers = [mdl.solver for mdl in nonparametric_models] diff --git a/src/opinf/operators/_affine.py b/src/opinf/operators/_affine.py index eb8ee04b..481a50e4 100644 --- a/src/opinf/operators/_affine.py +++ b/src/opinf/operators/_affine.py @@ -10,6 +10,7 @@ "AffineCubicOperator", "AffineInputOperator", "AffineStateInputOperator", + "AffinePolynomialOperator", ] import h5py @@ -27,6 +28,7 @@ InputOperator, StateInputOperator, ) +from ._polynomial_operator import PolynomialOperator # Helper functions ============================================================ @@ -668,6 +670,209 @@ class AffineCubicOperator(_AffineOperator): _OperatorClass = CubicOperator +class AffinePolynomialOperator(_AffineOperator): + # TODO: update description + r"""Affine-parametric cubic operator + :math:`\Ophat_{\ell}(\qhat,\u;\bfmu) + = \Ghat_{\ell}(\bfmu)[\qhat\otimes\qhat\otimes\qhat] = \left( + \sum_{a=0}^{A_{\ell}-1}\theta_{\ell}^{(a)}\!(\bfmu)\,\Ghat_{\ell}^{(a)} + \right)[\qhat\otimes\qhat\otimes\qhat].` + + Here, each :math:`\theta_\ell^{(a)}:\RR^{p}\to\RR` is a scalar-valued + function of the parameter vector + and each :math:`\Ghat_{\ell}^{(a)} \in \RR^{r\times r^3}` is a constant + matrix, see :class:`opinf.operators.CubicOperator`. + + Parameters + ---------- + coeffs : callable, (iterable of callables), or int + Coefficient functions for the terms of the affine expansion. + + * If callable, it should receive a parameter vector + :math:`\bfmu` and return the vector of affine coefficients, + :math:`[~\theta_{\ell}^{(0)}(\bfmu) + ~~\cdots~~\theta_{\ell}^{(A_{\ell}-1)}(\bfmu)~]\trp`. + In this case, ``nterms`` is a required argument. + * If an iterable, each entry should be a callable representing a + single affine coefficient function :math:`\theta_{\ell}^{(a)}`. + * If an integer :math:`p`, set :math:`A_{\ell} = p` and define + :math:`\theta_{\ell}^{(i)}\!(\bfmu) = \mu_i`. This is equivalent to + using ``coeffs=lambda mu: mu``, except the parameter dimension is + also captured and ``nterms`` is not required. + entries : list of ndarrays, or None + Operator matrices for each term of the affine expansion, i.e., + :math:`\Ghat_{\ell}^{(0)},\ldots,\Ghat_{\ell}^{(A_{\ell}-1)}.` + If not provided in the constructor, use :meth:`set_entries` later. + fromblock : bool + If ``True``, interpret ``entries`` as a horizontal concatenation + of arrays; if ``False`` (default), interpret ``entries`` as a list + of arrays. + """ + + _OperatorClass = PolynomialOperator + + def __init__( + self, + coeffs, + polynomial_order: int, + nterms: int = None, + entries=None, + fromblock: bool = False, + ): + """same as AffineOperator.__init__, + except the polynomial order is + passed as an additional input""" + self.polynomial_order = polynomial_order + super().__init__( + coeffs=coeffs, nterms=nterms, entries=entries, fromblock=fromblock + ) + + def operator_dimension(self, s: int, r: int, m: int) -> int: + r"""Number of columns in the concatenated operator matrix. + See AffineOperator.operator_dimension for detailed description. + Implementation is just slightly different because + PolynomialOperator.operator_dimension is not static. + """ + return self.nterms * PolynomialOperator( + polynomial_order=self.polynomial_order + ).operator_dimension(r, m) + + def datablock(self, parameters, states, inputs=None) -> np.ndarray: + r"""same as AffineOperator.datablock. + Implementation is just slightly different because + PolynomialOperator.datablock is not static. + """ + if not isinstance(self, InputMixin): + inputs = [None] * len(parameters) + blockcolumns = [] + for mu, Q, U in zip(parameters, states, inputs): + Di = PolynomialOperator( + polynomial_order=self.polynomial_order + ).datablock(Q, U) + theta_mus = self.coeffs(mu) + if self.nterms == 1 and np.isscalar(theta_mus): + theta_mus = [theta_mus] + blockcolumns.append(np.vstack([theta * Di for theta in theta_mus])) + return np.hstack(blockcolumns) + + def set_entries(self, entries, fromblock: bool = False) -> None: + r"""same as AffineOperator.set_entries. + Implementation is just slightly different because + PolynomialOperator.datablock is not static. + """ + # Extract / verify the entries. + nterms = self.nterms + if fromblock: + if not isinstance(entries, np.ndarray) or ( + entries.ndim not in (1, 2) + ): + raise ValueError( + "entries must be a 1- or 2-dimensional ndarray " + "when fromblock=True" + ) + entries = np.split(entries, nterms, axis=-1) + if np.ndim(entries) > 1: + self._check_shape_consistency(entries, "entries") + if (n_arrays := len(entries)) != nterms: + raise ValueError( + f"{nterms} = number of affine expansion terms " + f"!= len(entries) = {n_arrays}" + ) + + ParametricOpInfOperator.set_entries( + self, + [ + PolynomialOperator( + entries=A, polynomial_order=self.polynomial_order + ).entries + for A in entries + ], + ) + + @utils.requires("entries") + def evaluate(self, parameter): + r"""Evaluate the operator at the given parameter value. + Same as AffineOperator.evaluate, just implemented slightly differently. + """ + if self.parameter_dimension is None: + self._set_parameter_dimension_from_values([parameter]) + self._check_parametervalue_dimension(parameter) + theta_mus = self.coeffs(parameter) + if self.nterms == 1 and np.isscalar(theta_mus): + theta_mus = [theta_mus] + entries = sum([tm * A for tm, A in zip(theta_mus, self.entries)]) + return self._OperatorClass( + entries=entries, polynomial_order=self.polynomial_order + ) + + def restrict_to_subspace(self, indices_trial, indices_test=None): + """ + - not checking for duplicate indices + """ + new_entries = [ + PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + indices_test=indices_test, + entries=self.entries[i], + polynomial_order=self.polynomial_order, + ) + for i in range(self.nterms) + ] + + return AffinePolynomialOperator( + coeffs=self.coeffs, + polynomial_order=self.polynomial_order, + nterms=self.nterms, + entries=new_entries, + ) + + def extend_to_dimension( + self, new_r, indices_trial=None, indices_test=None, new_r_test=None + ): + """ + - not checking for duplicate indices + """ + if indices_trial is None: + indices_trial = [*range(self.state_dimension)] + + new_entries = [ + PolynomialOperator._extend_matrix_to_dimension( + new_r=new_r, + indices_trial=indices_trial, + indices_test=indices_test, + old_entries=self.entries[i], + polynomial_order=self.polynomial_order, + new_r_test=new_r_test, + ) + for i in range(self.nterms) + ] + + return AffinePolynomialOperator( + coeffs=self.coeffs, + polynomial_order=self.polynomial_order, + nterms=self.nterms, + entries=new_entries, + ) + + def copy(self): + """Return a copy of the operator. Only the operator matrices are + copied, not the coefficient functions. + """ + As = None + if self.entries is not None: + As = [A.copy() for A in self.entries] + op = self.__class__( + coeffs=self.coeffs, + nterms=self.nterms, + entries=As, + fromblock=False, + polynomial_order=self.polynomial_order, + ) + if self.parameter_dimension is not None: + op.parameter_dimension = self.parameter_dimension + return op + + class AffineInputOperator(_AffineOperator, InputMixin): r"""Affine-parametric input operator :math:`\Ophat_{\ell}(\qhat,\u;\bfmu) diff --git a/src/opinf/operators/_base.py b/src/opinf/operators/_base.py index c4c7db20..df8ee6e5 100644 --- a/src/opinf/operators/_base.py +++ b/src/opinf/operators/_base.py @@ -780,7 +780,7 @@ def datablock(states: np.ndarray, inputs=None) -> np.ndarray: def copy(self): """Return a copy of the operator.""" entries = self.entries.copy() if self.entries is not None else None - return self.__class__(entries) + return self.__class__(entries=entries) def save(self, savefile: str, overwrite: bool = False) -> None: """Save the operator to an HDF5 file. diff --git a/src/opinf/operators/_nonparametric.py b/src/opinf/operators/_nonparametric.py index 84f892c7..82d3926b 100644 --- a/src/opinf/operators/_nonparametric.py +++ b/src/opinf/operators/_nonparametric.py @@ -1845,12 +1845,21 @@ class InputOperator(OpInfOperator, InputMixin): True """ + my_input_dimension = None + + def set_input_dimension(self, m): + self.my_input_dimension = m + @property def input_dimension(self): r"""Dimension :math:`m` of the input :math:`\u` that the operator acts on. """ - return None if self.entries is None else self.entries.shape[1] + return ( + self.my_input_dimension + if self.entries is None + else self.entries.shape[1] + ) @staticmethod def _str(statestr, inputstr): @@ -1990,6 +1999,135 @@ def operator_dimension(r, m): """ return m + def restrict_to_subspace(self, indices_trial, indices_test=None): + r""" + Creates a new operator of type `InputOperator` for the reduced + (test) dimension + ``len(indices_test)`` (Petrov-Galerkin setting). The new operator + is constructed by restricting testing + in :math:`span{\mathbf{v}_i: i \in indices_test}`. + + If ``indices_test`` + is not provided, defaults to the Galerkin setting + ``indices_test = indices_trial``. + + Currently, the more general restriction onto combinations of + basis vectors (e.g., onto :math:`span{(v_1+v_2)/2}`) is not supported. + + Parameters + ---------- + indices_trial : list of integers + indices of the (trial) basis vectors onto which the operator + shall be restricted. Needs to be in increasing order and + not contain dubplicates. + indices_test : list of integers + indices of the (test) basis vectors onto which the operator + shall be restricted in the Petrov-Galerkin setting in + increasing order. Needs to be in increasing order and + not contain dubplicates. + + Returns + ------- + InputOperator + Operator for test + dimension ``len(indices_test)``, and polynomial order + ``self.polynomial_order``. + """ + if indices_test is None: + indices_test = indices_trial + + if max(indices_test) >= self.state_dimension: + raise RuntimeError( + f""" + In InputOperator.restrict_to_subspace: + Encountered restriction onto unknown test basis + vector number {max(indices_test)}. + Reduced dimension is {self.state_dimension}""" + ) + + new_entries = self.entries[indices_test, :] + + return InputOperator(entries=new_entries) + + def extend_to_dimension( + self, new_r, indices_trial=None, indices_test=None, new_r_test=None + ): + r""" + Creates a new operator of type `InputOperator` of the same + input dimension as this one but for the reduced (test) dimension + ``new_r_test`` (defaulted to + ``new_r_test = new_r`` if not provided). The new operator is + created by mapping the current test basis vectors :math:`\mathbf{w}_i` + onto the new test vectors :math:`\tilde{\mathbf{w}}_j`, + :math:`j=` ``indices_test[i]`` of the new basis, :math`i=1, ..., r`. + The remaining actions of the new operator (i.e., all actions that + involve :math:`\tilde{\mathbf{v}}_j` with + :math:`j\notin` ``indices_trial`` or :math:`\tilde{\mathbf{w}}_j` + with :math:`j\notin` ``indices_test``) are defaulted to 0. + + If ``indices_trial`` is not provided, it is assumed that the + current basis is expanded and the current basis vectors are + to be mapped onto the first :math:`r` basis vectors of the + new basis, i.e., we default to ``indices_trial = [0, ..., r-1]``. + + If ``indices_test`` + is not provided, defaults to the Galerkin setting + ``indices_test = indices_trial``. + + Currently, the more general restriction onto combinations of + basis vectors (e.g., onto :math:`span{(v_1+v_2)/2}`) is not supported. + + Parameters + ---------- + new_r : int + target reduced dimension (trial space). Needs to be at + least as large as ``self.state_dimension`` + indices_trial : list of integers + indices of the (trial) basis vectors to which the previous + operator entries shall be mapped in the expanded basis. + Needs to be in increasing order and + not contain dubplicates. + indices_test : list of integers + indices of the (test) basis vectors onto which the + previous operator entries shall be mapped in the + expanded basis (Petrov-Galerkin setting only). + Needs to be in increasing order and + not contain dubplicates. + new_r_test : int + target reduced dimension (test space). Defaulted to + ``new_r`` if not provided. + + Returns + ------- + InputOperator + Operator for trial dimension ``new_r``, test + dimension ``new_r_test``, and polynomial order + ``self.polynomial_order``. + """ + if indices_trial is None: + # default to extending the basis towards the right + indices_trial = [*range(self.state_dimension)] + + if indices_test is None: + # default to Galerking case + indices_test = indices_trial + + if new_r_test is None: + new_r_test = new_r + + if new_r_test < self.state_dimension: + raise RuntimeError( + f"""In InputOperator.extend_to_dimension: + Dimension mismatch. Expected new dimension ({new_r_test}) + to be larger than old dimension ({self.state_dimension}) + """ + ) + + new_entries = np.zeros((new_r_test, self.input_dimension)) + new_entries[indices_test, :] = self.entries + + return InputOperator(entries=new_entries) + # Dependent on both state and input =========================================== class StateInputOperator(OpInfOperator, InputMixin): diff --git a/src/opinf/operators/_polynomial_operator.py b/src/opinf/operators/_polynomial_operator.py index 7d2c4e39..78da44d2 100644 --- a/src/opinf/operators/_polynomial_operator.py +++ b/src/opinf/operators/_polynomial_operator.py @@ -4,6 +4,7 @@ import numpy as np import scipy.linalg as la +import itertools # import scipy.sparse as sparse from scipy.special import comb @@ -63,6 +64,13 @@ def __init__(self, polynomial_order: int, entries=None): super().__init__(entries=entries) + def copy(self): + """Return a copy of the operator.""" + entries = self.entries.copy() if self.entries is not None else None + return self.__class__( + entries=entries, polynomial_order=self.polynomial_order + ) + def operator_dimension(self, r: int, m=None) -> int: """ computes the number of non-redundant terms in a vector of length r @@ -79,12 +87,15 @@ def operator_dimension(self, r: int, m=None) -> int: raise ValueError( f"expected non-negative integer reduced dimension r. Got r={r}" ) + return PolynomialOperator.polynomial_operator_dimension( + r=r, polynomial_order=self.polynomial_order + ) - # for constant operators the dimension does not matter - if (p := self.polynomial_order) == 0: + @staticmethod + def polynomial_operator_dimension(r, polynomial_order) -> int: + if polynomial_order == 0: return 1 - - return comb(r, p, repetition=True, exact=True) + return comb(r, polynomial_order, repetition=True, exact=True) def datablock(self, states: np.ndarray, inputs=None) -> np.ndarray: r"""Return the data matrix block corresponding to @@ -189,10 +200,43 @@ def exp_p(x, p, kept=None): kept[p] ] + @staticmethod + def ckron_indices(r, p): + """Construct a mask for efficiently computing the compressed, p-times + repeated Kronecker product. + + This method provides a faster way to evaluate :meth:`ckron` + when the state dimension ``r`` is known *a priori*. + + Parameters + ---------- + r : int + State dimension. + p: int + polynomial order + + Returns + ------- + mask : ndarray + Compressed Kronecker product mask. + + """ + mask = np.array( + [*itertools.combinations_with_replacement([*range(r)][::-1], p)] + )[::-1, :] + return mask + def apply(self, state: np.ndarray, input_=None) -> np.ndarray: r"""Apply the operator to the given state. Input is not used. See OpInfOperator.apply for description. """ + # note: having all these if conditions here to distinguish between + # p=0, p=1, p>1 + # really makes things slower than necessary (apply gets called a lot) + # it might be worth excluding these special cases from the + # PolynomialOperator + # class and defaulting to ConstantOperator and LinearOperator instead. + if state.shape[0] != self.state_dimension: raise ValueError( f"Expected state of dimension r={self.state_dimension}." @@ -215,9 +259,8 @@ def apply(self, state: np.ndarray, input_=None) -> np.ndarray: # note: no need to go through the trouble of identifying the # non-redundant indices - # higher-order - restricted_kronecker_product = PolynomialOperator.exp_p( - x=state, p=self.polynomial_order, kept=self.nonredudant_entries + restricted_kronecker_product = np.prod( + state[self.nonredudant_entries_mask], axis=1 ) return self.entries @ restricted_kronecker_product @@ -228,8 +271,440 @@ def nonredudant_entries(self) -> list: when restricting the i-times Kronecker product of a vector of shape self.state_dimension() with itself. """ - # return self.__nonredudant_entries return [ PolynomialOperator.keptIndices_p(r=self.state_dimension, p=i) for i in range(self.polynomial_order + 1) ] + + @property + def nonredudant_entries_mask(self) -> np.ndarray: + r"""list containing at index i a list of the indices that are kept + when restricting the i-times Kronecker product of a vector of + shape self.state_dimension() with itself. + """ + return self.ckron_indices( + r=self.state_dimension, p=self.polynomial_order + ) + + def restrict_to_subspace(self, indices_trial, indices_test=None): + r""" + Creates a new operator of type `PolynomialOperator` of the same + polynomial order as this one but for the reduced (trial) dimension + :math:`r_1 :=` ``len(indices)`` and (test) dimension + ``len(indices_test)`` (Petrov-Galerkin setting). The new operator + is constructed by restricting the action of this operator (``self``) + onto :math:`span{\mathbf{v}_i: i \in indices_trial}`, and testing + in :math:`span{\mathbf{v}_i: i \in indices_test}`. + + If ``indices_test`` + is not provided, defaults to the Galerkin setting + ``indices_test = indices_trial``. + + Currently, the more general restriction onto combinations of + basis vectors (e.g., onto :math:`span{(v_1+v_2)/2}`) is not supported. + + Parameters + ---------- + indices_trial : list of integers + indices of the (trial) basis vectors onto which the operator + shall be restricted. Needs to be in increasing order and + not contain dubplicates. + indices_test : list of integers + indices of the (test) basis vectors onto which the operator + shall be restricted in the Petrov-Galerkin setting in + increasing order. Needs to be in increasing order and + not contain dubplicates. + + Returns + ------- + PolynomialOperator + Operator for trial dimension ``len(indices_trial)``, test + dimension ``len(indices_test)``, and polynomial order + ``self.polynomial_order``. + """ + if indices_test is None: + indices_test = indices_trial + + if max(indices_trial) >= self.state_dimension: + raise RuntimeError( + f""" + In PolynomialOperator.restrict_to_subspace: + Encountered restriction onto unknown trial basis + vector number {max(indices_trial)}. + Reduced dimension is {self.state_dimension}""" + ) + + if max(indices_test) >= self.state_dimension: + raise RuntimeError( + f""" + In PolynomialOperator.restrict_to_subspace: + Encountered restriction onto unknown test basis + vector number {max(indices_test)}. + Reduced dimension is {self.state_dimension}""" + ) + + new_entries = PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + entries=self.entries, + polynomial_order=self.polynomial_order, + indices_test=indices_test, + ) + + return PolynomialOperator( + entries=new_entries, polynomial_order=self.polynomial_order + ) + + @staticmethod + def _restrict_matrix_to_subspace( + indices_trial, entries, polynomial_order, indices_test=None + ): + r""" + Treating the matrix `entries` as operator matrix for the + polynomial order `polynomial_order`, this function creates + a submatrix `entries_sub` by restricting `entries` onto + those columns that correspond to interactions of basis + vectors :math:`v_i` with :math:`i \in` ``indices_trial`` + and to the rows listed in `indices_test`. + + Defaults to ``indices_test=indices_trial`` if + ``indices_test = None``. + + Parameters + ---------- + indices_trial : list of integers + indices of the (trial) basis vectors onto which the operator + shall be restricted. Needs to be in increasing order and + not contain dubplicates. + entries : np.ndarray + operator entry matrix of shape :math:`(a,b)` with + :math:`a \ge ` ``len(indices_test)`` and + :math:`b \ge r^{(p)}` where :math:`r=` ``len(indices_trial)`` + and :math:`r^{(p)}` is the number of non-redundant entries for + a polynomial operator of order :math:`p` and dimension :math:`r`. + polynomial_order : int + polynomial order of the operator matrix to be extracted + indices_test : list of integers + indices of the (test) basis vectors onto which the operator + shall be restricted in the Petrov-Galerkin setting in + increasing order. Needs to be in increasing order and + not contain dubplicates. + + Returns + ------- + entries_sub : np.ndarray + of shape ``(len(indices_test), c)``, where `c` is the + number of non-redundant entries for the condensed + Kronecker product of a vector with dimension + ``len(indices_trial)``. + """ + if len(indices_trial) != len(set(indices_trial)): + raise RuntimeError( + f""" + In PolynomialOperator.restrict_matrix_to_subspace: + Received duplicate entries in + `indices_trial=`{indices_trial}""" + ) + + if indices_test is None: + indices_test = indices_trial + elif len(indices_test) != len(set(indices_test)): + raise RuntimeError( + f""" + In PolynomialOperator.restrict_matrix_to_subspace: + Received duplicate entries in + `indices_test=`{indices_test}""" + ) + + if indices_trial != sorted(indices_trial) or indices_test != sorted( + indices_test + ): + raise RuntimeError( + f""" + In PolynomialOperator._extend_matrix_to_dimension: Received unordered + indices {indices_trial} (trial) or {indices_test} (test). + """ + ) + + # constant + if polynomial_order == 0: + if np.ndim(entries) == 2: + return entries[indices_test, 0] + return entries[indices_test] + + # higher-order polynomials + entries_sub = entries[indices_test, :] + # restrict to test indices only + + col_indices = PolynomialOperator._columnIndices_p( + indices=indices_trial, p=polynomial_order + ) + # find out which columns to keep + + return entries_sub[:, col_indices] + + @staticmethod + def _columnIndices_p(indices, p): + r""" + Identifies all (column) indices of a polynomial operator + of polynomial order :math:`p` that encode interactions + of basis vectors :math:`v_i` with :math:`i\in` ``indices``. + + Parameters + ---------- + indices : list of integers + indices of the basis vectors for which interactions + shall be identified + p : int + polynomial order of the interactions + """ + if p == 1: + return indices + + if p == 0: + return [0] + + sub = PolynomialOperator._columnIndices_p(indices, p - 1) + return [ + comb(indices[i], p, repetition=True, exact=True) + sub[j] + for i in range(len(indices)) + for j in range( + PolynomialOperator.polynomial_operator_dimension( + r=i + 1, polynomial_order=p - 1 + ) + ) + ] + + def extend_to_dimension( + self, new_r, indices_trial=None, indices_test=None, new_r_test=None + ): + r""" + Creates a new operator of type `PolynomialOperator` of the same + polynomial order as this one but for the reduced (trial) dimension + ``new_r`` and (test) dimension ``new_r_test`` (defaulted to + ``new_r_test = new_r`` if not provided). The new operator is + created by mapping the current (trial) basis vector + :math:`\mathbf{v}_i` + onto the basis vector :math:`\tilde{\mathbf{v}}_j`, + :math:`j=` ``indices_trial[i]`` of the new basis, :math`i=1, ..., r`. + Similarly, the current test basis vectors :math:`\mathbf{w}_i` are + mapped onto the new test vectors :math:`\tilde{\mathbf{w}}_j`, + :math:`j=` ``indices_test[i]`` of the new basis, :math`i=1, ..., r`. + The remaining actions of the new operator (i.e., all actions that + involve :math:`\tilde{\mathbf{v}}_j` with + :math:`j\notin` ``indices_trial`` or :math:`\tilde{\mathbf{w}}_j` + with :math:`j\notin` ``indices_test``) are defaulted to 0. + + If ``indices_trial`` is not provided, it is assumed that the + current basis is expanded and the current basis vectors are + to be mapped onto the first :math:`r` basis vectors of the + new basis, i.e., we default to ``indices_trial = [0, ..., r-1]``. + + If ``indices_test`` + is not provided, defaults to the Galerkin setting + ``indices_test = indices_trial``. + + Currently, the more general restriction onto combinations of + basis vectors (e.g., onto :math:`span{(v_1+v_2)/2}`) is not supported. + + Parameters + ---------- + new_r : int + target reduced dimension (trial space). Needs to be at + least as large as ``self.state_dimension`` + indices_trial : list of integers + indices of the (trial) basis vectors to which the previous + operator entries shall be mapped in the expanded basis. + Needs to be in increasing order and + not contain dubplicates. + indices_test : list of integers + indices of the (test) basis vectors onto which the + previous operator entries shall be mapped in the + expanded basis (Petrov-Galerkin setting only). + Needs to be in increasing order and + not contain dubplicates. + new_r_test : int + target reduced dimension (test space). Defaulted to + ``new_r`` if not provided. + + Returns + ------- + PolynomialOperator + Operator for trial dimension ``new_r``, test + dimension ``new_r_test``, and polynomial order + ``self.polynomial_order``. + """ + if indices_trial is None: + # default to extending the basis towards the right + indices_trial = [*range(self.state_dimension)] + + if indices_test is None: + # default to Galerking case + indices_test = indices_trial + + if new_r < self.state_dimension: + raise RuntimeError( + f"""In PolynomialOperator.extend_to_dimension: + Dimension mismatch. Expected new dimension ({new_r}) + to be larger than old dimension ({self.state_dimension}) + """ + ) + + new_entries = PolynomialOperator._extend_matrix_to_dimension( + new_r=new_r, + indices_trial=indices_trial, + polynomial_order=self.polynomial_order, + old_entries=self.entries, + indices_test=indices_test, + new_r_test=new_r_test, + ) + + return PolynomialOperator( + polynomial_order=self.polynomial_order, entries=new_entries + ) + + @staticmethod + def _extend_matrix_to_dimension( + new_r, + indices_trial, + polynomial_order, + old_entries, + indices_test=None, + new_r_test=None, + ): + r""" + This is the reverse function to _restrict_marix_to_dimension. + + Treating the matrix ``old_entries`` as operator matrix for the + polynomial order ``polynomial_order``, this function creates + a larger matrix of shape ``(new_r_test, a)`` with + :math:`a=r_{new}^{(p)}` the number of non-redundant entries + in the :math:`p=` ``polynomial-oder``-fold Kronecker product + of an :math:`r_{new}=` ``new_r`` dimensional vector. + The new matrix constains ``old_entries`` as the submatrix + encoding operator actions for trial and test basis vectors + with indices in ``indices_trial`` and ``indices_test``. + All remaining entries are set to zero. + + Defaults to ``indices_test=indices_trial`` if + ``indices_test = None``. + + Parameters + ---------- + new_r : int + target reduced dimension (trial space). Needs to be at + least as large as ``self.state_dimension`` + indices_trial : list of integers + indices of the (trial) basis vectors to which the previous + operator entries shall be mapped in the expanded basis. + Needs to be in increasing order and + not contain dubplicates. + polynomial_order : int + polynomial order of the operator matrix to be extracted + old_entries : np.ndarray + operator entry matrix of shape ``(a, b)`` with + :math:`a =` ``len(indices_test)`` and + :math:`b \ge r^{(p)}` where :math:`r=` ``len(indices_trial)`` + and :math:`r^{(p)}` is the number of non-redundant entries for + a polynomial operator of order :math:`p` and dimension :math:`r`. + indices_test : list of integers + indices of the (test) basis vectors onto which the + previous operator entries shall be mapped in the + expanded basis (Petrov-Galerkin setting only). + Needs to be in increasing order and + not contain dubplicates. + new_r_test : int + target reduced dimension (test space). Defaulted to + ``new_r`` if not provided. + + Returns + ------- + new_matrix : np.ndarray + of shape ``(new_r_test, c)``, where `c` is the + number of non-redundant entries for the condensed + Kronecker product of a vector with dimension + ``new_r`` for polynomial order ``polynomial_order``. + Contains `old_entries` as submatrix. + """ + if indices_test is None: + indices_test = indices_trial + + if new_r_test is None: + new_r_test = new_r + + if indices_trial != sorted(indices_trial) or indices_test != sorted( + indices_test + ): + raise RuntimeError( + f""" + In PolynomialOperator._extend_matrix_to_dimension: + Received unordered indices + {indices_trial} (trial) or {indices_test} (test). + """ + ) + + old_r = len(indices_trial) + if not old_entries.shape == ( + len(indices_test), + PolynomialOperator.polynomial_operator_dimension( + r=old_r, polynomial_order=polynomial_order + ), + ): + raise RuntimeError( + f"""In PolynomialOperator._extend_matrix_to_dimension: + Mismatch in the dimension of the passed matrix. + Expected { + (len(indices_test), + PolynomialOperator.polynomial_operator_dimension(r=old_r, + polynomial_order=polynomial_order))}. + Got {old_entries.shape}. + """ + ) + + if old_r > new_r: + raise RuntimeError( + f"""In PolynomialOperator. + _extend_matrix_to_dimension: + Mismatch in passed indices: + Old dimension {old_r} is larger + than new dimension {new_r} + """ + ) + + if len(indices_test) > new_r_test: + raise RuntimeError( + f"""In PolynomialOperator. + _extend_matrix_to_dimension: + Mismatch in passed indices for test space: + Old dimension {len(indices_test)} is + larger than new test dimension {new_r_test} + """ + ) + + # initialize matrix for old test space dimension + new_marix_sub = np.zeros( + shape=( + len(indices_test), + PolynomialOperator.polynomial_operator_dimension( + r=new_r, polynomial_order=polynomial_order + ), + ) + ) + + # populate columns + col_indices = PolynomialOperator._columnIndices_p( + indices=indices_trial, p=polynomial_order + ) + new_marix_sub[:, col_indices] = old_entries + + # final matrix: fill remaining rows with zeros + new_matrix = np.zeros( + shape=( + new_r_test, + PolynomialOperator.polynomial_operator_dimension( + r=new_r, polynomial_order=polynomial_order + ), + ) + ) + new_matrix[indices_test, :] = new_marix_sub + + return new_matrix diff --git a/tests/basis/test_pod.py b/tests/basis/test_pod.py index d7adac10..82eda8be 100644 --- a/tests/basis/test_pod.py +++ b/tests/basis/test_pod.py @@ -358,6 +358,7 @@ def test_fit(self, n=60, k=20, r=4): f"only {k} singular vectors can be extracted from ({n} x {k}) " f"snapshots, setting max_vectors={k}" ) + assert out is basis assert basis.full_state_dimension == n assert basis.reduced_state_dimension == r @@ -377,7 +378,9 @@ def test_fit(self, n=60, k=20, r=4): weights=SP, minthresh=1e-20, ) + out = basis.fit(Q) + assert out is basis assert basis.full_state_dimension == n assert basis.reduced_state_dimension == r diff --git a/tests/lstsq/test_tikhonov.py b/tests/lstsq/test_tikhonov.py index 01331c5a..4f55a987 100644 --- a/tests/lstsq/test_tikhonov.py +++ b/tests/lstsq/test_tikhonov.py @@ -894,5 +894,90 @@ def test_save_load_copy_and_reset(self, k=20): return super().test_save_load_copy_and_reset(k=k, d=10, r=5) +@pytest.mark.parametrize( + "n_cols, n_dofs, n_red", + [ + (n_cols, n_dofs, n_red) + for n_cols in [20, 50, 100, 1000] + for n_dofs in [1, 2, 5, 10, 20] + for n_red in range(1, 5) + ], +) +def test_initial_guesses(n_cols, n_dofs, n_red): + D = np.random.normal(size=(n_cols, n_dofs)) + R = np.random.normal(size=(n_red, n_cols)) + guess = np.random.normal(size=(n_red, n_dofs)) + + def compare_to_std(solver_std, solver_ini): + solver_std.fit(D, R) + Ohat_std = solver_std.solve() + + solver_ini.fit(D, R) + Ohat_ini = solver_ini.solve() + + assert Ohat_std.shape == Ohat_ini.shape + assert Ohat_ini.shape == guess.shape + + a = solver_std.regresidual(Ohat=Ohat_std) + b = solver_std.regresidual(Ohat=Ohat_ini) + assert ((a <= b) | np.isclose(a, b)).all() + + a = solver_ini.regresidual(Ohat=Ohat_std) + b = solver_ini.regresidual(Ohat=Ohat_ini) + assert ((a >= b) | np.isclose(a, b)).all() + + def compare_to_stronger(solver_class, regularizer): + """Ensure that using a stronger regularization brings + you closer to the initial guess. + """ + scaling = np.logspace(-5, 5, 11) + val_prev = np.inf * np.ones(n_red) + for i, scale in enumerate(scaling): + solver_ini = solver_class( + regularizer=scale * regularizer, initial_guess=guess + ) + solver_ini.fit(D, R) + Ohat_ini = solver_ini.solve() + diff = la.norm(Ohat_ini - guess, axis=1) ** 2 + assert (diff <= val_prev).all + val_prev = diff + + # L2 solver + solver_std = opinf.lstsq.L2Solver(regularizer=1) + solver_ini = opinf.lstsq.L2Solver(regularizer=1, initial_guess=guess) + compare_to_std(solver_std, solver_ini) + compare_to_stronger(opinf.lstsq.L2Solver, regularizer=1) + + # L2 decoupled solver + reg = np.logspace(-2, 2, n_red) + solver_std = opinf.lstsq.L2DecoupledSolver(regularizer=reg) + solver_ini = opinf.lstsq.L2DecoupledSolver( + regularizer=reg, initial_guess=guess + ) + compare_to_std(solver_std, solver_ini) + compare_to_stronger(opinf.lstsq.L2DecoupledSolver, regularizer=reg) + + # Tikhonov solver + reg = np.random.normal(size=(5 * n_dofs, n_dofs)) + reg = reg.T @ reg + solver_std = opinf.lstsq.TikhonovSolver(regularizer=reg) + solver_ini = opinf.lstsq.TikhonovSolver( + regularizer=reg, initial_guess=guess + ) + compare_to_std(solver_std, solver_ini) + compare_to_stronger(opinf.lstsq.TikhonovSolver, regularizer=reg) + + # Tikhonov decoupled solver + reg = [None] * n_red + for i in range(n_red): + yolo = np.random.normal(size=(5 * n_dofs, n_dofs)) + reg[i] = yolo.T @ yolo + solver_std = opinf.lstsq.TikhonovDecoupledSolver(regularizer=reg) + solver_ini = opinf.lstsq.TikhonovDecoupledSolver( + regularizer=reg, initial_guess=guess + ) + compare_to_std(solver_std, solver_ini) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/operators/test_affine.py b/tests/operators/test_affine.py index 59c01d98..2a425308 100644 --- a/tests/operators/test_affine.py +++ b/tests/operators/test_affine.py @@ -387,7 +387,11 @@ def test_publics(): OpClass, _submodule._AffineOperator ): continue - op = OpClass(_TestAffineOperator.thetas1) + + if OpClassName == "AffinePolynomialOperator": + op = OpClass(_TestAffineOperator.thetas1, polynomial_order=1) + else: + op = OpClass(_TestAffineOperator.thetas1) assert issubclass( op._OperatorClass, opinf.operators.OpInfOperator, diff --git a/tests/operators/test_nonparametric.py b/tests/operators/test_nonparametric.py index 2e1760c0..8e027e6b 100644 --- a/tests/operators/test_nonparametric.py +++ b/tests/operators/test_nonparametric.py @@ -1249,6 +1249,40 @@ def test_operator_dimension(self): assert self.Operator.operator_dimension(5, 2) == 2 +@pytest.mark.parametrize( + "r_large, r_small, m", + [ + (r_large, r_small, m) + for r_large in range(1, 8) + for r_small in range(1, r_large + 1) + for m in range(1, 4) + ], +) +def test_extend_dimension(r_large, r_small, m): + matrix_original = np.random.uniform( + size=( + r_small, + _module.InputOperator.operator_dimension(r=r_small, m=m), + ) + ) + + # sample random test indices + indices_test = np.random.choice( + [*range(r_large)], r_small, replace=False + ).tolist() + indices_test.sort() + + # scale operator up and down + operator = _module.InputOperator(entries=matrix_original.copy()) + operator_extended = operator.extend_to_dimension( + new_r=r_large, indices_trial=indices_test + ) + operator_condensed = operator_extended.restrict_to_subspace( + indices_trial=indices_test + ) + assert (matrix_original == operator_condensed.entries).all() + + # Dependent on state and input ================================================ class TestStateInputOperator(_TestNonparametricOperator): """Test operators._nonparametric.StateInputOperator.""" diff --git a/tests/operators/test_poly_operator.py b/tests/operators/test_poly_operator.py index e6cb1e84..964d467d 100644 --- a/tests/operators/test_poly_operator.py +++ b/tests/operators/test_poly_operator.py @@ -6,6 +6,7 @@ import opinf from opinf.operators._polynomial_operator import PolynomialOperator +from opinf.operators._affine import AffinePolynomialOperator other_operators = opinf.operators._nonparametric @@ -159,3 +160,212 @@ def test_apply_against_reference(r, p): # compare assert action.shape == action_ref.shape == (r,) # same size assert np.isclose(action, action_ref).all() # same entries + + +@pytest.mark.parametrize("r", [(r) for r in range(3, 10)]) +def test_restrict_to_subspace(r): + + indices_test = np.random.randint(0, r, size=(1,)).tolist() + indices_trial = [0, 2] + + # constant + large_matrix = np.random.normal(size=(r, 1)) + small_matrix = large_matrix[indices_test, 0] + assert np.isclose( + small_matrix, + PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + indices_test=indices_test, + entries=large_matrix, + polynomial_order=0, + ), + ).all() + + operator = PolynomialOperator(polynomial_order=0) + operator.set_entries(large_matrix) + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + operator = AffinePolynomialOperator( + polynomial_order=0, coeffs=1, entries=[large_matrix] + ) + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + # linear + large_matrix = np.random.normal(size=(r, r)) + small_matrix = large_matrix[indices_test, :][:, indices_trial] + assert np.isclose( + small_matrix, + PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + indices_test=indices_test, + entries=large_matrix, + polynomial_order=1, + ), + ).all() + + operator = PolynomialOperator(polynomial_order=1) + operator.entries = large_matrix + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + operator = AffinePolynomialOperator( + polynomial_order=1, coeffs=1, entries=[large_matrix] + ) + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + # quadratic + large_matrix = np.random.normal( + size=( + r, + PolynomialOperator.polynomial_operator_dimension( + r=r, polynomial_order=2 + ), + ) + ) + small_matrix = large_matrix[indices_test, :][:, [0, 3, 5]] + assert np.isclose( + small_matrix, + PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + indices_test=indices_test, + entries=large_matrix, + polynomial_order=2, + ), + ).all() + + operator = PolynomialOperator(polynomial_order=2) + operator.entries = large_matrix + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + operator = AffinePolynomialOperator( + polynomial_order=2, coeffs=1, entries=[large_matrix] + ) + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + # cubic + large_matrix = np.random.normal( + size=( + r, + PolynomialOperator.polynomial_operator_dimension( + r=r, polynomial_order=3 + ), + ) + ) + small_matrix = large_matrix[indices_test, :][:, [0, 4, 7, 9]] + assert np.isclose( + small_matrix, + PolynomialOperator._restrict_matrix_to_subspace( + indices_trial=indices_trial, + indices_test=indices_test, + entries=large_matrix, + polynomial_order=3, + ), + ).all() + + operator = PolynomialOperator(polynomial_order=3) + operator.entries = large_matrix + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + operator = AffinePolynomialOperator( + polynomial_order=3, coeffs=1, entries=[large_matrix] + ) + assert np.isclose( + small_matrix, + operator.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ).entries, + ).all() + + +@pytest.mark.parametrize( + "r_large, r_small, p", + [ + (r_large, r_small, p) + for r_large in range(1, 8) + for r_small in range(1, r_large + 1) + for p in range(1, 4) + ], +) +def test_extend_dimension(r_large, r_small, p): + matrix_original = np.random.uniform( + size=( + r_small, + PolynomialOperator.polynomial_operator_dimension( + r=r_small, polynomial_order=p + ), + ) + ) + + # sample random trial and test samples + indices_trial = np.random.choice( + [*range(r_large)], r_small, replace=False + ).tolist() + indices_test = np.random.choice( + [*range(r_large)], r_small, replace=False + ).tolist() + indices_trial.sort() + indices_test.sort() + + # scale operator up and down + operator = PolynomialOperator( + polynomial_order=p, entries=matrix_original.copy() + ) + operator_extended = operator.extend_to_dimension( + new_r=r_large, indices_test=indices_test, indices_trial=indices_trial + ) + operator_condensed = operator_extended.restrict_to_subspace( + indices_trial=indices_trial, indices_test=indices_test + ) + assert (matrix_original == operator_condensed.entries).all() + + matrix_extended = PolynomialOperator._extend_matrix_to_dimension( + indices_test=indices_test, + indices_trial=indices_trial, + new_r=r_large, + polynomial_order=p, + old_entries=matrix_original, + ) + matrix_condensed = PolynomialOperator._restrict_matrix_to_subspace( + indices_test=indices_test, + indices_trial=indices_trial, + polynomial_order=p, + entries=matrix_extended, + ) + + assert (matrix_condensed == matrix_original).all() + assert (matrix_condensed == operator_condensed.entries).all() + assert (matrix_extended == operator_extended.entries).all() diff --git a/tests/regression/test_basics.py b/tests/regression/test_basics.py index 982bd807..f36d01b5 100644 --- a/tests/regression/test_basics.py +++ b/tests/regression/test_basics.py @@ -1,5 +1,5 @@ # test_basics.py -"""Regression test: linear heat equation, extrapolation to new initial = +"""Regression test: linear heat equation, extrapolation to new initial conditions. """ import os diff --git a/tox.ini b/tox.ini index 648e0aca..410b8433 100644 --- a/tox.ini +++ b/tox.ini @@ -1,10 +1,12 @@ [tox] requires = tox>=4 -env_list = py{39,310,311,312,313} +env_list = py{310,311,312,313,314} [testenv] description = Run unit tests with pytest +passenv = + MPLBACKEND # Fix a bug in GitHub Actions for Windows package = editable deps = pytest>=6.0.2