Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libnestutil/nest_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ constexpr uint8_t NUM_BITS_SYN_ID = 6U;
constexpr uint8_t NUM_BITS_LCID = 27U;
constexpr uint8_t NUM_BITS_PROCESSED_FLAG = 1U;
constexpr uint8_t NUM_BITS_MARKER_SPIKE_DATA = 2U;
constexpr uint8_t NUM_BITS_MARKER_ACTIVATION = 1U;
constexpr uint8_t NUM_BITS_LAG = 14U;
constexpr uint8_t NUM_BITS_DELAY = 21U;
constexpr uint8_t NUM_BITS_NODE_ID = 61U;
Expand Down
108 changes: 67 additions & 41 deletions models/eprop_iaf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ eprop_iaf::Parameters_::Parameters_()
, kappa_( 0.97 )
, kappa_reg_( 0.97 )
, eprop_isi_trace_cutoff_( 1000.0 )
, activation_interval_( 3000.0 )
{
}

Expand Down Expand Up @@ -131,6 +132,7 @@ eprop_iaf::Parameters_::get( DictionaryDatum& d ) const
def< double >( d, names::kappa, kappa_ );
def< double >( d, names::kappa_reg, kappa_reg_ );
def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ );
def< double >( d, names::activation_interval, activation_interval_ );
}

double
Expand Down Expand Up @@ -168,6 +170,7 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node )
updateValueParam< double >( d, names::kappa, kappa_, node );
updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );
updateValueParam< double >( d, names::activation_interval, activation_interval_, node );

if ( C_m_ <= 0 )
{
Expand Down Expand Up @@ -209,11 +212,16 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node )
throw BadProperty( "Firing rate low-pass filter for regularization kappa_reg from range [0, 1] required." );
}

if ( eprop_isi_trace_cutoff_ < 0.0 )
if ( activation_interval_ < 0 )
{
throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." );
throw BadProperty( "Interval between activations activation_interval ≥ 0 required." );
}

if ( eprop_isi_trace_cutoff_ < 0.0 or eprop_isi_trace_cutoff_ > activation_interval_ )
{
throw BadProperty(
"Computation cutoff of eprop trace 0 ≤ eprop trace eprop_isi_trace_cutoff ≤ activation_interval required." );
}
return delta_EL;
}

Expand Down Expand Up @@ -271,6 +279,7 @@ eprop_iaf::pre_run_hook()

V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps();
V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps();
V_.activation_interval_steps_ = Time( Time::ms( P_.activation_interval_ ) ).get_steps();

// calculate the entries of the propagator matrix for the evolution of the state vector

Expand Down Expand Up @@ -314,6 +323,14 @@ eprop_iaf::update( Time const& origin, const long from, const long to )
S_.z_ = 1.0;
S_.v_m_ -= P_.V_th_ * S_.z_;
S_.r_ = V_.RefractoryCounts_;
set_last_event_time( t );
}
else if ( get_last_event_time() > 0 and t - get_last_event_time() >= V_.activation_interval_steps_ )
{
SpikeEvent se;
se.set_activation();
kernel().event_delivery_manager.send( *this, se, lag );
set_last_event_time( t );
}

append_new_eprop_history_entry( t );
Expand Down Expand Up @@ -380,61 +397,70 @@ eprop_iaf::compute_gradient( const long t_spike,
double& epsilon,
double& weight,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
WeightOptimizer* optimizer,
const bool activation,
const bool previous_event_was_activation,
double& sum_grad )
{
double e = 0.0; // eligibility trace
double z = 0.0; // spiking variable
double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
double psi = 0.0; // surrogate gradient
double L = 0.0; // learning signal
double firing_rate_reg = 0.0; // firing rate regularization
double grad = 0.0; // gradient
const auto& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto& opt_cp = *ecp.optimizer_cp_;
const bool optimize_each_step = opt_cp.optimize_each_step_;

const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
if ( not previous_event_was_activation )
{
sum_grad = 0.0; // sum of gradients
}

auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 );

const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike );
const long cutoff_end = t_spike_previous + V_.eprop_isi_trace_cutoff_steps_;
const long t_compute_until = std::min( cutoff_end, t_spike );

for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
if ( not previous_event_was_activation )
{
z = z_previous_buffer;
z_previous_buffer = z_current_buffer;
z_current_buffer = 0.0;

psi = eprop_hist_it->surrogate_gradient_;
L = eprop_hist_it->learning_signal_;
firing_rate_reg = eprop_hist_it->firing_rate_reg_;
double z_current_buffer = 1.0; // spike that triggered current computation

z_bar = V_.P_v_m_ * z_bar + z;
e = psi * z_bar;
e_bar = P_.kappa_ * e_bar + e;
e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e;

if ( optimize_each_step )
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
grad = L * e_bar + firing_rate_reg * e_bar_reg;
weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight );
}
else
{
grad += L * e_bar + firing_rate_reg * e_bar_reg;
const double z = z_previous_buffer; // spiking variable
z_previous_buffer = z_current_buffer;
z_current_buffer = 0.0;

const double psi = eprop_hist_it->surrogate_gradient_; // surrogate gradient
const double L = eprop_hist_it->learning_signal_; // learning signal
const double firing_rate_reg = eprop_hist_it->firing_rate_reg_; // firing rate regularization

z_bar = V_.P_v_m_ * z_bar + z;
const double e = psi * z_bar; // eligibility trace
e_bar = P_.kappa_ * e_bar + e;
e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e;

const double grad = L * e_bar + firing_rate_reg * e_bar_reg;

if ( optimize_each_step )
{
sum_grad = grad;
weight = optimizer->optimized_weight( opt_cp, t, sum_grad, weight );
}
else
{
sum_grad += grad;
}
}
}

if ( not optimize_each_step )
const long trace_decay_interval = t_spike - ( previous_event_was_activation ? t_spike_previous : t_compute_until );

if ( trace_decay_interval > 0 )
{
weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight );
z_bar *= std::pow( V_.P_v_m_, trace_decay_interval );
e_bar *= std::pow( P_.kappa_, trace_decay_interval );
e_bar_reg *= std::pow( P_.kappa_reg_, trace_decay_interval );
}

const long cutoff_to_spike_interval = t_spike - t_compute_until;

if ( cutoff_to_spike_interval > 0 )
if ( not( activation or optimize_each_step ) )
{
z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval );
e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval );
e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval );
weight = optimizer->optimized_weight( opt_cp, t_compute_until, sum_grad, weight );
}
}

Expand Down
13 changes: 12 additions & 1 deletion models/eprop_iaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ Parameter Unit Math equivalent Default Des
----------------------------------------------------------------------------------------------------------------
Parameter Unit Math equivalent Default Description
=============================== ======= =========================== ================== =========================
``activation_interval`` ms 3000.0 Interval between two
activations
``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing
rate regularization
``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of
Expand Down Expand Up @@ -397,7 +399,10 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >
double&,
double&,
const CommonSynapseProperties&,
WeightOptimizer* ) override;
WeightOptimizer*,
const bool,
const bool,
double& ) override;

long get_shift() const override;
bool is_eprop_recurrent_node() const override;
Expand Down Expand Up @@ -458,6 +463,9 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;

//! Interval between two activations.
long activation_interval_;

//! Default constructor.
Parameters_();

Expand Down Expand Up @@ -535,6 +543,9 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >

//! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes.
long eprop_isi_trace_cutoff_steps_;

//! Time steps of activation interval.
long activation_interval_steps_;
};

//! Get the current value of the membrane voltage.
Expand Down
Loading
Loading