Skip to content

Commit 2db85a0

Browse files
committed
Update templates
1 parent 7e1d4ac commit 2db85a0

File tree

1 file changed

+105
-69
lines changed

1 file changed

+105
-69
lines changed

pynestml/codegeneration/resources_nest_gpu/@[email protected]

Lines changed: 105 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,28 @@ extern __constant__ float NESTGPUTimeResolution;
4545
#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
4646
{%- endfor %}
4747

48-
__device__
49-
double propagator_32( double tau_syn, double tau, double C, double h )
50-
{
51-
const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h
52-
* ( tau_syn - tau ) * exp( -h / tau );
53-
const double P32_singular = h / C * exp( -h / tau );
54-
const double P32 =
55-
-tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )
56-
* expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );
57-
58-
const double dev_P32 = fabs( P32 - P32_singular );
59-
60-
if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0
61-
* fabs( P32_linear ) ) )
62-
{
63-
return P32_singular;
64-
}
65-
else
66-
{
67-
return P32;
68-
}
69-
}
48+
{#__device__#}
49+
{#double propagator_32( double tau_syn, double tau, double C, double h )#}
50+
{#{#}
51+
{# const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h#}
52+
{# * ( tau_syn - tau ) * exp( -h / tau );#}
53+
{# const double P32_singular = h / C * exp( -h / tau );#}
54+
{# const double P32 =#}
55+
{# -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )#}
56+
{# * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );#}
57+
{##}
58+
{# const double dev_P32 = fabs( P32 - P32_singular );#}
59+
{##}
60+
{# if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0#}
61+
{# * fabs( P32_linear ) ) )#}
62+
{# {#}
63+
{# return P32_singular;#}
64+
{# }#}
65+
{# else#}
66+
{# {#}
67+
{# return P32;#}
68+
{# }#}
69+
{#}#}
7070

7171

7272
__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
@@ -76,12 +76,22 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
7676
if (i_neuron < n_node) {
7777
float *param = param_arr + n_param*i_neuron;
7878

79-
P11ex = exp( -h / tau_ex );
80-
P11in = exp( -h / tau_in );
81-
P22 = exp( -h / tau_m );
82-
P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );
83-
P21in = (float)propagator_32( tau_in, tau_m, C_m, h );
84-
P20 = tau_m / C_m * ( 1.0 - P22 );
79+
{# P11ex = exp( -h / tau_ex );#}
80+
{# P11in = exp( -h / tau_in );#}
81+
{# P22 = exp( -h / tau_m );#}
82+
{# P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );#}
83+
{# P21in = (float)propagator_32( tau_in, tau_m, C_m, h );#}
84+
{# P20 = tau_m / C_m * ( 1.0 - P22 );#}
85+
{%- filter indent(4,True) %}
86+
{%- for internals_block in neuron.get_internals_blocks() %}
87+
{%- for decl in internals_block.get_declarations() %}
88+
{%- for variable in decl.get_variables() %}
89+
{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
90+
{%- include "directives/MemberInitialization.jinja2" %}
91+
{%- endfor %}
92+
{%- endfor %}
93+
{%- endfor %}
94+
{%- endfilter %}
8595
}
8696
}
8797

@@ -94,22 +104,35 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr
94104
float *var = var_arr + n_var*i_neuron;
95105
float *param = param_arr + n_param*i_neuron;
96106

97-
if ( refractory_step > 0.0 ) {
98-
// neuron is absolute refractory
99-
refractory_step -= 1.0;
100-
}
101-
else { // neuron is not refractory, so evolve V
102-
V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;
103-
}
104-
// exponential decaying PSCs
105-
I_syn_ex *= P11ex;
106-
I_syn_in *= P11in;
107-
108-
if (V_m_rel >= Theta_rel ) { // threshold crossing
109-
PushSpike(i_node_0 + i_neuron, 1.0);
110-
V_m_rel = V_reset_rel;
111-
refractory_step = (int)round(t_ref/NESTGPUTimeResolution);
112-
}
107+
{# if ( refractory_step > 0.0 ) {#}
108+
{# // neuron is absolute refractory#}
109+
{# refractory_step -= 1.0;#}
110+
{# }#}
111+
{# else { // neuron is not refractory, so evolve V#}
112+
{# V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;#}
113+
{# }#}
114+
{# // exponential decaying PSCs#}
115+
{# I_syn_ex *= P11ex;#}
116+
{# I_syn_in *= P11in;#}
117+
{##}
118+
{# if (V_m_rel >= Theta_rel ) { // threshold crossing#}
119+
{# PushSpike(i_node_0 + i_neuron, 1.0);#}
120+
{# V_m_rel = V_reset_rel;#}
121+
{# refractory_step = (int)round(t_ref/NESTGPUTimeResolution);#}
122+
{# }#}
123+
{%- if neuron.get_update_blocks() %}
124+
{%- filter indent(2) %}
125+
{%- for block in neuron.get_update_blocks() %}
126+
{%- set ast = block.get_block() %}
127+
{%- if ast.print_comment('*')|length > 1 %}
128+
/*
129+
{{ast.print_comment('*')}}
130+
*/
131+
{%- endif %}
132+
{%- include "directives/Block.jinja2" %}
133+
{%- endfor %}
134+
{%- endfilter %}
135+
{%- endif %}
113136
}
114137
}
115138

@@ -136,29 +159,42 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
136159
scal_var_name_ = {{ neuronName }}_scal_var_name;
137160
scal_param_name_ = {{ neuronName }}_scal_param_name;
138161

139-
SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms
140-
SetScalParam(0, n_node, "C_m", 250.0 ); // in pF
141-
SetScalParam(0, n_node, "E_L", -70.0 ); // in mV
142-
SetScalParam(0, n_node, "I_e", 0.0 ); // in pA
143-
SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_
144-
SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_
145-
SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms
146-
SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms
147-
// SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s
148-
// SetScalParam(0, n_node, "delta", 0.0 ); // in mV
149-
SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms
150-
SetScalParam(0, n_node, "den_delay", 0.0); // in ms
151-
SetScalParam(0, n_node, "P20", 0.0);
152-
SetScalParam(0, n_node, "P11ex", 0.0);
153-
SetScalParam(0, n_node, "P11in", 0.0);
154-
SetScalParam(0, n_node, "P21ex", 0.0);
155-
SetScalParam(0, n_node, "P21in", 0.0);
156-
SetScalParam(0, n_node, "P22", 0.0);
157-
158-
SetScalVar(0, n_node, "I_syn_ex", 0.0 );
159-
SetScalVar(0, n_node, "I_syn_in", 0.0 );
160-
SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L
161-
SetScalVar(0, n_node, "refractory_step", 0 );
162+
{# SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms#}
163+
{# SetScalParam(0, n_node, "C_m", 250.0 ); // in pF#}
164+
{# SetScalParam(0, n_node, "E_L", -70.0 ); // in mV#}
165+
{# SetScalParam(0, n_node, "I_e", 0.0 ); // in pA#}
166+
{# SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_#}
167+
{# SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_#}
168+
{# SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms#}
169+
{# SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms#}
170+
{# // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s#}
171+
{# // SetScalParam(0, n_node, "delta", 0.0 ); // in mV#}
172+
{# SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms#}
173+
{# SetScalParam(0, n_node, "den_delay", 0.0); // in ms#}
174+
{# SetScalParam(0, n_node, "P20", 0.0);#}
175+
{# SetScalParam(0, n_node, "P11ex", 0.0);#}
176+
{# SetScalParam(0, n_node, "P11in", 0.0);#}
177+
{# SetScalParam(0, n_node, "P21ex", 0.0);#}
178+
{# SetScalParam(0, n_node, "P21in", 0.0);#}
179+
{# SetScalParam(0, n_node, "P22", 0.0);#}
180+
{##}
181+
{# SetScalVar(0, n_node, "I_syn_ex", 0.0 );#}
182+
{# SetScalVar(0, n_node, "I_syn_in", 0.0 );#}
183+
{# SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L#}
184+
{# SetScalVar(0, n_node, "refractory_step", 0 );#}
185+
186+
{%- filter indent(2) %}
187+
{%- for variable in neuron.get_parameter_symbols() %}
188+
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}}
189+
{%- endfor %}
190+
{%- endfilter %}
191+
192+
193+
{%- filter indent(2) %}
194+
{%- for variable in neuron.get_internal_symbols() %}
195+
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0);
196+
{%- endfor %}
197+
{%- endfilter %}
162198

163199
// multiplication factor of input signal is always 1 for all nodes
164200
float input_weight = 1.0;
@@ -169,11 +205,11 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
169205
port_weight_port_step_ = 0;
170206

171207
// input spike signal is stored in I_syn_ex, I_syn_in
172-
port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex");
208+
port_input_arr_ = GetVarArr() + GetScalVarIdx("I_kernel_exc__X__exc_spikes");
173209
port_input_arr_step_ = n_var_;
174210
port_input_port_step_ = 1;
175211

176-
den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");
212+
{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}
177213

178214
return 0;
179215
}

0 commit comments

Comments
 (0)