Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/rsatoolbox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .noise import prec_from_measurements
from .noise import cov_from_unbalanced
from .noise import prec_from_unbalanced
from .noise import sigmak_from_measurements
74 changes: 71 additions & 3 deletions src/rsatoolbox/data/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from rsatoolbox.data import average_dataset_by
from rsatoolbox.util.data_utils import get_unique_inverse
from scipy.linalg import sqrtm


def _check_demean(matrix):
Expand Down Expand Up @@ -147,7 +148,7 @@ def _covariance_eye(matrix, dof):
b2 = min(d2, b2)
# shrink covariance matrix
s_shrink = b2 / d2 * m * np.eye(s.shape[0]) \
+ (d2-b2) / d2 * s
+ (d2 - b2) / d2 * s
# correction for degrees of freedom
s_shrink = s_shrink * matrix.shape[0] / dof
return s_shrink
Expand Down Expand Up @@ -189,11 +190,11 @@ def _covariance_diag(matrix, dof):
s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / (matrix.shape[0] - 1)
s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / (matrix.shape[0] - 1)
var_hat = matrix.shape[0] / dof ** 2 \
* (s2_mean - s_mean ** 2)
* (s2_mean - s_mean ** 2)
mask = ~np.eye(s.shape[0], dtype=bool)
lamb = np.sum(var_hat[mask]) / np.sum(s_mean[mask] ** 2)
lamb = max(min(lamb, 1), 0)
scaling = np.eye(s.shape[0]) + (1-lamb) * mask
scaling = np.eye(s.shape[0]) + (1 - lamb) * mask
s_shrink = s * scaling
return s_shrink

Expand Down Expand Up @@ -434,3 +435,70 @@ def prec_from_unbalanced(dataset, obs_desc, dof=None, method='shrinkage_diag'):
else:
prec = np.linalg.inv(cov)
return prec


def sigmak_from_measurements(dataset, obs_descriptor, cv_descriptor, noise=None):
"""
Estimates sigma_k, the matrix encoding the noise variance/covariance among the k conditions when two
conditions are measured in the same partition (e.g., due to shared fMRI noise
from the sluggishness of the HRF when two conditions are adjacent in time). If a noise matrix is provided,
prewhitening is performed on the data before computing sigma_k (make sure to do this if using
Mahalanobis or crossnobis distance). Assumes that sigma_k is constant across partitions, implementing
equation 36 from Diedrichsen et al. (2016), "On the distribution of cross-validated Mahalanobis distances."

Args:
dataset(data.Dataset):
rsatoolbox Dataset object
obs_descriptor (String):
descriptor defining experimental conditions
cv_descriptor (String):
descriptor defining crossvalidation folds/partitions
noise (numpy.ndarray):
dataset.n_channel x dataset.n_channel
precision matrix for noise between channels
default: identity matrix, i.e. euclidean distance

Returns:
numpy.ndarray: sigma_k: noise covariance matrix over conditions
n_conditions x n_conditions

"""

n_channels = dataset.n_channel
if noise is None:
noise = np.eye(n_channels)
else:
if noise.shape != (n_channels, n_channels):
raise ValueError("noise must have shape n_channel x n_channel")
noise_sqrt = sqrtm(noise) # take matrix square root to get whitening matrix
dataset_whitened = dataset.copy()
dataset_whitened.measurements = dataset.measurements @ noise_sqrt

# Compute mean patterns per condition
U_mean, conds, _ = average_dataset_by(dataset_whitened, obs_descriptor)
n_cond = len(conds)
cv_folds = np.unique(np.array(dataset_whitened.obs_descriptors[cv_descriptor]))

pair_counts = np.zeros((n_cond, n_cond)) # tally how many partitions each condition pair occurs in

# Compute sigma_k per partition, then average
sigma_ks = []

for fold in cv_folds:
U_fold = np.zeros(U_mean.shape) * np.nan # fold activations
dataset_fold = dataset_whitened.subset_obs(cv_descriptor, fold)
dataset_fold, fold_conds, _ = average_dataset_by(dataset_fold, obs_descriptor)
# Get indices mapping from subsetted conditions to full set, fill out U_fold
inds = [np.where(conds == c)[0][0] for c in fold_conds]
U_fold[inds, :] = dataset_fold
sigma_k_fold = ((U_fold - U_mean) @ ((U_fold - U_mean).T))
sigma_ks.append(sigma_k_fold)

# Increment pair counts for all condition pairs present in this fold
pair_counts += np.isfinite(sigma_k_fold).astype(int)

