Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 69 additions & 33 deletions heat/regression/lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implementation of the LASSO regression
"""

import math
import heat as ht
from heat.core.dndarray import DNDarray
from typing import Union, Optional
Expand All @@ -19,15 +20,22 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator):
.. math:: w\\_=(w_1,w_2,...,w_n), w=(w_0,w_1,w_2,...,w_n),
.. math:: y \\in M(m \\times 1), w \\in M(n \\times 1), X \\in M(m \\times n)

The implementation uses FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for optimization,
as described in Beck & Teboulle (2009): "A Fast Iterative Shrinkage-Thresholding Algorithm
for Linear Inverse Problems".

Parameters
----------
lam : float, optional
Constant that multiplies the L1 term. Default value: 0.1 ``lam = 0.`` is equivalent to an ordinary
least square (OLS). For numerical reasons, using ``lam = 0.,`` with the ``Lasso`` object is not advised.
max_iter : int, optional
The maximum number of iterations. Default value: 100
tol : float, optional. Default value: 1e-8
tol : float, optional. Default value: 1e-6
The tolerance for the optimization.
step_size : float, optional
The step size for the gradient descent step. If None, it will be computed as 1/L where L is the
Lipschitz constant of the gradient (largest eigenvalue of X^T X / m). Default: None

Attributes
----------
Expand All @@ -37,7 +45,7 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator):
intercept_ : float | array, shape (n_targets,)
independent term in decision function.
n_iter_ : int or None | array-like, shape (n_targets,)
number of iterations run by the coordinate descent solver to reach the specified tolerance.
number of iterations run by FISTA to reach the specified tolerance.

Examples
--------
Expand All @@ -48,12 +56,17 @@ class Lasso(ht.RegressionMixin, ht.BaseEstimator):
"""

def __init__(
self, lam: Optional[float] = 0.1, max_iter: Optional[int] = 100, tol: Optional[float] = 1e-6
self,
lam: Optional[float] = 0.1,
max_iter: Optional[int] = 100,
tol: Optional[float] = 1e-6,
step_size: Optional[float] = None,
Comment on lines +60 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are lam, max_iter, and tol allowed to be None? The docstring doesn't mention it.
Since python 3.10, you can also use float | None instead of Optional[float] if it confuses you.

) -> None:
"""Initialize lasso parameters"""
self.__lam = lam
self.max_iter = max_iter
self.tol = tol
self.step_size = step_size
self.__theta = None
self.n_iter = None

Expand Down Expand Up @@ -87,23 +100,23 @@ def theta(self):
"""Returns regularization term lambda"""
return self.__theta

def soft_threshold(self, rho: DNDarray) -> Union[DNDarray, float]:
def soft_threshold(self, x: DNDarray, threshold: float) -> DNDarray:
"""
Soft threshold operator
Vectorized soft threshold operator (proximal operator for L1 norm)

Parameters
----------
rho : DNDarray
Input model data, Shape = (1,)
out : DNDarray or float
Thresholded model data, Shape = (1,)
x : DNDarray
Input data
threshold : float
Threshold value (lambda * step_size for FISTA)

Returns
-------
DNDarray
Thresholded data
"""
if rho < -self.__lam:
return rho + self.__lam
elif rho > self.__lam:
return rho - self.__lam
else:
return 0.0
return ht.sign(x) * ht.maximum(ht.abs(x) - threshold, 0.0)

def rmse(self, gt: DNDarray, yest: DNDarray) -> DNDarray:
"""
Expand All @@ -120,7 +133,8 @@ def rmse(self, gt: DNDarray, yest: DNDarray) -> DNDarray:

def fit(self, x: DNDarray, y: DNDarray) -> None:
"""
Fit lasso model with coordinate descent
Fit lasso model using FISTA (Fast Iterative Shrinkage-Thresholding
Algorithm)

Parameters
----------
Expand All @@ -129,8 +143,8 @@ def fit(self, x: DNDarray, y: DNDarray) -> None:
y : DNDarray
Labels, Shape = (n_samples,)
"""
# Get number of model parameters
_, n = x.shape
# Get number of samples and features
m, n = x.shape

if y.ndim > 2:
raise ValueError(f"y.ndim must <= 2, currently: {y.ndim}")
Expand All @@ -140,28 +154,49 @@ def fit(self, x: DNDarray, y: DNDarray) -> None:
if len(y.shape) == 1:
y = ht.expand_dims(y, axis=1)

# Initialize model parameters
theta = ht.zeros((n, 1), dtype=float, device=x.device)
# Compute step size (1/L where L is Lipschitz constant of gradient)
if self.step_size is None:
# L = largest eigenvalue of (X^T X) / m
# For efficiency, we approximate: L ≈ ||X||_F^2 / m
XtX_norm = ht.linalg.norm(x) ** 2 / m
L = XtX_norm.item()
step = 1.0 / L if L > 0 else 1.0
else:
step = self.step_size

# Initialize parameters (not split - these are model parameters)
theta = ht.zeros((n, 1), dtype=x.dtype, split=None, device=x.device)
y_k = theta.copy() # Extrapolation point
t_k = 1.0 # Momentum parameter

# Looping until max number of iterations or convergence
# FISTA iterations
for i in range(self.max_iter):
theta_old = theta.copy()

# Looping through each coordinate
for j in range(n):
X_j = ht.array(x.larray[:, j : j + 1], is_split=0, device=x.device, comm=x.comm)
# Compute gradient at y_k: (1/m) * X^T (X y_k - y)
residual = x @ y_k - y
gradient = (x.T @ residual) / m

# Gradient descent step
z = y_k - step * gradient

# Proximal step: soft thresholding
# Apply soft thresholding with lambda * step_size
theta_new = self.soft_threshold(z, self.__lam * step)

# Don't regularize the intercept (first element)
theta_new[0] = z[0] # No thresholding for intercept
theta = theta_new

y_est = x @ theta
theta_j = theta.larray[j].item()
# Update momentum parameter (using math.sqrt for scalar)
t_k_new = (1.0 + math.sqrt(1.0 + 4.0 * t_k**2)) / 2.0

rho = (X_j * (y - y_est + theta_j * X_j)).mean()
# Update extrapolation point
y_k = theta + ((t_k - 1.0) / t_k_new) * (theta - theta_old)

# Intercept parameter theta[0] not be regularized
if j == 0:
theta[j] = rho
else:
theta[j] = self.soft_threshold(rho)
t_k = t_k_new

# Check convergence
diff = self.rmse(theta, theta_old)
if self.tol is not None and diff < self.tol:
self.n_iter = i + 1
Expand All @@ -173,7 +208,8 @@ def fit(self, x: DNDarray, y: DNDarray) -> None:

def predict(self, x: DNDarray) -> DNDarray:
"""
Apply lasso model to input data. First row data corresponds to interception
Apply lasso model to input data. First row data corresponds to
interception

Parameters
----------
Expand Down
8 changes: 7 additions & 1 deletion heat/regression/tests/test_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ def test_get_and_set_params(self):
lasso = ht.regression.Lasso()
params = lasso.get_params()

self.assertEqual(params, {"lam": 0.1, "max_iter": 100, "tol": 1e-6})
expected_params = {
"lam": 0.1,
"max_iter": 100,
"tol": 1e-6,
"step_size": None
}
self.assertEqual(params, expected_params)

params["max_iter"] = 200
lasso.set_params(**params)
Expand Down
Loading