@@ -45,28 +45,28 @@ extern __constant__ float NESTGPUTimeResolution;
4545#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
4646{% - endfor %}
4747
48- __device__
49- double propagator_32( double tau_syn, double tau, double C, double h )
50- {
51- const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h
52- * ( tau_syn - tau ) * exp( -h / tau );
53- const double P32_singular = h / C * exp( -h / tau );
54- const double P32 =
55- -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )
56- * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );
57-
58- const double dev_P32 = fabs( P32 - P32_singular );
59-
60- if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0
61- * fabs( P32_linear ) ) )
62- {
63- return P32_singular;
64- }
65- else
66- {
67- return P32;
68- }
69- }
48+ {# __device__#}
49+ {# double propagator_32( double tau_syn, double tau, double C, double h )#}
50+ {#{#}
51+ {# const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h#}
52+ {# * ( tau_syn - tau ) * exp( -h / tau );#}
53+ {# const double P32_singular = h / C * exp( -h / tau );#}
54+ {# const double P32 =#}
55+ {# -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )#}
56+ {# * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );#}
57+ {##}
58+ {# const double dev_P32 = fabs( P32 - P32_singular );#}
59+ {##}
60+ {# if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0#}
61+ {# * fabs( P32_linear ) ) )#}
62+ {# {#}
63+ {# return P32_singular;#}
64+ {# }# }
65+ {# else#}
66+ {# {#}
67+ {# return P32;#}
68+ {# }# }
69+ {#}# }
7070
7171
7272__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
@@ -76,12 +76,22 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
7676 if (i_neuron < n_node) {
7777 float *param = param_arr + n_param*i_neuron;
7878
79- P11ex = exp( -h / tau_ex );
80- P11in = exp( -h / tau_in );
81- P22 = exp( -h / tau_m );
82- P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );
83- P21in = (float)propagator_32( tau_in, tau_m, C_m, h );
84- P20 = tau_m / C_m * ( 1.0 - P22 );
79+ {# P11ex = exp( -h / tau_ex );#}
80+ {# P11in = exp( -h / tau_in );#}
81+ {# P22 = exp( -h / tau_m );#}
82+ {# P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );#}
83+ {# P21in = (float)propagator_32( tau_in, tau_m, C_m, h );#}
84+ {# P20 = tau_m / C_m * ( 1.0 - P22 );#}
85+ {% - filter indent (4,True ) %}
86+ {% - for internals_block in neuron .get_internals_blocks () %}
87+ {% - for decl in internals_block .get_declarations () %}
88+ {% - for variable in decl .get_variables () %}
89+ {% - set variable_symbol = variable .get_scope ().resolve_to_symbol (variable .get_complete_name (), SymbolKind .VARIABLE ) %}
90+ {% - include "directives/MemberInitialization.jinja2" %}
91+ {% - endfor %}
92+ {% - endfor %}
93+ {% - endfor %}
94+ {% - endfilter %}
8595 }
8696}
8797
@@ -94,22 +104,35 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr
94104 float *var = var_arr + n_var*i_neuron;
95105 float *param = param_arr + n_param*i_neuron;
96106
97- if ( refractory_step > 0.0 ) {
98- // neuron is absolute refractory
99- refractory_step -= 1.0;
100- }
101- else { // neuron is not refractory, so evolve V
102- V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;
103- }
104- // exponential decaying PSCs
105- I_syn_ex *= P11ex;
106- I_syn_in *= P11in;
107-
108- if (V_m_rel >= Theta_rel ) { // threshold crossing
109- PushSpike(i_node_0 + i_neuron, 1.0);
110- V_m_rel = V_reset_rel;
111- refractory_step = (int)round(t_ref/NESTGPUTimeResolution);
112- }
107+ {# if ( refractory_step > 0.0 ) {#}
108+ {# // neuron is absolute refractory#}
109+ {# refractory_step -= 1.0;#}
110+ {# }#}
111+ {# else { // neuron is not refractory, so evolve V#}
112+ {# V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;#}
113+ {# }#}
114+ {# // exponential decaying PSCs#}
115+ {# I_syn_ex *= P11ex;#}
116+ {# I_syn_in *= P11in;#}
117+ {##}
118+ {# if (V_m_rel >= Theta_rel ) { // threshold crossing#}
119+ {# PushSpike(i_node_0 + i_neuron, 1.0);#}
120+ {# V_m_rel = V_reset_rel;#}
121+ {# refractory_step = (int)round(t_ref/NESTGPUTimeResolution);#}
122+ {# }#}
123+ {% - if neuron .get_update_blocks () %}
124+ {% - filter indent (2) %}
125+ {% - for block in neuron .get_update_blocks () %}
126+ {% - set ast = block .get_block () %}
127+ {% - if ast .print_comment ('*' )|length > 1 %}
128+ /*
129+ {{ast.print_comment('*')}}
130+ */
131+ {% - endif %}
132+ {% - include "directives/Block.jinja2" %}
133+ {% - endfor %}
134+ {% - endfilter %}
135+ {% - endif %}
113136 }
114137}
115138
@@ -136,29 +159,42 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
136159 scal_var_name_ = {{ neuronName }}_scal_var_name;
137160 scal_param_name_ = {{ neuronName }}_scal_param_name;
138161
139- SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms
140- SetScalParam(0, n_node, "C_m", 250.0 ); // in pF
141- SetScalParam(0, n_node, "E_L", -70.0 ); // in mV
142- SetScalParam(0, n_node, "I_e", 0.0 ); // in pA
143- SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_
144- SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_
145- SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms
146- SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms
147- // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s
148- // SetScalParam(0, n_node, "delta", 0.0 ); // in mV
149- SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms
150- SetScalParam(0, n_node, "den_delay", 0.0); // in ms
151- SetScalParam(0, n_node, "P20", 0.0);
152- SetScalParam(0, n_node, "P11ex", 0.0);
153- SetScalParam(0, n_node, "P11in", 0.0);
154- SetScalParam(0, n_node, "P21ex", 0.0);
155- SetScalParam(0, n_node, "P21in", 0.0);
156- SetScalParam(0, n_node, "P22", 0.0);
157-
158- SetScalVar(0, n_node, "I_syn_ex", 0.0 );
159- SetScalVar(0, n_node, "I_syn_in", 0.0 );
160- SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L
161- SetScalVar(0, n_node, "refractory_step", 0 );
162+ {# SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms#}
163+ {# SetScalParam(0, n_node, "C_m", 250.0 ); // in pF#}
164+ {# SetScalParam(0, n_node, "E_L", -70.0 ); // in mV#}
165+ {# SetScalParam(0, n_node, "I_e", 0.0 ); // in pA#}
166+ {# SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_#}
167+ {# SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_#}
168+ {# SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms#}
169+ {# SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms#}
170+ {# // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s#}
171+ {# // SetScalParam(0, n_node, "delta", 0.0 ); // in mV#}
172+ {# SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms#}
173+ {# SetScalParam(0, n_node, "den_delay", 0.0); // in ms#}
174+ {# SetScalParam(0, n_node, "P20", 0.0);#}
175+ {# SetScalParam(0, n_node, "P11ex", 0.0);#}
176+ {# SetScalParam(0, n_node, "P11in", 0.0);#}
177+ {# SetScalParam(0, n_node, "P21ex", 0.0);#}
178+ {# SetScalParam(0, n_node, "P21in", 0.0);#}
179+ {# SetScalParam(0, n_node, "P22", 0.0);#}
180+ {##}
181+ {# SetScalVar(0, n_node, "I_syn_ex", 0.0 );#}
182+ {# SetScalVar(0, n_node, "I_syn_in", 0.0 );#}
183+ {# SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L#}
184+ {# SetScalVar(0, n_node, "refractory_step", 0 );#}
185+
186+ {% - filter indent (2) %}
187+ {% - for variable in neuron .get_parameter_symbols () %}
188+ SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}}
189+ {% - endfor %}
190+ {% - endfilter %}
191+
192+
193+ {% - filter indent (2) %}
194+ {% - for variable in neuron .get_internal_symbols () %}
195+ SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0);
196+ {% - endfor %}
197+ {% - endfilter %}
162198
163199 // multiplication factor of input signal is always 1 for all nodes
164200 float input_weight = 1.0;
@@ -169,11 +205,11 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
169205 port_weight_port_step_ = 0;
170206
171207 // input spike signal is stored in I_syn_ex, I_syn_in
172- port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex ");
208+ port_input_arr_ = GetVarArr() + GetScalVarIdx("I_kernel_exc__X__exc_spikes ");
173209 port_input_arr_step_ = n_var_;
174210 port_input_port_step_ = 1;
175211
176- den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");
212+ {# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}
177213
178214 return 0;
179215}
0 commit comments