diff --git a/src/rsatoolbox/data/__init__.py b/src/rsatoolbox/data/__init__.py index d66379d6..4fed4a08 100644 --- a/src/rsatoolbox/data/__init__.py +++ b/src/rsatoolbox/data/__init__.py @@ -1,9 +1,9 @@ +from .computations import average_dataset +from .computations import average_dataset_by from .dataset import Dataset from .dataset import TemporalDataset from .dataset import load_dataset from .dataset import dataset_from_dict -from .computations import average_dataset -from .computations import average_dataset_by from .noise import cov_from_residuals from .noise import prec_from_residuals from .noise import cov_from_measurements diff --git a/src/rsatoolbox/data/dataset.py b/src/rsatoolbox/data/dataset.py index e565bc28..cdab58ea 100644 --- a/src/rsatoolbox/data/dataset.py +++ b/src/rsatoolbox/data/dataset.py @@ -11,8 +11,11 @@ from warnings import warn from copy import deepcopy import numpy as np +from itertools import product from pandas import DataFrame +from scipy.linalg import sqrtm from rsatoolbox.data.ops import merge_datasets +from rsatoolbox.data.noise import sigmak_from_measurements from rsatoolbox.util.data_utils import get_unique_unsorted from rsatoolbox.util.data_utils import get_unique_inverse from rsatoolbox.util.descriptor_utils import check_descriptor_length_error @@ -765,9 +768,9 @@ def time_as_observations(self, by='time') -> Dataset: for key in self.time_descriptors: obs_descriptors[key] = np.concatenate(( obs_descriptors[key], np.repeat( - [self.time_descriptors[key][s] - for s in selection], - self.n_obs)), + [self.time_descriptors[key][s] + for s in selection], + self.n_obs)), axis=0) dataset = Dataset(measurements=measurements, @@ -890,3 +893,153 @@ def merge_subsets(dataset_list): warn('Deprecated: [rsatoolbox.data.dataset.merge_subsets()]. Replace by ' '[rsatoolbox.data.ops.merge_datasets()]', DeprecationWarning) return merge_datasets(dataset_list) + + +class FramedDataset(Dataset): + """ + FramedDataset with added functionality for running framed RSA analyses. Automatically inserts all-zero + and all-c vectors with specified options. A cond_descriptor must be provided to indicate which + obs_descriptor delineates conditions in subsequent RSA analyses; the all-zero and all-c vectors will then + be inserted for each combination of values of the other obs_descriptors (i.e., for any possible + crossvalidation fold). Include a noise covariance matrix to scale the all-c based on the whitened patterns. + + Args: + measurements (numpy.ndarray): n_obs x n_channel x time 3d-array, + descriptors (dict): descriptors (metadata) + obs_descriptors (dict): observation descriptors (all + are array-like with shape = (n_obs,...)) + channel_descriptors (dict): channel descriptors (all are + array-like with shape = (n_channel,...)) + cond_descriptor (str): the obs_descriptor that delineates conditions in subsequent RSA analyses + include_all_zeros (bool): whether to add all-zero vectors to the dataset (default: True) + all_c_scale (float or None): the scaling factor the all-c vector to be inserted; this is the ratio + of the norm of the all-c vector to the mean norm of the stimulus vectors. Put None to omit this vector. + noise (np.ndarray): voxel-by-voxel covariance matrix to be used for scaling the all-c (default: identity matrix) + (n_channel x n_channel) + check_dims (bool): whether to check the dimensions of the descriptors (default: True) + + + Returns: + FramedDataset object + """ + + def __init__(self, + measurements, + descriptors=None, + obs_descriptors=None, + channel_descriptors=None, + cond_descriptor=None, + include_all_zeros=True, + all_c_scale=1.0, + noise=None, + check_dims=True): + + if measurements.ndim != 2: + raise AttributeError( + "measurements must be in dimension n_obs x n_channel") + if (cond_descriptor is None) or (cond_descriptor not in obs_descriptors): + raise ValueError("A cond_descriptor must be provided; this should be the obs_descriptor that will" + "be used to define experimental conditions in subsequent RSA analyses.") + if noise is None: + noise = np.eye(measurements.shape[1]) + elif (type(noise) is not np.ndarray) or (len(noise.shape) != 2) or (noise.shape[0] != noise.shape[1]): + raise ValueError("Noise must be a square np.ndarray with shape (n_channel x n_channel), or " + "None for identity matrix.") + + self.n_obs, self.n_channel = measurements.shape # n_obs based on original number of observations + + if check_dims: + check_descriptor_length_error(obs_descriptors, + "obs_descriptors", + self.n_obs + ) + check_descriptor_length_error(channel_descriptors, + "channel_descriptors", + self.n_channel + ) + # Add the all-zero and all-c vectors. + + # Get all combinations of obs_descriptors that aren't the cond_descriptor. + other_descriptors = [key for key in obs_descriptors.keys() if key != cond_descriptor] + if len(other_descriptors) == 0: + other_obs_combs = {} + num_combs = 0 + else: + # Make dict of lists with all combinations of other descriptors + unique_values = {key: get_unique_unsorted(obs_descriptors[key]) for key in other_descriptors} + combinations = list(zip(*product(*unique_values.values()))) + num_combs = len(combinations[0]) + other_obs_combs = {other_descriptors[i]: list(combinations[i]) for i in range(len(other_descriptors))} + + measurements_list = [measurements] + + if include_all_zeros: # all-zeros inserted as the first pattern + all_zeros_measurements = np.zeros((num_combs, measurements.shape[1])) + measurements_list = [all_zeros_measurements, measurements] + obs_descriptors[cond_descriptor] = ['all_z'] * num_combs + obs_descriptors[cond_descriptor] + for desc in other_descriptors: # fill out remaining obs_descriptors + obs_descriptors[desc] = other_obs_combs[desc] + obs_descriptors[desc] + + if all_c_scale is not None: # tune the value of c, insert as last pattern + all_ones = np.ones((1, self.n_channel)) + noise_sqrt = sqrtm(noise) + measurements_whitened = measurements @ noise_sqrt + all_ones_whitened = all_ones @ noise_sqrt + mean_norm = np.mean(np.linalg.norm(measurements_whitened, axis=1)) + c_val = mean_norm * all_c_scale / np.linalg.norm(all_ones_whitened) + all_c_measurements = c_val * np.ones((num_combs, self.n_channel)) + measurements_list.append(all_c_measurements) + obs_descriptors[cond_descriptor].extend(['all_c'] * num_combs) + for desc in other_descriptors: # fill out remaining obs_descriptors + obs_descriptors[desc].extend(other_obs_combs[desc]) + + self.measurements = np.vstack(measurements_list) + + descriptors['noise'] = [noise] + descriptors['is_framed'] = [True] + descriptors['cond_descriptor'] = [cond_descriptor] + descriptors['all_c_scale'] = [all_c_scale] + descriptors['include_all_zeros'] = [include_all_zeros] + descriptors['include_all_c'] = [all_c_scale is not None] + descriptors['n_framing_patterns'] = [int(include_all_zeros) + int(all_c_scale is not None)] + self.descriptors = parse_input_descriptor(descriptors) + self.obs_descriptors = parse_input_descriptor(obs_descriptors) + self.channel_descriptors = parse_input_descriptor(channel_descriptors) + + def get_sigmak(self, from_data=False, cv_desc=None): + """Returns the sigma_k matrix for this dataset; if from_data is False, this will be the identity matrix + with the rows/columns corresponding to the all-zero and all-c patterns set to zero. If set to true, + sigma_k will be estimated from the data, and a cv_descriptor must be provided.""" + if not from_data: + framed_inds, n_cond = self._get_framed_inds() + sigma_k = np.eye(n_cond) + sigma_k[framed_inds, framed_inds] = 0 + else: + if cv_desc is None: + raise ValueError("If estimating sigma_k from data, a cv_descriptor must be provided.") + sigma_k = sigmak_from_measurements(self, + self.descriptors['cond_descriptor'], + cv_desc, + self.descriptors['noise']) + return sigma_k + + def get_framed_rdm_mask(self): + """Returns a binary np.ndarray mask with 1 for all RDM entries involving the all-zeros or all-c, and + 0 otherwise, with RDM entries defined by cond_descriptor. This is useful when using + whitened RDM comparators in case we wish to only use the distances involving the framed patterns + when computing V.""" + framed_inds, n_cond = self._get_framed_inds() + mask = np.zeros((n_cond, n_cond)) + mask[framed_inds, :] = mask[:, framed_inds] = 1 + return mask + + def _get_framed_inds(self): + """Utility function for getting the indices corresponding to the framing patterns; returns n_cond + as well for convenience.""" + framed_inds = [] + cond_order = list(get_unique_unsorted(self.obs_descriptors[self.descriptors['cond_descriptor']])) + if 'all_z' in cond_order: + framed_inds.append(cond_order.index('all_z')) + if 'all_c' in cond_order: + framed_inds.append(cond_order.index('all_c')) + return framed_inds, len(cond_order) diff --git a/src/rsatoolbox/data/noise.py b/src/rsatoolbox/data/noise.py index a03d623c..9ea5ce3b 100755 --- a/src/rsatoolbox/data/noise.py +++ b/src/rsatoolbox/data/noise.py @@ -10,6 +10,7 @@ import numpy as np from rsatoolbox.data import average_dataset_by from rsatoolbox.util.data_utils import get_unique_inverse +from scipy.linalg import sqrtm def _check_demean(matrix): @@ -147,7 +148,7 @@ def _covariance_eye(matrix, dof): b2 = min(d2, b2) # shrink covariance matrix s_shrink = b2 / d2 * m * np.eye(s.shape[0]) \ - + (d2-b2) / d2 * s + + (d2 - b2) / d2 * s # correction for degrees of freedom s_shrink = s_shrink * matrix.shape[0] / dof return s_shrink @@ -189,11 +190,11 @@ def _covariance_diag(matrix, dof): s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / (matrix.shape[0] - 1) s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / (matrix.shape[0] - 1) var_hat = matrix.shape[0] / dof ** 2 \ - * (s2_mean - s_mean ** 2) + * (s2_mean - s_mean ** 2) mask = ~np.eye(s.shape[0], dtype=bool) lamb = np.sum(var_hat[mask]) / np.sum(s_mean[mask] ** 2) lamb = max(min(lamb, 1), 0) - scaling = np.eye(s.shape[0]) + (1-lamb) * mask + scaling = np.eye(s.shape[0]) + (1 - lamb) * mask s_shrink = s * scaling return s_shrink @@ -434,3 +435,70 @@ def prec_from_unbalanced(dataset, obs_desc, dof=None, method='shrinkage_diag'): else: prec = np.linalg.inv(cov) return prec + + +def sigmak_from_measurements(dataset, obs_desc, cv_descriptor, noise=None): + """ + Estimates sigma_k, the matrix encoding the noise variance/covariance among the k conditions when two + conditions are measured in the same partition (e.g., due to shared fMRI noise + from the sluggishness of the HRF when two conditions are adjacent in time). If a noise matrix is provided, + prewhitening is performed on the data prior to computing sigma_k (make sure to do this if using + Mahalanobis or crossnobis distance). Assumes that sigma_k is constant across partitions, implementing + equation 36 from Diedrichsen et al. (2016), "On the distribution of cross-validated Mahalanobis distances." + + Args: + dataset(data.Dataset): + rsatoolbox Dataset object + obs_desc (String): + descriptor defining experimental conditions + cv_desc (String): + descriptor defining crossvalidation folds/partitions + noise (numpy.ndarray): + dataset.n_channel x dataset.n_channel + precision matrix for noise between channels + default: identity matrix, i.e. euclidean distance + + Returns: + numpy.ndarray: sigma_k: noise covariance matrix over conditions + n_conditions x n_conditions + + """ + + n_channels = dataset.n_channel + if noise is None: + noise = np.eye(n_channels) + else: + if noise.shape != (n_channels, n_channels): + raise ValueError("noise must have shape n_channel x n_channel") + noise_sqrt = sqrtm(noise) # take matrix square root to get whitening matrix + dataset_whitened = dataset.copy() + dataset_whitened.measurements = dataset.measurements @ noise_sqrt + + # Compute mean patterns per condition + U_mean, conds, _ = average_dataset_by(dataset_whitened, obs_desc) + n_cond = len(conds) + cv_folds = np.unique(np.array(dataset_whitened.obs_descriptors[cv_descriptor])) + + pair_counts = np.zeros((n_cond, n_cond)) # tally how many partitions each condition pair occurs in + + # Compute sigma_k per partition, then average + sigma_ks = [] + + for fold in cv_folds: + U_fold = np.zeros(U_mean.shape) * np.nan # fold activations + dataset_fold = dataset_whitened.subset_obs(cv_descriptor, fold) + dataset_fold, fold_conds, _ = average_dataset_by(dataset_fold, obs_desc) + # Get indices mapping from subsetted conditions to full set, fill out U_fold + inds = [np.where(conds == c)[0][0] for c in fold_conds] + U_fold[inds, :] = dataset_fold + sigma_k_fold = ((U_fold - U_mean) @ ((U_fold - U_mean).T)) + sigma_ks.append(sigma_k_fold) + + # Increment pair counts for all condition pairs present in this fold + pair_counts += np.isfinite(sigma_k_fold).astype(int) + + # Finally add all sigma_ks and divide by pair counts to get average + sigma_k = np.nansum(sigma_ks, axis=0) / (pair_counts - 1) / n_channels + # Any pairs that never occurred together should be set to zero + sigma_k[np.isnan(sigma_k)] = 0.0 + return sigma_k diff --git a/src/rsatoolbox/rdm/calc.py b/src/rsatoolbox/rdm/calc.py index 3b8f2cfe..2d27d9d3 100644 --- a/src/rsatoolbox/rdm/calc.py +++ b/src/rsatoolbox/rdm/calc.py @@ -15,6 +15,7 @@ from rsatoolbox.data import average_dataset_by from rsatoolbox.util.rdm_utils import _extract_triu_ from rsatoolbox.util.build_rdm import _build_rdms +from rsatoolbox.util.matrix import pairwise_contrast if TYPE_CHECKING: from rsatoolbox.rdm.rdms import RDMs @@ -30,6 +31,8 @@ def calc_rdm( cv_descriptor: Optional[str] = None, prior_lambda: float = 1.0, prior_weight: float = 0.1, + degree: float = 2, + root: bool = True, remove_mean: bool = False) -> Union[RDMs, List[RDMs]]: """ calculates an RDM from an input dataset @@ -50,6 +53,10 @@ def calc_rdm( precision matrix used to calculate the RDM used only for Mahalanobis and Crossnobis estimators defaults to an identity matrix, i.e. euclidean distance + degree: float + degree of the minkowski distance + root: bool + whether to take the root of the minkowski distance remove_mean (bool): whether the mean of each pattern shall be removed before distance calculation. This has no effect on poisson based and correlation distances. @@ -67,7 +74,7 @@ def calc_rdm( noise_i = noise rdms.append(_calc_rdm_single(ds_i, method, descriptor, noise_i, cv_descriptor, prior_lambda, - prior_weight, remove_mean)) + prior_weight, degree, root, remove_mean)) if descriptor is None: return concat(rdms) else: @@ -75,7 +82,7 @@ def calc_rdm( else: return _calc_rdm_single(dataset, method, descriptor, noise, cv_descriptor, prior_lambda, - prior_weight, remove_mean) + prior_weight, degree, root, remove_mean) def _calc_rdm_single( @@ -86,6 +93,8 @@ def _calc_rdm_single( cv_descriptor: Optional[str], prior_lambda: float, prior_weight: float, + degree: float, + root: bool, remove_mean: bool) -> RDMs: """Create RDMs object for a single Dataset """ @@ -97,16 +106,19 @@ def _calc_rdm_single( rdm = calc_rdm_mahalanobis(dataset, descriptor, noise, remove_mean) elif method == 'crossnobis': rdm = calc_rdm_crossnobis(dataset, descriptor, noise, - cv_descriptor, remove_mean) + cv_descriptor, remove_mean) elif method == 'poisson': rdm = calc_rdm_poisson(dataset, descriptor, - prior_lambda=prior_lambda, - prior_weight=prior_weight) + prior_lambda=prior_lambda, + prior_weight=prior_weight) elif method == 'poisson_cv': rdm = calc_rdm_poisson_cv(dataset, descriptor, - cv_descriptor=cv_descriptor, - prior_lambda=prior_lambda, - prior_weight=prior_weight) + cv_descriptor=cv_descriptor, + prior_lambda=prior_lambda, + prior_weight=prior_weight) + elif method == 'minkowski': + rdm = calc_rdm_minkowski(dataset, descriptor, degree, + root, remove_mean) else: raise NotImplementedError if descriptor is not None: @@ -213,13 +225,43 @@ def calc_rdm_euclidean( rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ measurements, desc = _parse_input(dataset, descriptor, remove_mean) - sum_sq_measurements = np.sum(measurements**2, axis=1, keepdims=True) + sum_sq_measurements = np.sum(measurements ** 2, axis=1, keepdims=True) rdm = sum_sq_measurements + sum_sq_measurements.T \ - - 2 * np.dot(measurements, measurements.T) + - 2 * np.dot(measurements, measurements.T) rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms(rdm, dataset, 'squared euclidean', descriptor, desc) +def calc_rdm_minkowski( + dataset: DatasetBase, + descriptor: Optional[str] = None, + degree: float = 2, + root: bool = True, + remove_mean: bool = False): + """ + Args: + dataset (rsatoolbox.data.DatasetBase): + The dataset the RDM is computed from + descriptor (String): + obs_descriptor used to define the rows/columns of the RDM + defaults to one row/column per row in the dataset + remove_mean (bool): + whether the mean of each pattern shall be removed + before calculating distances. + Returns: + rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM + """ + measurements, desc = _parse_input(dataset, descriptor, remove_mean) + # Calculate minkowski distance between the measurement rows + n_cond = measurements.shape[0] + C = pairwise_contrast(np.arange(n_cond)) + deltas = C @ measurements + rdm = np.sum(np.abs(deltas) ** degree, axis=1) + if root: + rdm = rdm ** (1 / degree) + return _build_rdms(rdm, dataset, 'minkowski distance', descriptor, desc) + + def calc_rdm_correlation(dataset, descriptor=None): """ calculates an RDM from an input dataset using correlation distance @@ -273,7 +315,7 @@ def calc_rdm_mahalanobis(dataset, descriptor=None, noise=None, remove_mean: bool noise = _check_noise(noise, dataset.n_channel) kernel = measurements @ noise @ measurements.T rdm = np.expand_dims(np.diag(kernel), 0) + \ - np.expand_dims(np.diag(kernel), 1) - 2 * kernel + np.expand_dims(np.diag(kernel), 1) - 2 * kernel rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms( rdm, @@ -375,7 +417,7 @@ def calc_rdm_crossnobis(dataset, descriptor, noise=None, measurements[i_fold], measurements[j_fold], np.linalg.inv( (variances[i_fold] + variances[j_fold]) / 2) - ) + ) rdms.append(rdm) rdms = np.array(rdms) rdm = np.einsum('ij->j', rdms) / rdms.shape[0] @@ -410,10 +452,10 @@ def calc_rdm_poisson(dataset, descriptor=None, prior_lambda=1.0, """ measurements, desc = _parse_input(dataset, descriptor) measurements = (measurements + prior_lambda * prior_weight) \ - / (1 + prior_weight) + / (1 + prior_weight) kernel = measurements @ np.log(measurements).T rdm = np.expand_dims(np.diag(kernel), 0) + \ - np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T + np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms(rdm, dataset, 'poisson', descriptor, desc) @@ -459,13 +501,13 @@ def calc_rdm_poisson_cv(dataset, descriptor=None, prior_lambda=1.0, measurements_test, _, _ = average_dataset_by(data_test, descriptor) measurements_train = (measurements_train + prior_lambda * prior_weight) \ - / (1 + prior_weight) + / (1 + prior_weight) measurements_test = (measurements_test + prior_lambda * prior_weight) \ - / (1 + prior_weight) + / (1 + prior_weight) kernel = measurements_train @ np.log(measurements_test).T rdm = np.expand_dims(np.diag(kernel), 0) + \ - np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T + np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T rdm = _extract_triu_(rdm) / measurements_train.shape[1] return _build_rdms(rdm, dataset, 'poisson_cv', descriptor) @@ -473,7 +515,7 @@ def calc_rdm_poisson_cv(dataset, descriptor=None, prior_lambda=1.0, def _calc_rdm_crossnobis_single(meas1, meas2, noise) -> NDArray: kernel = meas1 @ noise @ meas2.T rdm = np.expand_dims(np.diag(kernel), 0) + \ - np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T + np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T return _extract_triu_(rdm) / meas1.shape[1] @@ -485,8 +527,8 @@ def _gen_default_cv_descriptor(dataset, descriptor) -> np.ndarray: desc = np.asarray(dataset.obs_descriptors[descriptor]) values, counts = np.unique(desc, return_counts=True) assert np.all(counts == counts[0]), ( - 'cv_descriptor generation failed:\n' - + 'different number of observations per pattern') + 'cv_descriptor generation failed:\n' + + 'different number of observations per pattern') n_repeats = counts[0] cv_descriptor = np.zeros_like(desc) for i_val in values: @@ -495,10 +537,10 @@ def _gen_default_cv_descriptor(dataset, descriptor) -> np.ndarray: def _parse_input( - dataset: DatasetBase, - descriptor: Optional[str], - remove_mean: bool = False - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + dataset: DatasetBase, + descriptor: Optional[str], + remove_mean: bool = False +) -> Tuple[np.ndarray, Optional[np.ndarray]]: if descriptor is None: measurements = dataset.measurements desc = None diff --git a/tests/test_data.py b/tests/test_data.py index d00e6db0..863603cd 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -403,7 +403,7 @@ def test_temporaldataset_time_as_observations(self): def test_temporaldataset_time_as_channels(self): from rsatoolbox.data.dataset import TemporalDataset - measurements = np.zeros((3, 2, 4)) # 3 trials, 2 channels, 4 timepoints + measurements = np.zeros((3, 2, 4)) # 3 trials, 2 channels, 4 timepoints des = {'session': 0, 'subj': 0} obs_des = {'conds': np.array([0, 1, 1])} chn_des = {'electrode': np.array(['A1', 'B2'])} @@ -418,8 +418,8 @@ def test_temporaldataset_time_as_channels(self): ) data = data_temporal.time_as_channels() self.assertEqual(data.n_obs, 3) - self.assertEqual(data.n_channel, 2*4) - self.assertEqual(len(data.channel_descriptors['time']), 2*4) + self.assertEqual(data.n_channel, 2 * 4) + self.assertEqual(len(data.channel_descriptors['time']), 2 * 4) assert_array_equal( data.channel_descriptors['time'], np.concatenate([tim_des['time'], tim_des['time']]) @@ -428,7 +428,7 @@ def test_temporaldataset_time_as_channels(self): data.channel_descriptors['time_formatted'], tim_des['time_formatted'] + tim_des['time_formatted'] ) - self.assertEqual(len(data.channel_descriptors['electrode']), 2*4) + self.assertEqual(len(data.channel_descriptors['electrode']), 2 * 4) assert_array_equal( data.channel_descriptors['electrode'], ['A1', 'A1', 'A1', 'A1', 'B2', 'B2', 'B2', 'B2'] @@ -498,6 +498,65 @@ def test_equality(self): self.assertNotEqual(orig, other) +class TestFramedDataset(unittest.TestCase): + + def setUp(self): + self.rng = np.random.default_rng(0) + self.n_stim = 10 + self.n_channel = 20 + self.n_fold = 5 + self.n_trials = self.n_stim * self.n_fold + self.patterns_orig = self.rng.random((self.n_trials, self.n_channel)) + stim = np.repeat(np.arange(self.n_stim), self.n_fold) + folds = np.tile(np.arange(self.n_fold), self.n_stim) + self.dataset_basic = rsd.FramedDataset(measurements=self.patterns_orig, + obs_descriptors={'stim': stim, + 'fold': folds}, + cond_descriptor='stim', + include_all_zeros=False, + all_c_scale=None) + self.dataset_zeros = rsd.FramedDataset(measurements=self.patterns_orig, + obs_descriptors={'stim': stim, + 'fold': folds}, + cond_descriptor='stim', + include_all_zeros=True, + all_c_scale=None) + self.dataset_full = rsd.FramedDataset(measurements=self.patterns_orig, + obs_descriptors={'stim': stim, + 'fold': folds}, + cond_descriptor='stim', + include_all_zeros=True, + all_c_scale=1) + + def test_basic_shape(self): + assert self.dataset_basic.measurements.shape[0] == self.n_trials + + def test_zeros_shape(self): + assert self.dataset_zeros.measurements.shape[0] == self.n_trials + self.n_fold + + def test_full_shape(self): + assert self.dataset_full.measurements.shape[0] == self.n_trials + self.n_fold * 2 + + def test_sigmak(self): + sigma_k_fromdata = self.dataset_full.get_sigmak(from_data=True, cv_desc='fold') + sigma_k_notdata = self.dataset_full.get_sigmak(from_data=False, cv_desc='fold') + assert sigma_k_fromdata.shape == (self.n_stim + 2, self.n_stim + 2) + assert sigma_k_notdata.shape == (self.n_stim + 2, self.n_stim + 2) + assert sigma_k_fromdata[0, 0] == 0 + assert sigma_k_fromdata[-1, -1] == 0 + assert sigma_k_notdata[0, 0] == 0 + assert sigma_k_notdata[-1, -1] == 0 + assert sigma_k_notdata[1, 1] == 1 + + def test_mask(self): + mask = self.dataset_full.get_framed_rdm_mask() + assert mask.shape == (self.n_stim + 2, self.n_stim + 2) + assert mask.sum() == 2 * (self.n_stim + 2) - 1 + assert mask[0, 0] == 1 + assert mask[-1, -1] == 1 + assert mask[1, 1] == 0 + + class TestDataComputations(unittest.TestCase): def setUp(self):