Skip to content

Commit 306c4b4

Browse files
authored
Merge pull request #2 from NuttidaLab/tmp
Refactor evaluation code for the spiking network
2 parents c448b6b + 2f61397 commit 306c4b4

17 files changed

Lines changed: 899 additions & 1319 deletions

README.md

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ The SpikeRNN framework consists of two complementary packages with a modern **ta
1818
### Spiking RNN Package
1919
- Rate-to-spike conversion maintaining task performance
2020
- Biologically realistic leaky integrate-and-fire (LIF) neurons
21-
- Spiking task evaluation classes (GoNogoSpikingTask, XORSpikingTask, ManteSpikingTask)
22-
- SpikingTaskFactory for task-based evaluation
21+
- Spiking task evaluation classes (GoNogoSpikingEvaluator, XORSpikingEvaluator, ManteSpikingEvaluator)
22+
- SpikingEvaluatorFactory for task-based evaluation
2323
- Scaling factor optimization for optimal conversion
2424

2525
## Task-Based Architecture
@@ -44,13 +44,13 @@ After installation, you can import both packages:
4444

4545
```python
4646
from rate import FR_RNN_dale, set_gpu, create_default_config
47-
from spiking import LIF_network_fnc, lambda_grid_search, evaluate_task
47+
from spiking import LIF_network_fnc, lambda_grid_search
4848

4949
# Task-based architecture
5050
from rate import TaskFactory
51-
from spiking import SpikingTaskFactory
51+
from spiking.eval_tasks import SpikingEvaluatorFactory
5252
from rate.tasks import GoNogoTask, XORTask, ManteTask
53-
from spiking.tasks import GoNogoSpikingTask, XORSpikingTask, ManteSpikingTask
53+
from spiking.eval_tasks import GoNogoSpikingEvaluator, XORSpikingEvaluator, ManteSpikingEvaluator
5454
```
5555

5656
## Quick Start: Task-Based Architecture
@@ -81,23 +81,22 @@ print(f"Generated {task.__class__.__name__} trial with label: {label}")
8181
The framework provides 2 levels of evaluation:
8282

8383
```python
84-
from spiking import SpikingTaskFactory, evaluate_task
84+
from spiking.eval_tasks import SpikingEvaluatorFactory, evaluate_task
8585

8686
# Direct task evaluation (when you have a network instance, not necessarily trained)
87-
spiking_task = SpikingTaskFactory.create_task('go_nogo')
88-
performance = spiking_task.evaluate_performance(spiking_rnn, n_trials=100)
89-
print(f"Accuracy: {performance['overall_accuracy']:.2f}")
87+
spiking_evaluator = SpikingEvaluatorFactory.create_evaluator('go_nogo', settings)
88+
performance = spiking_evaluator.evaluate_single_trial(model_path, scaling_factor)
89+
print(f"Trial result: {performance}")
9090

9191
# High-level interface (when you have model files with trained weights)
9292
performance = evaluate_task(
9393
task_name='go_nogo',
94-
model_dir='models/go-nogo',
95-
n_trials=100,
96-
save_plots=True
94+
model_path='models/go-nogo/model.mat',
95+
n_trials=50
9796
)
9897

9998
# Command line interface (for scripts and automation)
100-
# python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/
99+
# python -m spiking.eval_tasks --task go_nogo --model_path models/go-nogo/model.mat
101100
```
102101

103102
### Extending with Custom Tasks
@@ -128,28 +127,32 @@ stimulus, target, label = custom_task.simulate_trial()
128127
Create custom spiking evaluation tasks:
129128

130129
```python
131-
from spiking.tasks import AbstractSpikingTask, SpikingTaskFactory
130+
from spiking.eval_tasks import SpikingEvaluatorFactory
131+
from rate.tasks import AbstractTask
132132

133-
class MyCustomSpikingTask(AbstractSpikingTask):
134-
def get_default_settings(self):
135-
return {'T': 200, 'custom_param': 1.0}
136-
137-
def get_sample_trial_types(self):
138-
return ['type_a', 'type_b'] # For visualization
133+
class MyCustomSpikingEvaluator(AbstractTask):
134+
def __init__(self, settings):
135+
super().__init__(settings)
136+
self.eval_amp_thresh = settings.get('eval_amp_thresh', 0.7) # custom value
139137

140-
def generate_stimulus(self, trial_type=None):
141-
# Generate stimulus logic
142-
return stimulus, label
138+
def validate_settings(self):
139+
# Validation logic for custom task
140+
required_keys = ['T', 'custom_param']
141+
for key in required_keys:
142+
if key not in self.settings:
143+
raise ValueError(f"Missing required setting: {key}")
143144

144-
def evaluate_performance(self, spiking_rnn, n_trials=100):
145-
# Multi-trial performance evaluation
146-
return {'accuracy': 0.85, 'n_trials': n_trials}
145+
def evaluate_single_trial(self, model_path: str, scaling_factor: float) -> int:
146+
"""Evaluate a single trial for the custom task."""
147+
# Custom evaluation logic here
148+
# Return 1 if correct, 0 if incorrect
149+
pass
147150

148151
# Register and use with evaluation system
149-
SpikingTaskFactory.register_task('my_custom', MyCustomSpikingTask)
152+
SpikingEvaluatorFactory._registry['my_custom'] = MyCustomSpikingEvaluator
150153

151154
# Now works with eval_tasks.py
152-
python -m spiking.eval_tasks --task my_custom --model_dir models/custom/
155+
python -m spiking.eval_tasks --task my_custom --model_path models/custom/model.mat
153156
```
154157