# Finally add all sigma_ks and divide by pair counts to get average
sigma_k = np.nansum(sigma_ks, axis=0) / (pair_counts - 1) / n_channels
# Any pairs that never occurred together should be set to zero
sigma_k[np.isnan(sigma_k)] = 0.0
return sigma_k
56 changes: 52 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import rsatoolbox.data as rsd
import numpy as np
from numpy.testing import assert_array_equal
from scipy import stats


class TestData(unittest.TestCase):
Expand Down Expand Up @@ -403,7 +404,7 @@ def test_temporaldataset_time_as_observations(self):

def test_temporaldataset_time_as_channels(self):
from rsatoolbox.data.dataset import TemporalDataset
measurements = np.zeros((3, 2, 4)) # 3 trials, 2 channels, 4 timepoints
measurements = np.zeros((3, 2, 4)) # 3 trials, 2 channels, 4 timepoints
des = {'session': 0, 'subj': 0}
obs_des = {'conds': np.array([0, 1, 1])}
chn_des = {'electrode': np.array(['A1', 'B2'])}
Expand All @@ -418,8 +419,8 @@ def test_temporaldataset_time_as_channels(self):
)
data = data_temporal.time_as_channels()
self.assertEqual(data.n_obs, 3)
self.assertEqual(data.n_channel, 2*4)
self.assertEqual(len(data.channel_descriptors['time']), 2*4)
self.assertEqual(data.n_channel, 2 * 4)
self.assertEqual(len(data.channel_descriptors['time']), 2 * 4)
assert_array_equal(
data.channel_descriptors['time'],
np.concatenate([tim_des['time'], tim_des['time']])
Expand All @@ -428,7 +429,7 @@ def test_temporaldataset_time_as_channels(self):
data.channel_descriptors['time_formatted'],
tim_des['time_formatted'] + tim_des['time_formatted']
)
self.assertEqual(len(data.channel_descriptors['electrode']), 2*4)
self.assertEqual(len(data.channel_descriptors['electrode']), 2 * 4)
assert_array_equal(
data.channel_descriptors['electrode'],
['A1', 'A1', 'A1', 'A1', 'B2', 'B2', 'B2', 'B2']
Expand Down Expand Up @@ -582,6 +583,53 @@ def test_equal(self):
np.testing.assert_allclose(cov1, cov2)


class TestSigmaK(unittest.TestCase):
def setUp(self):
self.rng = np.random.default_rng(42)
self.n_voxels = 10
self.n_cond = 3
self.n_fold = 1000
self.n_patterns = self.n_cond * self.n_fold
self.sigma_k_true = np.array([[1.0, 0.5, 0.2],
[0.5, 2.0, 0.3],
[0.2, 0.3, 3.0]])
self.measurements_full = stats.matrix_normal(
mean=np.zeros((self.n_cond, self.n_voxels)),
rowcov=self.sigma_k_true, colcov=np.eye(self.n_voxels), seed=42).rvs(self.n_fold).reshape(
self.n_fold * self.n_cond, self.n_voxels)
self.obs_full = np.tile(np.arange(self.n_cond), self.n_fold)
self.fold_full = np.repeat(np.arange(self.n_fold), self.n_cond)
self.dataset_full = rsd.Dataset(
self.measurements_full,
obs_descriptors={'obs': self.obs_full,
'fold': self.fold_full})
# Discard 1000 trials to make sure nothing breaks
trials_to_remove = np.sort(self.rng.choice(np.arange(self.n_cond, self.n_patterns),
size=1000,
replace=False))
self.dataset_subset = rsd.Dataset(
np.delete(self.measurements_full, trials_to_remove, axis=0),
obs_descriptors={
'obs': np.delete(self.obs_full, trials_to_remove, axis=0),
'fold': np.delete(self.fold_full, trials_to_remove, axis=0)}
)

def test_shape(self):
from rsatoolbox.data import sigmak_from_measurements
sigmak = sigmak_from_measurements(self.dataset_full, 'obs', 'fold')
np.testing.assert_equal(sigmak.shape, [self.n_cond, self.n_cond])

def test_values_full(self):
from rsatoolbox.data import sigmak_from_measurements
sigmak = sigmak_from_measurements(self.dataset_full, 'obs', 'fold')
np.testing.assert_allclose(sigmak, self.sigma_k_true, rtol=0.2)

def test_values_unbalanced(self):
from rsatoolbox.data import sigmak_from_measurements
sigmak = sigmak_from_measurements(self.dataset_subset, 'obs', 'fold')
np.testing.assert_allclose(sigmak, self.sigma_k_true, rtol=0.2)


class TestSave(unittest.TestCase):

def test_dict_conversion(self):
Expand Down
Loading