@@ -390,7 +390,7 @@ def generate_input_stim_go_nogo(settings: Dict[str, Any], seed: bool = False) ->
390390 - u: 1xT stimulus matrix
391391 - label: Either 1 (Go trial) or 0 (NoGo trial)
392392 """
393- if seed :
393+ if seed == True :
394394 np .random .seed (42 )
395395
396396 T = settings ['T' ]
@@ -522,7 +522,7 @@ def generate_target_continuous_go_nogo(settings: Dict[str, Any], label: int, see
522522 Returns:
523523 np.ndarray: 1xT target signal array.
524524 """
525- if seed :
525+ if seed == True :
526526 np .random .seed (42 )
527527
528528 T = settings ['T' ]
@@ -543,41 +543,65 @@ def generate_target_continuous_xor(settings: Dict[str, Any], label: str) -> np.n
543543 Generate the target output signal for the XOR task.
544544
545545 Args:
546- settings (Dict[str, Any]): Dictionary containing the following keys:
547- - T: Duration of a single trial (in steps)
546+ settings (Dict[str, Any]): Dictionary containing task parameters.
548547 label (str): Either 'same' or 'diff'.
549548
550549 Returns:
551- np.ndarray: 1xT target signal array.
550+ np.ndarray: A 1D target signal array of shape (T,) .
552551 """
553552 T = settings ['T' ]
553+ stim_on = settings ['stim_on' ]
554+ stim_dur = settings ['stim_dur' ]
555+ delay = settings ['delay' ]
554556
555- target = np .zeros ((T - 1 ,))
557+ # Calculate the time when the second stimulus presentation ends
558+ task_end_T = stim_on + (2 * stim_dur ) + delay
559+
560+ # Initialize the target signal array with shape (1, T)
561+ z = np .zeros ((1 , T ))
562+
563+ # Define the target window: starts 10 steps after the task ends and lasts for 100 steps
564+ target_onset = 10 + task_end_T
565+ target_offset = target_onset + 100
566+
567+ # Assign the target value based on the label
556568 if label == 'same' :
557- target [ 200 : ] = 1
569+ z [ 0 , target_onset : target_offset ] = 1
558570 elif label == 'diff' :
559- target [200 :] = - 1
571+ z [0 , target_onset :target_offset ] = - 1
572+
573+ return np .squeeze (z )
560574
561- return target
562575
563576def generate_target_continuous_mante (settings : Dict [str , Any ], label : int ) -> np .ndarray :
564577 """
565- Generate the target output signal for the sensory integration task from Mante et al (2013) .
578+ Generate the target output signal for the sensory integration task.
566579
567580 Args:
568- settings (Dict[str, Any]): Dictionary containing the following keys:
569- - T: Duration of a single trial (in steps)
570- label (int): Either +1 or -1 (the correct decision).
581+ settings (Dict[str, Any]): Dictionary containing task parameters.
582+ label (int): Either +1 or -1, the correct decision.
571583
572584 Returns:
573- np.ndarray: 1xT target signal array.
585+ np.ndarray: A 1D target signal array of shape (T,) .
574586 """
575587 T = settings ['T' ]
588+ stim_on = settings ['stim_on' ]
589+ stim_dur = settings ['stim_dur' ]
576590
577- target = np .zeros ((T - 1 ,))
578- target [- 200 :] = label
591+ # Initialize the target signal array with shape (1, T)
592+ z = np .zeros ((1 , T ))
593+
594+ # Calculate the target onset time dynamically
595+ target_onset = stim_on + stim_dur
579596
580- return target
597+ # Assign the target value from the onset time to the end of the trial
598+ if label == 1 :
599+ z [0 , target_onset :] = 1
600+ else :
601+ z [0 , target_onset :] = - 1
602+
603+ # Squeeze the array to shape (T,) to match the original's output
604+ return np .squeeze (z )
581605
582606def loss_op (o : List [torch .Tensor ], z : np .ndarray , training_params : Dict [str , Any ]) -> torch .Tensor :
583607 """
@@ -599,9 +623,9 @@ def loss_op(o: List[torch.Tensor], z: np.ndarray, training_params: Dict[str, Any
599623 z_tensor = torch .tensor (z , dtype = torch .float32 , device = o [0 ].device )
600624
601625 for i in range (len (o )):
602- if loss_fn .lower () == 'l1' :
626+ if loss_fn .lower () == 'l1' : # mean absolute error (MAE)
603627 loss = loss + torch .norm (o [i ].squeeze () - z_tensor [i ], p = 1 )
604- elif loss_fn .lower () == 'l2' :
628+ elif loss_fn .lower () == 'l2' : # root mean squared error (RMSE)
605629 loss = loss + (o [i ].squeeze () - z_tensor [i ])** 2
606630
607631 if loss_fn .lower () == 'l2' :
0 commit comments