diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 4ca0c3fc01..70cad9e100 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -10,7 +10,7 @@ from ..manipulations import concatenate from .. import factories from .. import communication -from ..types import float32, float64 +from ..types import float32, float64, complex64 __all__ = ["qr"] @@ -94,11 +94,10 @@ def qr( if procs_to_merge == 0: procs_to_merge = A.comm.size - if A.dtype not in [float32, float64]: + if A.dtype not in [float32, float64, complex64]: raise TypeError(f"Array 'A' must have a datatype of float32 or float64, but has {A.dtype}") QR = collections.namedtuple("QR", "Q, R") - if A.ndim == 3: single_proc_qr = torch.vmap(torch.linalg.qr, in_dims=0, out_dims=0) else: