Skip to content
Open
4 changes: 2 additions & 2 deletions src/rsatoolbox/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .computations import average_dataset
from .computations import average_dataset_by
from .dataset import Dataset
from .dataset import TemporalDataset
from .dataset import load_dataset
from .dataset import dataset_from_dict
from .computations import average_dataset
from .computations import average_dataset_by
from .noise import cov_from_residuals
from .noise import prec_from_residuals
from .noise import cov_from_measurements
Expand Down
159 changes: 156 additions & 3 deletions src/rsatoolbox/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from warnings import warn
from copy import deepcopy
import numpy as np
from itertools import product
from pandas import DataFrame
from scipy.linalg import sqrtm
from rsatoolbox.data.ops import merge_datasets
from rsatoolbox.data.noise import sigmak_from_measurements
from rsatoolbox.util.data_utils import get_unique_unsorted
from rsatoolbox.util.data_utils import get_unique_inverse
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
Expand Down Expand Up @@ -765,9 +768,9 @@ def time_as_observations(self, by='time') -> Dataset:
for key in self.time_descriptors:
obs_descriptors[key] = np.concatenate((
obs_descriptors[key], np.repeat(
[self.time_descriptors[key][s]
for s in selection],
self.n_obs)),
[self.time_descriptors[key][s]
for s in selection],
self.n_obs)),
axis=0)

