Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def qr(
r"""
Calculates the QR decomposition of a 2D ``DNDarray``.
Factor the matrix ``A`` as *QR*, where ``Q`` is orthonormal and ``R`` is upper-triangular.
If ``mode = "reduced``, function returns ``QR(Q=Q, R=R)``, if ``mode = "r"`` function returns ``QR(Q=None, R=R)``
If ``mode = "reduced"``, function returns ``QR(Q=Q, R=R)``, if ``mode = "r"`` function returns ``QR(Q=None, R=R)``

This function also works for batches of matrices; in this case, the last two dimensions of the input array are considered as the matrix dimensions.
The output arrays have the same leading batch dimensions as the input array.
Expand Down
10 changes: 5 additions & 5 deletions heat/core/linalg/randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def rsvd(
) -> Union[Tuple[DNDarray, DNDarray, DNDarray], Tuple[DNDarray, DNDarray]]:
r"""
Randomized SVD (rSVD) with prescribed truncation rank `svd_rank`.
If :math:`A = U \operatorname{diag}(S) V^T` is the true SVD of A, this routine computes an approximation for U[:,:svd_rank] (and S[:svd_rank], V[:,:svd_rank]).
If :math:`A = U \operatorname{diag}(S) V^T` is the true SVD of A, this routine computes an approximation for U[:,:svd_rank] (and S[:svd_rank], V.T[:,:svd_rank]).

The accuracy of this approximation depends on the structure of A ("low-rank" is best) and appropriate choice of parameters.

Expand Down Expand Up @@ -130,13 +130,13 @@ def rsvd(
B.resplit_(
None
) # B will be of size ell x n and thus small enough to fit into memory of a single process
U, sigma, V = svd(B) # actually just torch svd as input is not split anymore
U, sigma, Vt = svd(B)
U = matmul(Q, U)[:, :svd_rank]
U.balance_()
S = sigma[:svd_rank]
V = V[:, :svd_rank]
V.balance_()
return U, S, V
Vt = Vt[:svd_rank, :]
Vt.balance_()
return U, S, Vt


def reigh(
Expand Down
112 changes: 30 additions & 82 deletions heat/core/linalg/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def svd(
) -> Tuple[DNDarray, DNDarray, DNDarray]:
"""
Computes the singular value decomposition of a matrix (the input array ``A``).
For an input DNDarray ``A`` of shape ``(M, N)``, the function returns DNDarrays ``U``, ``S``, and ``V`` such that ``A = U @ ht.diag(S) @ V.T``
For an input DNDarray ``A`` of shape ``(M, N)``, the function returns DNDarrays ``U``, ``S``, and ``V.T`` such that ``A = U @ ht.diag(S) @ V.T``
with shapes ``(M, min(M,N))``, ``(min(M, N),)``, and ``(min(M,N),N)``, respectively, in the case that ``compute_uv=True``, or
only the vector containing the singular values ``S`` of shape ``(min(M, N),)`` in the case that ``compute_uv=False``. By definition of the singular value decomposition,
the matrix ``U`` is orthogonal, the matrix ``V`` is orthogonal, and the entries of the vector ``S``are non-negative real numbers.
Expand All @@ -39,7 +39,7 @@ def svd(
full_matrices : bool, optional
currently, only the default value ``False`` is supported. This argument is included for compatibility with NumPy.
compute_uv : bool, optional
if ``True``, the matrices ``U`` and ``V`` are computed and returned together with the singular values ``S``.
if ``True``, the matrices ``U`` and ``V.T`` are computed and returned together with the singular values ``S``.
If ``False``, only the vector ``S`` containing the singular values is returned.
qr_procs_to_merge : int, optional
the number of processes to merge in the tall skinny QR decomposition that is applied if the input array is tall skinny (``M > N``) or short fat (``M < N``).
Expand All @@ -54,7 +54,7 @@ def svd(
Unlike in NumPy, we currently do not support the option ``full_matrices=True``, since this can result in heavy memory consumption (in particular for tall skinny
and short fat matrices) that should be avoided in the context Heat is designed for. If you nevertheless require this feature, please open an issue on GitHub.

The algorithm used for the computation of the singular value depens on the shape of the input array ``A``.
The algorithm used for the computation of the singular value depends on the shape of the input array ``A``.
For tall and skinny matrices (``M > N``), the algorithm is based on the tall-skinny QR decomposition. For the remaining cases we use the approach based on
Zolotarev-Polar Decomposition and a symmetric eigenvalue decomposition based on Zolotarev-Polar Decomposition; see Algorithm 5.3 in:

Expand Down Expand Up @@ -97,49 +97,29 @@ def svd(
f"Array ``A`` must have a datatype of float32 or float64, but has {A.dtype}"
)

def _toDNDarray(array):
"""Returns an unsplit heat DNDarray that inherits properties from the `A` matrix to be decomposed"""
return DNDarray(
array,
tuple(array.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)

if not A.is_distributed():
# this is the non-distributed case
if compute_uv:
U_loc, S_loc, Vt_loc = torch.linalg.svd(A.larray, full_matrices=full_matrices)
U = DNDarray(
U_loc,
tuple(U_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
S = DNDarray(
S_loc,
tuple(S_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
V = DNDarray(
Vt_loc.T,
tuple(Vt_loc.T.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
return U, S, V
U = _toDNDarray(U_loc)
S = _toDNDarray(S_loc)
Vt = _toDNDarray(Vt_loc)
return U, S, Vt
else:
S_loc = torch.linalg.svdvals(A.larray)
S = DNDarray(
S_loc,
tuple(S_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
S = _toDNDarray(S_loc)
return S
elif A.split == 0 and A.lshape_map[:, 0].max().item() >= A.shape[1]:
# this is the distributed, tall skinny case
Expand All @@ -148,60 +128,28 @@ def svd(
# compute full SVD: first full QR, then SVD of R
Q, R = qr(A, mode="reduced", procs_to_merge=qr_procs_to_merge)
Utilde_loc, S_loc, Vt_loc = torch.linalg.svd(R.larray, full_matrices=False)
Utilde = DNDarray(
Utilde_loc,
tuple(Utilde_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
S = DNDarray(
S_loc,
tuple(S_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
V = DNDarray(
Vt_loc.T,
tuple(Vt_loc.T.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
Utilde = _toDNDarray(Utilde_loc)
S = _toDNDarray(S_loc)
Vt = _toDNDarray(Vt_loc)
U = (Utilde.T @ Q.T).T
return U, S, V
return U, S, Vt
else:
# compute only singular values: first only R of QR, then singular values only of R
_, R = qr(A, mode="r", procs_to_merge=qr_procs_to_merge)
S_loc = torch.linalg.svdvals(R.larray)
S = DNDarray(
S_loc,
tuple(S_loc.shape),
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=A.balanced,
)
S = _toDNDarray(S_loc)
return S
elif A.split == 1 and A.lshape_map[:, 1].max().item() >= A.shape[0]:
# this is the distributed, short fat case
# apply the tall skinny SVD to the transpose of A
if compute_uv:
V, S, U = svd(
V, S, Ut = svd(
A.T,
full_matrices=full_matrices,
compute_uv=True,
qr_procs_to_merge=qr_procs_to_merge,
)
return U, S, V
return Ut.T, S, V.T
else:
S = svd(
A.T,
Expand All @@ -217,13 +165,13 @@ def svd(
if A.shape[0] < A.shape[1]:
# Zolo-PD requires A.shape[0] >= A.shape[1], so we need to transpose in this case
if compute_uv:
V, S, U = svd(
V, S, Ut = svd(
A.T,
full_matrices=full_matrices,
compute_uv=True,
qr_procs_to_merge=qr_procs_to_merge,
)
return U, S, V
return Ut.T, S, V.T
else:
S = svd(
A.T,
Expand All @@ -241,4 +189,4 @@ def svd(
if not compute_uv:
return S
else:
return U @ V, S, V
return U @ V, S, V.T
Loading