Skip to content

Commit

Permalink
Merge pull request #175 from ViCCo-Group/cka_fix
Browse files Browse the repository at this point in the history
fixed rbf kernel in cka
  • Loading branch information
LukasMut authored May 19, 2024
2 parents 16940b8 + 0a8789d commit ab01641
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.6.6"
__version__ = "2.6.7"
6 changes: 4 additions & 2 deletions thingsvision/core/cka/cka_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ def linear_kernel(self, X: Array) -> Array:
"""Use a linear kernel for computing the gram matrix."""
return X @ X.T

def rbf_kernel(self, X: Array, sigma: float = 1.0) -> Array:
def rbf_kernel(self, X: Array) -> Array:
"""Use an rbf kernel for computing the gram matrix. Sigma defines the width."""
GX = X @ X.T
KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
if sigma is None:
if self.sigma is None:
mdist = np.median(KX[KX != 0])
sigma = math.sqrt(mdist)
else:
sigma = self.sigma
KX *= -0.5 / sigma**2
KX = np.exp(KX)
return KX
Expand Down
15 changes: 9 additions & 6 deletions thingsvision/core/cka/cka_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
kernel: str,
unbiased: bool = False,
device: str = "cpu",
verbose: bool = False,
sigma: Optional[float] = 1.0,
) -> None:
"""
Expand All @@ -30,7 +31,7 @@ def __init__(
sigma (float) - for 'rbf' kernel sigma defines the width of the Gaussian;
"""
super().__init__(m=m, kernel=kernel, unbiased=unbiased, sigma=sigma)
device = self._check_device(device)
device = self._check_device(device, verbose)
if device == "cpu":
self.hsic = self._hsic
else:
Expand All @@ -39,7 +40,7 @@ def __init__(
self.device = torch.device(device)

@staticmethod
def _check_device(device: str) -> str:
def _check_device(device: str, verbose: bool) -> str:
"""Check whether the selected device is available on current compute node."""
if device.startswith("cuda"):
gpu_index = re.search(r"cuda:(\d+)", device)
Expand All @@ -58,8 +59,8 @@ def _check_device(device: str) -> str:
category=UserWarning,
)
device = "cuda:0"

print(f"\nUsing device: {device}\n")
if verbose:
print(f"\nUsing device: {device}\n")
return device

def centering(self, K: TensorType["m", "m"]) -> TensorType["m", "m"]:
Expand Down Expand Up @@ -104,14 +105,16 @@ def linear_kernel(
return X @ X.T

def rbf_kernel(
self, X: TensorType["m", "d"], sigma: Optional[float] = 1.0
self, X: Union[TensorType["m", "d"], TensorType["m", "p"]]
) -> TensorType["m", "m"]:
"""Use an rbf kernel for computing the gram matrix. Sigma defines the width."""
GX = X @ X.T
KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
if sigma is None:
if self.sigma is None:
mdist = torch.median(KX[KX != 0])
sigma = torch.sqrt(mdist)
else:
sigma = self.sigma
KX *= -0.5 / sigma**2
KX = KX.exp()
return KX
Expand Down
8 changes: 7 additions & 1 deletion thingsvision/core/cka/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def get_cka(
unbiased: bool = False,
sigma: Optional[float] = 1.0,
device: Optional[str] = None,
verbose: Optional[bool] = False,
) -> Union[CKANumPy, CKATorch]:
"""Return a NumPy or PyTorch implementation of CKA."""
assert backend in BACKENDS, f"\nSupported backends are: {BACKENDS}\n"
Expand All @@ -23,6 +24,11 @@ def get_cka(
device, str
), "\nDevice must be set for using PyTorch backend.\n"
cka = CKATorch(
m=m, kernel=kernel, unbiased=unbiased, device=device, sigma=sigma
m=m,
kernel=kernel,
unbiased=unbiased,
device=device,
sigma=sigma,
verbose=verbose,
)
return cka

0 comments on commit ab01641

Please sign in to comment.