Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions ibllib/qc/task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
>>> outcome, results = qc.run()
"""
import logging
import json
import sys
from packaging import version
from pathlib import PurePosixPath
Expand All @@ -61,6 +62,7 @@
from brainbox.behavior.wheel import cm_to_rad, traces_by_trial
from ibllib.io.extractors import ephys_fpga
from one.alf import spec

from . import base

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -359,7 +361,7 @@ def run(self, update=False, **kwargs):
self.update(outcome, kwargs.get('namespace', self.namespace))
return outcome, results

def compute_session_status(self):
def compute_session_status(self, use_custom=True):
"""
Compute the overall session QC for each key and aggregates in a single value.

Expand All @@ -377,9 +379,36 @@ def compute_session_status(self):
# Get mean passed of each check, or None if passed is None or all NaN
results = {k: None if v is None or np.isnan(v).all() else np.nanmean(v)
for k, v in self.passed.items()}

# If a custom criteria is defined for this session, use it
if use_custom:
custom_criteria = self.get_custom_session_criteria()
if custom_criteria is not None:
self.criteria = custom_criteria

session_outcome, outcomes = compute_session_status_from_dict(results, self.criteria)
return session_outcome, results, outcomes

def get_custom_session_criteria(self):
"""
Use a custom QC criteria associated with the session.

Returns
-------
dict
The QC criteria to use
"""
if self.one:
note_title = f'=== SESSION QC CRITERIA {self.namespace} ==='
query = f'text__icontains,{note_title},object_id,{str(self.eid)}'
notes = self.one.alyx.rest('notes', 'list', django=query)
if len(notes) > 0:
notes = json.loads(notes[0]['text'])
criteria = {k if k == 'default' else f'_{self.namespace}_{k}': v
for k, v in notes['criteria'].items()}
self.log.info('Using custom QC criteria found on Alyx note associated with session')
return criteria

@staticmethod
def compute_dataset_qc_status(outcomes, namespace='task'):
"""Return map of dataset specific QC values.
Expand Down Expand Up @@ -1124,6 +1153,7 @@ def check_n_trial_events(data, **_):
intervals = data['intervals']
correct = data['correct']
err_trig = data['errorCueTrigger_times']
stim_trig = data['stimFreezeTrigger_times']

# Exclude these fields; valve and errorCue times are the same as feedback_times and we must
# test errorCueTrigger_times separately
Expand All @@ -1132,13 +1162,25 @@ def check_n_trial_events(data, **_):
'wheelMoves_peakVelocity_times', 'valveOpen_times', 'wheelMoves_peakAmplitude',
'wheelMoves_intervals', 'wheel_timestamps', 'stimFreeze_times']
events = [k for k in data.keys() if k.endswith('_times') and k not in exclude]
exclude_nogo = exclude + ['stimFreezeTrigger_times', 'firstMovement_times']
events_nogo = [k for k in data.keys() if k.endswith('_times') and k not in exclude_nogo]

metric = np.zeros(data['intervals'].shape[0], dtype=bool)

# For each trial interval check that one of each trial event occurred. For incorrect trials,
# check the error cue trigger occurred within the interval, otherwise check it is nan.
# For each trial interval check that one of each trial event occurred.
# For incorrect trials, check the error cue trigger occurred within the interval, otherwise check it is nan.
# For no go trials, stimFreeze is nan
for i, (start, end) in enumerate(intervals):
metric[i] = (all([start < data[k][i] < end for k in events]) and
(np.isnan(err_trig[i]) if correct[i] else start < err_trig[i] < end))
if correct[i]:
met = all([start < data[k][i] < end for k in events]) and np.isnan(err_trig[i])
else:
if data['choice'][i] != 0:
met = all([start < data[k][i] < end for k in events]) and (start < err_trig[i] < end)
else:
met = (all([start < data[k][i] < end for k in events_nogo]) and
np.isnan(stim_trig[i]) and (start < err_trig[i] < end))
metric[i] = met

passed = metric.astype(bool)
assert intervals.shape[0] == len(metric) == len(passed)
return metric, passed
Expand Down Expand Up @@ -1228,9 +1270,13 @@ def check_errorCue_delays(data, audio_output='harp', **_):
percentile of delays over 500 training sessions using the Xonar soundcard.
"""
threshold = 0.0015 if audio_output.lower() == 'harp' else 0.062
# There are some instances when the mouse responds before the goCue ttl is finished. In these cases the errorCue
# tone is delayed. # TODO can this be a metric and also check xonar
idx = data['response_times'] - data['goCue_times'] < 0.105 # 0.1 is the length of the goCue tone, add a little buffer
metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf)
passed = ((metric <= threshold) & (metric > 0)).astype(float)
passed[data['correct']] = metric[data['correct']] = np.nan
passed[idx] = metric[idx] = np.nan
assert data['intervals'].shape[0] == len(metric) == len(passed)
return metric, passed