dataset = Dataset(measurements=measurements,
Expand Down Expand Up @@ -890,3 +893,153 @@ def merge_subsets(dataset_list):
warn('Deprecated: [rsatoolbox.data.dataset.merge_subsets()]. Replace by '
'[rsatoolbox.data.ops.merge_datasets()]', DeprecationWarning)
return merge_datasets(dataset_list)


class FramedDataset(Dataset):
"""
FramedDataset with added functionality for running framed RSA analyses. Automatically inserts all-zero
and all-c vectors with specified options. A cond_descriptor must be provided to indicate which
obs_descriptor delineates conditions in subsequent RSA analyses; the all-zero and all-c vectors will then
be inserted for each combination of values of the other obs_descriptors (i.e., for any possible
crossvalidation fold). Include a noise covariance matrix to scale the all-c based on the whitened patterns.

Args:
measurements (numpy.ndarray): n_obs x n_channel x time 3d-array,
descriptors (dict): descriptors (metadata)
obs_descriptors (dict): observation descriptors (all
are array-like with shape = (n_obs,...))
channel_descriptors (dict): channel descriptors (all are
array-like with shape = (n_channel,...))
cond_descriptor (str): the obs_descriptor that delineates conditions in subsequent RSA analyses
include_all_zeros (bool): whether to add all-zero vectors to the dataset (default: True)
all_c_scale (float or None): the scaling factor the all-c vector to be inserted; this is the ratio
of the norm of the all-c vector to the mean norm of the stimulus vectors. Put None to omit this vector.
noise (np.ndarray): voxel-by-voxel covariance matrix to be used for scaling the all-c (default: identity matrix)
(n_channel x n_channel)
check_dims (bool): whether to check the dimensions of the descriptors (default: True)


Returns:
FramedDataset object
"""

def __init__(self,
measurements,
descriptors=None,
obs_descriptors=None,
channel_descriptors=None,
cond_descriptor=None,
include_all_zeros=True,
all_c_scale=1.0,
noise=None,
check_dims=True):

if measurements.ndim != 2:
raise AttributeError(
"measurements must be in dimension n_obs x n_channel")
if (cond_descriptor is None) or (cond_descriptor not in obs_descriptors):
raise ValueError("A cond_descriptor must be provided; this should be the obs_descriptor that will"
"be used to define experimental conditions in subsequent RSA analyses.")
if noise is None:
noise = np.eye(measurements.shape[1])
elif (type(noise) is not np.ndarray) or (len(noise.shape) != 2) or (noise.shape[0] != noise.shape[1]):
raise ValueError("Noise must be a square np.ndarray with shape (n_channel x n_channel), or "
"None for identity matrix.")

self.n_obs, self.n_channel = measurements.shape # n_obs based on original number of observations

if check_dims:
check_descriptor_length_error(obs_descriptors,
"obs_descriptors",
self.n_obs
)
check_descriptor_length_error(channel_descriptors,
"channel_descriptors",
self.n_channel
)
# Add the all-zero and all-c vectors.

# Get all combinations of obs_descriptors that aren't the cond_descriptor.
other_descriptors = [key for key in obs_descriptors.keys() if key != cond_descriptor]
if len(other_descriptors) == 0:
other_obs_combs = {}
num_combs = 0
else:
# Make dict of lists with all combinations of other descriptors
unique_values = {key: get_unique_unsorted(obs_descriptors[key]) for key in other_descriptors}
combinations = list(zip(*product(*unique_values.values())))
num_combs = len(combinations[0])
other_obs_combs = {other_descriptors[i]: list(combinations[i]) for i in range(len(other_descriptors))}

measurements_list = [measurements]

if include_all_zeros: # all-zeros inserted as the first pattern
all_zeros_measurements = np.zeros((num_combs, measurements.shape[1]))
measurements_list = [all_zeros_measurements, measurements]
obs_descriptors[cond_descriptor] = ['all_z'] * num_combs + obs_descriptors[cond_descriptor]
for desc in other_descriptors: # fill out remaining obs_descriptors
obs_descriptors[desc] = other_obs_combs[desc] + obs_descriptors[desc]

if all_c_scale is not None: # tune the value of c, insert as last pattern
all_ones = np.ones((1, self.n_channel))
noise_sqrt = sqrtm(noise)
measurements_whitened = measurements @ noise_sqrt
all_ones_whitened = all_ones @ noise_sqrt
mean_norm = np.mean(np.linalg.norm(measurements_whitened, axis=1))
c_val = mean_norm * all_c_scale / np.linalg.norm(all_ones_whitened)
all_c_measurements = c_val * np.ones((num_combs, self.n_channel))
measurements_list.append(all_c_measurements)
obs_descriptors[cond_descriptor].extend(['all_c'] * num_combs)
for desc in other_descriptors: # fill out remaining obs_descriptors
obs_descriptors[desc].extend(other_obs_combs[desc])

self.measurements = np.vstack(measurements_list)

descriptors['noise'] = [noise]
descriptors['is_framed'] = [True]
descriptors['cond_descriptor'] = [cond_descriptor]
descriptors['all_c_scale'] = [all_c_scale]
descriptors['include_all_zeros'] = [include_all_zeros]
descriptors['include_all_c'] = [all_c_scale is not None]
descriptors['n_framing_patterns'] = [int(include_all_zeros) + int(all_c_scale is not None)]
self.descriptors = parse_input_descriptor(descriptors)
self.obs_descriptors = parse_input_descriptor(obs_descriptors)
self.channel_descriptors = parse_input_descriptor(channel_descriptors)

def get_sigmak(self, from_data=False, cv_desc=None):
"""Returns the sigma_k matrix for this dataset; if from_data is False, this will be the identity matrix
with the rows/columns corresponding to the all-zero and all-c patterns set to zero. If set to true,
sigma_k will be estimated from the data, and a cv_descriptor must be provided."""
if not from_data:
framed_inds, n_cond = self._get_framed_inds()
sigma_k = np.eye(n_cond)
sigma_k[framed_inds, framed_inds] = 0
else:
if cv_desc is None:
raise ValueError("If estimating sigma_k from data, a cv_descriptor must be provided.")
sigma_k = sigmak_from_measurements(self,
self.descriptors['cond_descriptor'],
cv_desc,
self.descriptors['noise'])
return sigma_k

def get_framed_rdm_mask(self):
"""Returns a binary np.ndarray mask with 1 for all RDM entries involving the all-zeros or all-c, and
0 otherwise, with RDM entries defined by cond_descriptor. This is useful when using
whitened RDM comparators in case we wish to only use the distances involving the framed patterns
when computing V."""
framed_inds, n_cond = self._get_framed_inds()
mask = np.zeros((n_cond, n_cond))
mask[framed_inds, :] = mask[:, framed_inds] = 1
return mask

def _get_framed_inds(self):
"""Utility function for getting the indices corresponding to the framing patterns; returns n_cond
as well for convenience."""
framed_inds = []
cond_order = list(get_unique_unsorted(self.obs_descriptors[self.descriptors['cond_descriptor']]))
if 'all_z' in cond_order:
framed_inds.append(cond_order.index('all_z'))
if 'all_c' in cond_order:
framed_inds.append(cond_order.index('all_c'))
return framed_inds, len(cond_order)
74 changes: 71 additions & 3 deletions src/rsatoolbox/data/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from rsatoolbox.data import average_dataset_by
from rsatoolbox.util.data_utils import get_unique_inverse
from scipy.linalg import sqrtm


def _check_demean(matrix):
Expand Down Expand Up @@ -147,7 +148,7 @@ def _covariance_eye(matrix, dof):
b2 = min(d2, b2)
# shrink covariance matrix
s_shrink = b2 / d2 * m * np.eye(s.shape[0]) \
+ (d2-b2) / d2 * s
+ (d2 - b2) / d2 * s
# correction for degrees of freedom
s_shrink = s_shrink * matrix.shape[0] / dof
return s_shrink
Expand Down Expand Up @@ -189,11 +190,11 @@ def _covariance_diag(matrix, dof):
s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / (matrix.shape[0] - 1)
s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / (matrix.shape[0] - 1)
var_hat = matrix.shape[0] / dof ** 2 \
* (s2_mean - s_mean ** 2)
* (s2_mean - s_mean ** 2)
mask = ~np.eye(s.shape[0], dtype=bool)
lamb = np.sum(var_hat[mask]) / np.sum(s_mean[mask] ** 2)
lamb = max(min(lamb, 1), 0)
scaling = np.eye(s.shape[0]) + (1-lamb) * mask
scaling = np.eye(s.shape[0]) + (1 - lamb) * mask
s_shrink = s * scaling
return s_shrink

Expand Down Expand Up @@ -434,3 +435,70 @@ def prec_from_unbalanced(dataset, obs_desc, dof=None, method='shrinkage_diag'):
else:
prec = np.linalg.inv(cov)
return prec


def sigmak_from_measurements(dataset, obs_desc, cv_descriptor, noise=None):
"""
Estimates sigma_k, the matrix encoding the noise variance/covariance among the k conditions when two
conditions are measured in the same partition (e.g., due to shared fMRI noise
from the sluggishness of the HRF when two conditions are adjacent in time). If a noise matrix is provided,
prewhitening is performed on the data prior to computing sigma_k (make sure to do this if using
Mahalanobis or crossnobis distance). Assumes that sigma_k is constant across partitions, implementing
equation 36 from Diedrichsen et al. (2016), "On the distribution of cross-validated Mahalanobis distances."

Args:
dataset(data.Dataset):
rsatoolbox Dataset object
obs_desc (String):
descriptor defining experimental conditions
cv_desc (String):
descriptor defining crossvalidation folds/partitions
noise (numpy.ndarray):
dataset.n_channel x dataset.n_channel
precision matrix for noise between channels
default: identity matrix, i.e. euclidean distance

Returns:
numpy.ndarray: sigma_k: noise covariance matrix over conditions
n_conditions x n_conditions

"""

n_channels = dataset.n_channel
if noise is None:
noise = np.eye(n_channels)
else:
if noise.shape != (n_channels, n_channels):
raise ValueError("noise must have shape n_channel x n_channel")
noise_sqrt = sqrtm(noise) # take matrix square root to get whitening matrix
dataset_whitened = dataset.copy()
dataset_whitened.measurements = dataset.measurements @ noise_sqrt

# Compute mean patterns per condition
U_mean, conds, _ = average_dataset_by(dataset_whitened, obs_desc)
n_cond = len(conds)
cv_folds = np.unique(np.array(dataset_whitened.obs_descriptors[cv_descriptor]))

pair_counts = np.zeros((n_cond, n_cond)) # tally how many partitions each condition pair occurs in

# Compute sigma_k per partition, then average
sigma_ks = []

for fold in cv_folds:
U_fold = np.zeros(U_mean.shape) * np.nan # fold activations
dataset_fold = dataset_whitened.subset_obs(cv_descriptor, fold)
dataset_fold, fold_conds, _ = average_dataset_by(dataset_fold, obs_desc)
# Get indices mapping from subsetted conditions to full set, fill out U_fold
inds = [np.where(conds == c)[0][0] for c in fold_conds]
U_fold[inds, :] = dataset_fold
sigma_k_fold = ((U_fold - U_mean) @ ((U_fold - U_mean).T))
sigma_ks.append(sigma_k_fold)

# Increment pair counts for all condition pairs present in this fold
pair_counts += np.isfinite(sigma_k_fold).astype(int)

# Finally add all sigma_ks and divide by pair counts to get average
sigma_k = np.nansum(sigma_ks, axis=0) / (pair_counts - 1) / n_channels
# Any pairs that never occurred together should be set to zero
sigma_k[np.isnan(sigma_k)] = 0.0
return sigma_k
Loading
Loading