diff --git a/heat/regression/lasso.py b/heat/regression/lasso.py index 8e9b2d45b7..ea9318821c 100644 --- a/heat/regression/lasso.py +++ b/heat/regression/lasso.py @@ -2,6 +2,7 @@ Implementation of the LASSO regression """ +import math import heat as ht from heat.core.dndarray import DNDarray from typing import Union, Optional @@ -19,6 +20,10 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator): .. math:: w\\_=(w_1,w_2,...,w_n), w=(w_0,w_1,w_2,...,w_n), .. math:: y \\in M(m \\times 1), w \\in M(n \\times 1), X \\in M(m \\times n) + The implementation uses FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for optimization, + as described in Beck & Teboulle (2009): "A Fast Iterative Shrinkage-Thresholding Algorithm + for Linear Inverse Problems". + Parameters ---------- lam : float, optional @@ -26,8 +31,11 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator): least square (OLS). For numerical reasons, using ``lam = 0.,`` with the ``Lasso`` object is not advised. max_iter : int, optional The maximum number of iterations. Default value: 100 - tol : float, optional. Default value: 1e-8 + tol : float, optional. Default value: 1e-6 The tolerance for the optimization. + step_size : float, optional + The step size for the gradient descent step. If None, it will be computed as 1/L where L is the + Lipschitz constant of the gradient (largest eigenvalue of X^T X / m). Default: None Attributes ---------- @@ -37,7 +45,7 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator): intercept_ : float | array, shape (n_targets,) independent term in decision function. n_iter_ : int or None | array-like, shape (n_targets,) - number of iterations run by the coordinate descent solver to reach the specified tolerance. + number of iterations run by FISTA to reach the specified tolerance. Examples -------- @@ -48,12 +56,17 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator): """ def __init__( - self, lam: Optional[float] = 0.1, max_iter: Optional[int] = 100, tol: Optional[float] = 1e-6 + self, + lam: Optional[float] = 0.1, + max_iter: Optional[int] = 100, + tol: Optional[float] = 1e-6, + step_size: Optional[float] = None, ) -> None: """Initialize lasso parameters""" self.__lam = lam self.max_iter = max_iter self.tol = tol + self.step_size = step_size self.__theta = None self.n_iter = None @@ -87,23 +100,23 @@ def theta(self): """Returns regularization term lambda""" return self.__theta - def soft_threshold(self, rho: DNDarray) -> Union[DNDarray, float]: + def soft_threshold(self, x: DNDarray, threshold: float) -> DNDarray: """ - Soft threshold operator + Vectorized soft threshold operator (proximal operator for L1 norm) Parameters ---------- - rho : DNDarray - Input model data, Shape = (1,) - out : DNDarray or float - Thresholded model data, Shape = (1,) + x : DNDarray + Input data + threshold : float + Threshold value (lambda * step_size for FISTA) + + Returns + ------- + DNDarray + Thresholded data """ - if rho < -self.__lam: - return rho + self.__lam - elif rho > self.__lam: - return rho - self.__lam - else: - return 0.0 + return ht.sign(x) * ht.maximum(ht.abs(x) - threshold, 0.0) def rmse(self, gt: DNDarray, yest: DNDarray) -> DNDarray: """ @@ -120,7 +133,8 @@ def rmse(self, gt: DNDarray, yest: DNDarray) -> DNDarray: def fit(self, x: DNDarray, y: DNDarray) -> None: """ - Fit lasso model with coordinate descent + Fit lasso model using FISTA (Fast Iterative Shrinkage-Thresholding + Algorithm) Parameters ---------- @@ -129,8 +143,8 @@ def fit(self, x: DNDarray, y: DNDarray) -> None: y : DNDarray Labels, Shape = (n_samples,) """ - # Get number of model parameters - _, n = x.shape + # Get number of samples and features + m, n = x.shape if y.ndim > 2: raise ValueError(f"y.ndim must <= 2, currently: {y.ndim}") @@ -140,28 +154,49 @@ def fit(self, x: DNDarray, y: DNDarray) -> None: if len(y.shape) == 1: y = ht.expand_dims(y, axis=1) - # Initialize model parameters - theta = ht.zeros((n, 1), dtype=float, device=x.device) + # Compute step size (1/L where L is Lipschitz constant of gradient) + if self.step_size is None: + # L = largest eigenvalue of (X^T X) / m + # For efficiency, we approximate: L ≈ ||X||_F^2 / m + XtX_norm = ht.linalg.norm(x) ** 2 / m + L = XtX_norm.item() + step = 1.0 / L if L > 0 else 1.0 + else: + step = self.step_size + + # Initialize parameters (not split - these are model parameters) + theta = ht.zeros((n, 1), dtype=x.dtype, split=None, device=x.device) + y_k = theta.copy() # Extrapolation point + t_k = 1.0 # Momentum parameter - # Looping until max number of iterations or convergence + # FISTA iterations for i in range(self.max_iter): theta_old = theta.copy() - # Looping through each coordinate - for j in range(n): - X_j = ht.array(x.larray[:, j : j + 1], is_split=0, device=x.device, comm=x.comm) + # Compute gradient at y_k: (1/m) * X^T (X y_k - y) + residual = x @ y_k - y + gradient = (x.T @ residual) / m + + # Gradient descent step + z = y_k - step * gradient + + # Proximal step: soft thresholding + # Apply soft thresholding with lambda * step_size + theta_new = self.soft_threshold(z, self.__lam * step) + + # Don't regularize the intercept (first element) + theta_new[0] = z[0] # No thresholding for intercept + theta = theta_new - y_est = x @ theta - theta_j = theta.larray[j].item() + # Update momentum parameter (using math.sqrt for scalar) + t_k_new = (1.0 + math.sqrt(1.0 + 4.0 * t_k**2)) / 2.0 - rho = (X_j * (y - y_est + theta_j * X_j)).mean() + # Update extrapolation point + y_k = theta + ((t_k - 1.0) / t_k_new) * (theta - theta_old) - # Intercept parameter theta[0] not be regularized - if j == 0: - theta[j] = rho - else: - theta[j] = self.soft_threshold(rho) + t_k = t_k_new + # Check convergence diff = self.rmse(theta, theta_old) if self.tol is not None and diff < self.tol: self.n_iter = i + 1 @@ -173,7 +208,8 @@ def fit(self, x: DNDarray, y: DNDarray) -> None: def predict(self, x: DNDarray) -> DNDarray: """ - Apply lasso model to input data. First row data corresponds to interception + Apply lasso model to input data. First row data corresponds to + interception Parameters ---------- diff --git a/heat/regression/tests/test_lasso.py b/heat/regression/tests/test_lasso.py index 8b2ed6908f..0ed9b0dbc6 100644 --- a/heat/regression/tests/test_lasso.py +++ b/heat/regression/tests/test_lasso.py @@ -14,7 +14,13 @@ def test_get_and_set_params(self): lasso = ht.regression.Lasso() params = lasso.get_params() - self.assertEqual(params, {"lam": 0.1, "max_iter": 100, "tol": 1e-6}) + expected_params = { + "lam": 0.1, + "max_iter": 100, + "tol": 1e-6, + "step_size": None + } + self.assertEqual(params, expected_params) params["max_iter"] = 200 lasso.set_params(**params)