Skip to content

Commit 2b883c2

Browse files
committed
add function parameters
1 parent fb3c63d commit 2b883c2

4 files changed

Lines changed: 78 additions & 64 deletions

File tree

docs/api/spiking/eval_go_nogo.rst

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@ Main Function
1919
Function Parameters
2020
-------------------
2121

22-
The main evaluation function can be called with optional parameters:
22+
The main evaluation function accepts:
2323

24-
* **model_path** (str, optional): Path to trained model file
25-
* **scaling_factor** (float, optional): Scaling factor for conversion
26-
* **n_trials** (int, optional): Number of trials to evaluate
27-
* **plot_results** (bool, optional): Whether to generate plots
24+
* **model_path** (str): Path to directory containing trained model files
25+
(default: '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0')
26+
27+
The function will:
28+
1. Load the first .mat model file from the specified directory
29+
2. Extract the optimal scaling factor from the model
30+
3. Run example Go and NoGo trials
31+
4. Generate visualizations of network behavior
2832

2933
Example Usage
3034
-------------
@@ -33,34 +37,33 @@ Example Usage
3337
3438
from spiking import eval_go_nogo
3539
36-
# Evaluate with default parameters
40+
# Evaluate with default model path
3741
eval_go_nogo()
3842
39-
# Evaluate specific model with custom parameters
43+
# Evaluate specific model
4044
eval_go_nogo(
41-
model_path='models/go-nogo/trained_model.mat',
42-
scaling_factor=50.0,
43-
n_trials=100,
44-
plot_results=True
45+
model_path='models/go-nogo/my_trained_models'
4546
)
4647
4748
Output Metrics
4849
--------------
4950

5051
The evaluation generates the following metrics:
5152

52-
* **Go Trial Accuracy**: Percentage of correct responses to Go stimuli
53-
* **NoGo Trial Accuracy**: Percentage of correct response inhibition to NoGo stimuli
54-
* **Overall Accuracy**: Combined accuracy across all trials
55-
* **Response Time**: Average response time for Go trials
56-
* **Spike Count**: Total number of spikes generated during trials
53+
* **Network Output**: Response curves for both Go and NoGo trials
54+
* **Spike Patterns**: Detailed spike raster plots showing:
55+
- Excitatory neuron activity (red)
56+
- Inhibitory neuron activity (blue)
57+
* **Temporal Dynamics**: Network behavior over the full trial duration
5758

5859
Visualization
5960
-------------
6061

6162
The function generates several plots:
6263

63-
* Spike raster plots for Go and NoGo trials
64-
* Network output comparison between rate and spiking models
65-
* Performance summary statistics
66-
* Response time distributions
64+
* Network output comparison between Go and NoGo trials
65+
* Spike raster plots showing:
66+
- NoGo trial neural activity
67+
- Go trial neural activity
68+
* Color-coded neuron types (excitatory in red, inhibitory in blue)
69+
* Time-resolved network responses

docs/api/spiking/lambda_grid_search.rst

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,24 @@ Grid Search Parameters
2525

2626
The main grid search function accepts:
2727

28-
* **model_path** (str, optional): Path to trained rate RNN model
29-
* **scaling_range** (tuple, optional): Range of scaling factors to test (default: 20-75)
30-
* **n_trials_per_factor** (int, optional): Number of trials per scaling factor
31-
* **task_type** (str, optional): Task type ('go-nogo', 'xor', 'mante')
32-
* **parallel** (bool, optional): Whether to use parallel processing
28+
* **model_dir** (str): Directory containing trained rate RNN model .mat files
29+
(default: '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0')
30+
* **n_trials** (int): Number of trials to evaluate each scaling factor
31+
(default: 100)
32+
* **scaling_factors** (list): List of scaling factors to test
33+
(default: [20, 25, 30, ..., 75])
34+
* **task_name** (str): Task type ('go-nogo', 'xor', or 'mante')
35+
(default: 'go-nogo')
3336

3437
Single Trial Evaluation
3538
----------------------------------------------------
3639

3740
The evaluate_single_trial function tests a specific scaling factor:
3841

39-
* **model_path** (str): Path to model file
42+
* **curr_full** (str): Full path to model file
4043
* **scaling_factor** (float): Scaling factor to test
4144
* **trial_params** (dict): Trial parameters including stimulus and task settings
45+
* **task_name** (str): Name of the task to evaluate
4246

4347
Returns performance metrics for the given scaling factor.
4448

@@ -54,28 +58,28 @@ Example Usage
5458
5559
# Grid search with custom parameters
5660
lambda_grid_search(
57-
model_path='models/go-nogo/trained_model.mat',
58-
scaling_range=(30, 80),
59-
n_trials_per_factor=50,
60-
task_type='go-nogo',
61-
parallel=True
61+
model_dir='models/go-nogo',
62+
n_trials=50,
63+
scaling_factors=list(range(30, 81, 5)),
64+
task_name='go-nogo'
6265
)
6366
64-
# Evaluate a single scaling factor
67+
# Evaluate a single trial
6568
from spiking.lambda_grid_search import evaluate_single_trial
6669
6770
performance = evaluate_single_trial(
6871
model_path='models/go-nogo/trained_model.mat',
6972
scaling_factor=50.0,
70-
trial_params={'stimulus': stimulus, 'task': 'go-nogo'}
73+
trial_params={},
74+
task_name='go-nogo'
7175
)
7276
7377
Optimization Process
7478
----------------------------------------------------
7579