155158
## Requirements
@@ -201,16 +204,20 @@ import numpy as np
201204

202205
# Optimize scaling factor
203206
lambda_grid_search(
204-
model_dir='models/go-nogo',
205-
task_name='go-nogo',
207+
model_path='models/go-nogo/model.mat',
208+
task_name='go_nogo',
206209
n_trials=100,
207-
scaling_factors=list(np.arange(25, 76, 5))
210+
scaling_factors=list(np.arange(25, 76, 5)),
211+
task_settings=settings
208212
)
209213

210214
# Evaluate performance
211215
performance = evaluate_task(
212216
task_name='go_nogo',
213-
model_dir='models/go-nogo/'
217+
model_path='models/go-nogo/model.mat',
218+
n_trials=50,
219+
task_settings=settings,
220+
all_trial_types=True
214221
)
215222
```
216223

__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@
5454
LIF_network_fnc,
5555
create_default_spiking_config,
5656
lambda_grid_search,
57-
SpikingTaskFactory,
58-
GoNogoSpikingTask,
59-
XORSpikingTask,
60-
ManteSpikingTask
6157
)
6258

6359
def check_packages():
@@ -87,12 +83,6 @@ def check_packages():
8783
"create_default_spiking_config",
8884
"lambda_grid_search",
8985

90-
# Task-based architecture (spiking)
91-
"SpikingTaskFactory",
92-
"GoNogoSpikingTask",
93-
"XORSpikingTask",
94-
"ManteSpikingTask",
95-
9686
# Subpackages
9787
"rate",
9888
"spiking",

docs/api/spiking/eval_tasks.rst

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Adapter Classes
2424
Overview
2525
----------------------------------------------------------------------------------
2626

27-
The eval_tasks module provides a high-level evaluation interface that standardizes the process of evaluating trained spiking RNN models across different cognitive tasks. The system is designed to be fully extensible, automatically supporting any task registered with the ``SpikingTaskFactory``.
27+
The eval_tasks module provides a high-level evaluation interface that standardizes the process of evaluating trained spiking RNN models across different cognitive tasks. The system is designed to be fully extensible, automatically supporting any task registered with the ``SpikingEvaluatorFactory``.
2828

2929
**Key Features:**
3030

@@ -35,13 +35,10 @@ The eval_tasks module provides a high-level evaluation interface that standardiz
3535
* **Robust Error Handling**: Graceful handling of evaluation failures
3636
* **Flexible Visualization**: Generic visualization system for any task type
3737

38-
**Evaluation Layers:**
38+
**Evaluation:**
3939

40-
The framework provides three levels of evaluation:
41-
42-
1. **Core Task Methods**: Direct task evaluation (``task.evaluate_performance()``)
43-
2. **High-Level Interface**: Complete workflow (``evaluate_task()``)
44-
3. **Command-Line Interface**: Batch processing (``python -m spiking.eval_tasks``)
40+
1. **High-Level Interface**: Complete workflow (``evaluate_task()``)
41+
2. **Command-Line Interface**: Batch processing (``python -m spiking.eval_tasks``)
4542

4643
Usage Examples
4744
----------------------------------------------------------------------------------
@@ -54,49 +51,52 @@ Usage Examples
5451
5552
# Evaluate any registered task
5653
performance = evaluate_task(
57-
task_name='go_nogo', # or 'xor', 'mante', custom tasks
58-
model_dir='models/go-nogo/',
59-
save_plots=True
54+
task_name='go_nogo',
55+
model_path='models/go-nogo/model.mat',
56+
n_trials=50
6057
)
6158
62-
print(f"Accuracy: {performance['overall_accuracy']:.3f}")
59+
print(f"Performance: {performance}")
6360
6461
**Command-Line Interface:**
6562

6663
.. code-block:: bash
6764
6865
# Basic evaluation
69-
python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/
66+
python -m spiking.eval_tasks --task go_nogo --model_path models/go-nogo/model.mat
7067
7168
# With custom parameters
7269
python -m spiking.eval_tasks \
7370
--task xor \
74-
--model_dir models/xor/ \
71+
--model_path models/xor/model.mat \
7572
--scaling_factor 45.0 \
76-
--no_plots
73+
--n_trials 50
7774
7875
# Custom task (after registration)
79-
python -m spiking.eval_tasks --task my_custom --model_dir models/custom/
76+
python -m spiking.eval_tasks --task my_custom --model_path models/custom/model.mat
8077
8178
**Custom Task Integration:**
8279

8380
.. code-block:: python
8481
85-
from spiking.tasks import SpikingTaskFactory, AbstractSpikingTask
86-
from spiking.eval_tasks import evaluate_task
87-
88-
# 1. Define custom task
89-
class WorkingMemoryTask(AbstractSpikingTask):
90-
# ... implementation ...
91-
pass
92-
82+
from spiking.eval_tasks import SpikingEvaluatorFactory, evaluate_task
83+
from rate.tasks import AbstractTask
84+
85+
# 1. Define custom evaluator (inheriting from a rate task class)
86+
class WorkingMemoryEvaluator(AbstractTask):
87+
def validate_settings(self):
88+
pass
89+
def evaluate_single_trial(self, model_path, scaling_factor, model_data=None):
90+
# ... implementation ...
91+
pass
92+
9393
# 2. Register with factory
94-
SpikingTaskFactory.register_task('working_memory', WorkingMemoryTask)
95-
94+
SpikingEvaluatorFactory._registry['working_memory'] = WorkingMemoryEvaluator
95+
9696
# 3. Evaluate using unified interface
9797
performance = evaluate_task(
98-
task_name='working_memory', # Now supported automatically
99-
model_dir='models/working_memory/',
98+
task_name='working_memory',
99+
model_path='models/working_memory/model.mat',
100100
)
101101
102102
Command-Line Arguments
@@ -108,22 +108,26 @@ Command-Line Arguments
108108

109109
Task to evaluate. Available tasks are dynamically determined from the factory registry.
110110

111-
.. option:: --model_dir MODEL_DIR
111+
.. option:: --model_path MODEL_PATH
112112

113-
Directory containing the trained model .mat file.
113+
Path to the trained model .mat file.
114114

115115
.. option:: --scaling_factor SCALING_FACTOR
116116

117117
Override scaling factor (uses value from .mat file if not provided).
118118

119-
.. option:: --no_plots
119+
.. option:: --n_trials N_TRIALS
120120

121-
Skip generating visualization plots.
121+
Number of trials to evaluate.
122122

123123
.. option:: --T T
124124

125125
Trial duration (timesteps) - overrides task default.
126126

127+
.. option:: --delay DELAY
128+
129+
Delay time (timesteps) - overrides task default.
130+
127131
.. option:: --stim_on STIM_ON
128132

129133
Stimulus onset time - overrides task default.
@@ -141,21 +145,4 @@ The system automatically loads trained rate RNN models from `.mat` files and ext
141145

142146
* Network weights and connectivity matrices
143147
* Optimal scaling factors for rate-to-spike conversion
144-
* Task-specific parameters and configurations
145-
146-
**Generic Visualization:**
147-
148-
The visualization system uses each task's ``get_sample_trial_types()`` method to determine what trial types to generate for plotting. This allows custom tasks to specify their own visualization patterns without modifying the evaluation code.
149-
150-
**Error Handling:**
151-
152-
The evaluation system includes comprehensive error handling:
153-
154-
* Graceful handling of missing model files
155-
* Validation of task names against factory registry
156-
* Recovery from trial generation failures
157-
* Informative error messages for debugging
158-
159-
**Extensibility:**
160-
161-
The system is designed to be fully extensible. Any task that inherits from ``AbstractSpikingTask`` and is registered with ``SpikingTaskFactory`` can be evaluated using this unified interface.
148+
* Task-specific parameters and configurations

docs/api/spiking/index.rst

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Core Modules
1111
:maxdepth: 1
1212

1313
lif_network
14-
tasks
1514
eval_tasks
1615
lambda_grid_search
1716
utils
@@ -22,9 +21,6 @@ Module Overview
2221
**LIF_network_fnc.py**
2322
Core function for converting rate RNNs to spiking networks and running LIF simulations.
2423

25-
**tasks.py**
26-
Task-based architecture for spiking neural network evaluation with abstract base classes and concrete task implementations.
27-
2824
**eval_tasks.py**
2925
Unified, extensible evaluation interface for spiking neural networks on cognitive tasks.
3026

@@ -43,14 +39,6 @@ Quick Reference
4339
* ``evaluate_task()``: Unified evaluation interface for all tasks
4440
* ``lambda_grid_search()``: Optimize scaling factors
4541

46-
**Task Classes:**
47-
48-
* ``AbstractSpikingTask``: Base class for spiking task evaluation
49-
* ``GoNogoSpikingTask``: Go-NoGo task for spiking networks
50-
* ``XORSpikingTask``: XOR task for spiking networks
51-
* ``ManteSpikingTask``: Mante task for spiking networks
52-
* ``SpikingTaskFactory``: Factory for creating spiking task instances
53-
5442
**Configuration:**
5543

5644
* ``SpikingConfig``: Configuration dataclass for spiking RNN parameters
@@ -60,4 +48,3 @@ Quick Reference
6048

6149
* ``load_rate_model()``: Load rate model from `.mat` file
6250
* ``format_spike_data()``: Format spike data for analysis
63-
* ``SpikingTaskFactory.register_task()``: Register custom tasks

0 commit comments

Comments
 (0)