Skip to content

Commit 515e36b

Browse files
committed
fix bug for generate target functions
1 parent fefbc86 commit 515e36b

1 file changed

Lines changed: 43 additions & 19 deletions

File tree

rate/model.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def generate_input_stim_go_nogo(settings: Dict[str, Any], seed: bool = False) ->
390390
- u: 1xT stimulus matrix
391391
- label: Either 1 (Go trial) or 0 (NoGo trial)
392392
"""
393-
if seed:
393+
if seed == True:
394394
np.random.seed(42)
395395

396396
T = settings['T']
@@ -522,7 +522,7 @@ def generate_target_continuous_go_nogo(settings: Dict[str, Any], label: int, see
522522
Returns:
523523
np.ndarray: 1xT target signal array.
524524
"""
525-
if seed:
525+
if seed == True:
526526
np.random.seed(42)
527527

528528
T = settings['T']
@@ -543,41 +543,65 @@ def generate_target_continuous_xor(settings: Dict[str, Any], label: str) -> np.n
543543
Generate the target output signal for the XOR task.
544544
545545
Args:
546-
settings (Dict[str, Any]): Dictionary containing the following keys:
547-
- T: Duration of a single trial (in steps)
546+
settings (Dict[str, Any]): Dictionary containing task parameters.
548547
label (str): Either 'same' or 'diff'.
549548
550549
Returns:
551-
np.ndarray: 1xT target signal array.
550+
np.ndarray: A 1D target signal array of shape (T,).
552551
"""
553552
T = settings['T']
553+
stim_on = settings['stim_on']
554+
stim_dur = settings['stim_dur']
555+
delay = settings['delay']
554556

555-
target = np.zeros((T-1,))
557+
# Calculate the time when the second stimulus presentation ends
558+
task_end_T = stim_on + (2 * stim_dur) + delay
559+
560+
# Initialize the target signal array with shape (1, T)
561+
z = np.zeros((1, T))
562+
563+
# Define the target window: starts 10 steps after the task ends and lasts for 100 steps
564+
target_onset = 10 + task_end_T
565+
target_offset = target_onset + 100
566+
567+
# Assign the target value based on the label
556568
if label == 'same':
557-
target[200:] = 1
569+
z[0, target_onset:target_offset] = 1
558570
elif label == 'diff':
559-
target[200:] = -1
571+
z[0, target_onset:target_offset] = -1
572+
573+
return np.squeeze(z)
560574

561-
return target
562575

563576
def generate_target_continuous_mante(settings: Dict[str, Any], label: int) -> np.ndarray:
564577
"""
565-
Generate the target output signal for the sensory integration task from Mante et al (2013).
578+
Generate the target output signal for the sensory integration task.
566579
567580
Args:
568-
settings (Dict[str, Any]): Dictionary containing the following keys:
569-
- T: Duration of a single trial (in steps)
570-
label (int): Either +1 or -1 (the correct decision).
581+
settings (Dict[str, Any]): Dictionary containing task parameters.
582+
label (int): Either +1 or -1, the correct decision.
571583
572584
Returns:
573-
np.ndarray: 1xT target signal array.
585+
np.ndarray: A 1D target signal array of shape (T,).
574586
"""
575587
T = settings['T']
588+
stim_on = settings['stim_on']
589+
stim_dur = settings['stim_dur']
576590

577-
target = np.zeros((T-1,))
578-
target[-200:] = label
591+
# Initialize the target signal array with shape (1, T)
592+
z = np.zeros((1, T))
593+
594+
# Calculate the target onset time dynamically
595+
target_onset = stim_on + stim_dur
579596

580-
return target
597+
# Assign the target value from the onset time to the end of the trial
598+
if label == 1:
599+
z[0, target_onset:] = 1
600+
else:
601+
z[0, target_onset:] = -1
602+
603+
# Squeeze the array to shape (T,) to match the original's output
604+
return np.squeeze(z)
581605

582606
def loss_op(o: List[torch.Tensor], z: np.ndarray, training_params: Dict[str, Any]) -> torch.Tensor:
583607
"""
@@ -599,9 +623,9 @@ def loss_op(o: List[torch.Tensor], z: np.ndarray, training_params: Dict[str, Any
599623
z_tensor = torch.tensor(z, dtype=torch.float32, device=o[0].device)
600624

601625
for i in range(len(o)):
602-
if loss_fn.lower() == 'l1':
626+
if loss_fn.lower() == 'l1': # mean absolute error (MAE)
603627
loss = loss + torch.norm(o[i].squeeze() - z_tensor[i], p=1)
604-
elif loss_fn.lower() == 'l2':
628+
elif loss_fn.lower() == 'l2': # root mean squared error (RMSE)
605629
loss = loss + (o[i].squeeze() - z_tensor[i])**2
606630

607631
if loss_fn.lower() == 'l2':

0 commit comments

Comments
 (0)