7680
The grid search follows these steps:
7781

78-
1. **Load trained rate model** from the specified path
82+
1. **Load trained rate models** from the specified directory
7983
2. **Generate test stimuli** appropriate for the task type
8084
3. **Iterate through scaling factors** in the specified range
8185
4. **Convert to spiking network** for each scaling factor
@@ -101,18 +105,15 @@ Different metrics are used depending on the task:
101105
* Context-dependent decision accuracy
102106
* Sensory integration performance
103107

104-
Parallel Processing
108+
Output Files
105109
----------------------------------------------------
106110

107-
The module supports parallel processing using Python's multiprocessing:
108-
109-
.. code-block:: python
111+
The function modifies each input .mat file to include:
110112

111-
# Enable parallel processing (default)
112-
lambda_grid_search(parallel=True)
113+
* **opt_scaling_factor**: The optimal scaling factor found
114+
* **all_perfs**: Performance scores for all tested scaling factors
115+
* **scaling_factors**: List of all scaling factors that were tested
113116

114-
# Disable for debugging
115-
lambda_grid_search(parallel=False)
116117

117118
Output
118119
----------------------------------------------------

spiking/eval_go_nogo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@
2828

2929
from .LIF_network_fnc import LIF_network_fnc
3030

31-
def eval_go_nogo():
31+
def eval_go_nogo(model_path= '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'):
3232
# First, load one trained rate RNN
3333
# Make sure lambda_grid_search.py was performed on the model.
3434
# Update model_path to point where the trained model is
35-
model_path = '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'
3635
mat_files = [f for f in os.listdir(model_path) if f.endswith('.mat')]
3736
model_name = mat_files[0]
3837
model_path = os.path.join(model_path, model_name)

spiking/lambda_grid_search.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,41 +162,52 @@ def evaluate_single_trial(args):
162162

163163
return 0, np.zeros((1, 1000))
164164

165-
def lambda_grid_search():
165+
def lambda_grid_search(model_dir = '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0', n_trials = 100, scaling_factors = list(range(20, 76, 5)), task_name = 'go-nogo'):
166+
"""
167+
Perform a grid search to find the optimal scaling factor for a given task.
168+
169+
Args:
170+
model_dir (str): The directory containing the trained rate RNN model .mat files.
171+
Default: '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'
172+
n_trials (int): The number of trials to use to evaluate the LIF RNN.
173+
Default: 100
174+
scaling_factors (list): The scaling factor values to try for grid search.
175+
Default: [20, 25, 30, ..., 75]
176+
task_name (str): The name of the task to evaluate ('go-nogo', 'xor', or 'mante').
177+
Default: 'go-nogo'
178+
179+
The optimal scaling factor and performance metrics are saved back to the original .mat file
180+
with the following new fields:
181+
- opt_scaling_factor: The scaling factor that achieved best performance
182+
- all_perfs: Performance scores for all tested scaling factors
183+
- scaling_factors: List of all scaling factors that were tested
184+
"""
166185
# Directory containing all the trained rate RNN model .mat files
167-
model_dir = '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'
186+
# model_dir = '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'
187+
168188
mat_files = [f for f in os.listdir(model_dir) if f.endswith('.mat')]
169189

170190
# Whether to use the initial random connectivity weights
171191
# This should be set to False unless you want to compare
172192
# the effects of pre-trained vs post-trained weights
173193
use_initial_weights = False
174194

175-
# Number of trials to use to evaluate the LIF RNN
176-
n_trials = 100
177-
178-
# Scaling factor values to try for grid search
179-
# The more values it has, the longer the search
180-
scaling_factors = list(range(20, 76, 5)) # [20, 25, 30, ..., 75]
181-
182195
# Grid search
183196
for i, mat_file in enumerate(mat_files):
184197
curr_fname = mat_file
185198
curr_full = os.path.join(model_dir, curr_fname)
186199
print(f'Analyzing {curr_fname}')
187200

188-
# Get the task name
189-
task_name = None
190-
if 'go-nogo' in curr_full:
191-
task_name = 'go-nogo'
192-
elif 'mante' in curr_full:
193-
task_name = 'mante'
194-
elif 'xor' in curr_full:
195-
task_name = 'xor'
196-
197201
if task_name is None:
198-
print(f"Unknown task type for {curr_fname}")
199-
continue
202+
if 'go-nogo' in curr_full:
203+
task_name = 'go-nogo'
204+
elif 'mante' in curr_full:
205+
task_name = 'mante'
206+
elif 'xor' in curr_full:
207+
task_name = 'xor'
208+
else:
209+
print(f"Unknown task type for {curr_fname}")
210+
continue
200211

201212
# Load the model
202213
model_data = sio.loadmat(curr_full)

0 commit comments

Comments
 (0)