Expand Down Expand Up @@ -1306,7 +1352,10 @@ def check_stimFreeze_delays(data, **_):
'intervals')
"""
metric = np.nan_to_num(data['stimFreeze_times'] - data['stimFreezeTrigger_times'], nan=np.inf)
passed = (metric <= 0.15) & (metric > 0)
passed = ((metric <= 0.15) & (metric > 0)).astype(float)
# Remove no_go trials (stimFreeze not triggered in no-go trials)
passed[data['choice'] == 0] = np.nan

assert data['intervals'].shape[0] == len(metric) == len(passed)
return metric, passed

Expand Down
5 changes: 4 additions & 1 deletion ibllib/tests/qc/test_critical_reasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ def test_upload_existing_note(self):
assert expected_dict == note_dict

def tearDown(self) -> None:
self.one.alyx.rest('sessions', 'delete', id=self.eid)

notes = self.one.alyx.rest('notes', 'list', django=f'object_id,{self.eid}', no_cache=True)
for n in notes:
self.one.alyx.rest('notes', 'delete', id=n['id'])


if __name__ == '__main__':
Expand Down
75 changes: 71 additions & 4 deletions ibllib/tests/qc/test_task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from one.api import ONE
from one.alf import spec
from ibllib.tests import TEST_DB
from ibllib.tests.fixtures.utils import register_new_session
from ibllib.qc import task_metrics as qcmetrics

from brainbox.behavior.wheel import cm_to_rad
Expand Down Expand Up @@ -128,7 +129,8 @@ def test_update_dataset_qc(self):
self.assertRaises(AssertionError, qcmetrics.update_dataset_qc, qc, registered_datasets.copy(), one)


class TestTaskMetrics(unittest.TestCase):
class TaskQCTestData(unittest.TestCase):

def setUp(self):
self.data = self.load_fake_bpod_data()
self.wheel_gain = 4
Expand Down Expand Up @@ -188,6 +190,8 @@ def load_fake_bpod_data(n=5):
)
data['feedback_times'] = data['response_times'] + resp_feeback_delay
data['stimFreeze_times'] = data['response_times'] + 1e-2
# StimFreeze for no go trials is nan
data['stimFreeze_times'][0] = np.nan
data['stimFreezeTrigger_times'] = data['stimFreeze_times'] - trigg_delay
data['feedbackType'] = np.vectorize(lambda x: -1 if x == 0 else x)(data['correct'])
outcome = data['feedbackType'].copy()
Expand Down Expand Up @@ -281,6 +285,9 @@ def add_frag(t, p):
'firstMovement_times': np.array(movement_times)
}


class TestTaskMetrics(TaskQCTestData):

def test_check_stimOn_goCue_delays(self):
metric, passed = qcmetrics.check_stimOn_goCue_delays(self.data)
self.assertTrue(np.allclose(metric, 0.0011), 'failed to return correct metric')
Expand All @@ -303,7 +310,9 @@ def test_check_response_feedback_delays(self):

def test_check_response_stimFreeze_delays(self):
metric, passed = qcmetrics.check_response_stimFreeze_delays(self.data)
self.assertTrue(np.allclose(metric, 1e-2), 'failed to return correct metric')
self.assertTrue(np.allclose(metric[1:], 1e-2), 'failed to return correct metric')
# No go trial has inf value as stimFreeze values are nan
self.assertEqual(metric[0], np.inf)
# Set incorrect timestamp (stimFreeze occurs before response)
self.data['stimFreeze_times'][-1] = self.data['response_times'][-1] - 1e-4
metric, passed = qcmetrics.check_response_stimFreeze_delays(self.data)
Expand Down Expand Up @@ -414,11 +423,13 @@ def test_check_stimOff_delays(self):

def test_check_stimFreeze_delays(self):
metric, passed = qcmetrics.check_stimFreeze_delays(self.data)
self.assertTrue(np.allclose(metric, 1e-4), 'failed to return correct metric')
self.assertTrue(np.allclose(metric[1:], 1e-4), 'failed to return correct metric')
# No go trial has inf value as stimFreeze values are nan
self.assertEqual(metric[0], np.inf)
# Set incorrect timestamp
self.data['stimFreeze_times'][-1] = self.data['stimFreezeTrigger_times'][-1] + 0.2
metric, passed = qcmetrics.check_stimFreeze_delays(self.data)
n = len(self.data['stimFreeze_times'])
n = len(self.data['stimFreeze_times']) - 1 # remove the nogo trial which we expect to be nan
expected = (n - 1) / n
self.assertEqual(np.nanmean(passed), expected, 'failed to detect dodgy timestamp')

Expand Down Expand Up @@ -669,5 +680,61 @@ def test_compute(self):
self.assertEqual(outcomes['_task_habituation_time'], spec.QC.NOT_SET)


class TestTaskQCWithCustomCriteria(TaskQCTestData):
"""
Test running the task QC but with custom criteria stored in a note attached to the session
"""

def setUp(self):

super().setUp()

self.one = ONE(**TEST_DB)
_, self.eid = register_new_session(self.one)
self.qc = qcmetrics.TaskQC(self.eid, one=self.one)
self.qc.extractor = Bunch({'data': self.data, 'settings': {}})

import json
note_title = '=== SESSION QC CRITERIA task ==='
note_text = {
"title": note_title,
"criteria": {'default': {'WARNING': 1, 'FAIL': 0.5}}
}

note_data = {'user': self.one.alyx.user,
'content_type': 'session',
'object_id': self.eid,
'text': f'{json.dumps(note_text)}'}

self.note = self.one.alyx.rest('notes', 'create', data=note_data)

def test_compute(self):

# Build a dict of metrics and passed for a few the task qcs
metrics = dict()
passed = dict()

metrics['_task_stimOff_itiIn_delays'], passed['_task_stimOff_itiIn_delays'] = (
qcmetrics.check_stimOff_itiIn_delays(self.data))

metrics['_task_stimOn_delays'], passed['_task_stimOn_delays'] = (
qcmetrics.check_stimOff_itiIn_delays(self.data))

self.qc.metrics = metrics
self.qc.passed = passed

# Outcome using defualt BWM criteria
outcome, *_ = self.qc.compute_session_status(use_custom=False)
self.assertEqual(spec.QC.PASS, outcome)

# Outcome using custom note criteria
outcome, *_ = self.qc.compute_session_status(use_custom=True)
self.assertEqual(spec.QC.WARNING, outcome)

def tearDown(self):
self.one.alyx.rest('notes', 'delete', id=self.note['id'])
self.one.alyx.rest('sessions', 'delete', id=self.eid)


if __name__ == '__main__':
unittest.main(exit=False, verbosity=2)
3 changes: 3 additions & 0 deletions ibllib/tests/test_base_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def tearDownClass(cls) -> None:
if cls.tmpdir:
cls.tmpdir.cleanup()
if cls.one and cls.eid:
notes = cls.one.alyx.rest('notes', 'list', django=f'object_id,{cls.eid}', no_cache=True)
for n in notes:
cls.one.alyx.rest('notes', 'delete', id=n['id'])
cls.one.alyx.rest('sessions', 'delete', id=cls.eid)


Expand Down
Loading