diff --git a/doc/running/running_nest.rst b/doc/running/running_nest.rst index 2edbee3d3..530daf33a 100644 --- a/doc/running/running_nest.rst +++ b/doc/running/running_nest.rst @@ -25,7 +25,7 @@ Event-based updating of synapses The synapse is allowed to contain an ``update`` block. Statements in the ``update`` block are executed whenever the internal state of the synapse is updated from one timepoint to the next; these updates are typically triggered by incoming spikes. The NESTML ``timestep()`` function will return the time that has elapsed since the last event was handled. -Synapses in NEST are not allowed to have any nonlinear time-based internal dynamics (ODEs). This is due to the fact that synapses are, unlike nodes, not updated on a regular time grid. Linear ODEs are allowed, because they admit an analytical solution, which can be updated in a single step from the previous event time to the current event time. However, nonlinear dynamics are not allowed because they would require a numeric solver evaluating the dynamics on a regular time grid. +Synapses can have ODEs with linear and non-linear dynamics. In the case of linear dynamics, the ODEs are solved with the propagators provided by the ODE-toolbox; for non-linear dynamics, the ODEs are solved using a fourth order Runge-Kutta solver with adaptive timestep. If ODE-toolbox is not successful in finding the propagator solver to a system of ODEs that is, however, solvable, the propagators may be entered "by hand" in the ``update`` block of the model. This block may contain any series of statements to update the state of the system from the current timestep to the next, for example, multiplications of state variables by the propagators. diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 066109e3d..6ca0d641f 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -223,14 +223,23 @@ def setup_printers(self): self._gsl_variable_printer = GSLVariablePrinter(None) if self.option_exists("nest_version") and (self.get_option("nest_version").startswith("2") or self.get_option("nest_version").startswith("v2")): self._gsl_function_call_printer = NEST2GSLFunctionCallPrinter(None) + self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None) else: self._gsl_function_call_printer = NESTGSLFunctionCallPrinter(None) + self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None) self._gsl_printer = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer, constant_printer=self._constant_printer, function_call_printer=self._gsl_function_call_printer)) self._gsl_function_call_printer._expression_printer = self._gsl_printer + self._gsl_variable_printer_no_origin = GSLVariablePrinter(None, with_origin=False) + self._gsl_printer_no_origin = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer_no_origin, + constant_printer=self._constant_printer, + function_call_printer=self._gsl_function_call_printer)) + self._gsl_variable_printer_no_origin._expression_printer = self._gsl_printer_no_origin + self._gsl_function_call_printer_no_origin._expression_printer = self._gsl_printer_no_origin + # ODE-toolbox printers self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) @@ -521,6 +530,7 @@ def _get_model_namespace(self, astnode: ASTModel) -> Dict: namespace["printer"] = self._nest_printer namespace["printer_no_origin"] = self._printer_no_origin namespace["gsl_printer"] = self._gsl_printer + namespace["gsl_printer_no_origin"] = self._gsl_printer_no_origin namespace["nestml_printer"] = NESTMLPrinter() namespace["type_symbol_printer"] = self._type_symbol_printer @@ -666,6 +676,9 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast + ASTUtils.assign_numeric_non_numeric_state_variables(synapse, namespace["numeric_state_variables"], + namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None) + namespace["spike_updates"] = synapse.spike_updates # special case for NEST delay variable (state or parameter) diff --git a/pynestml/codegeneration/nest_code_generator_utils.py b/pynestml/codegeneration/nest_code_generator_utils.py index 342c2321e..4ff5c7e9a 100644 --- a/pynestml/codegeneration/nest_code_generator_utils.py +++ b/pynestml/codegeneration/nest_code_generator_utils.py @@ -58,9 +58,6 @@ def print_symbol_origin(cls, variable_symbol: VariableSymbol, variable: ASTVaria if variable_symbol.block_type == BlockType.INTERNALS: return "V_.%s" - if variable_symbol.block_type == BlockType.INPUT: - return "B_.%s" - return "" @classmethod diff --git a/pynestml/codegeneration/printers/gsl_variable_printer.py b/pynestml/codegeneration/printers/gsl_variable_printer.py index c9cfbc46f..64797a3e2 100644 --- a/pynestml/codegeneration/printers/gsl_variable_printer.py +++ b/pynestml/codegeneration/printers/gsl_variable_printer.py @@ -18,12 +18,13 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils from pynestml.codegeneration.nest_unit_converter import NESTUnitConverter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.expression_printer import ExpressionPrinter from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.ast_utils import ASTUtils from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -33,46 +34,42 @@ class GSLVariablePrinter(CppVariablePrinter): Variable printer for C++ syntax and using the GSL (GNU Scientific Library) API from inside the ``extern "C"`` stepping function. """ - def print_variable(self, node: ASTVariable) -> str: + def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = True, ): + super().__init__(expression_printer) + self.with_origin = with_origin + + def print_variable(self, variable: ASTVariable) -> str: """ Converts a single name reference to a gsl processable format. - :param node: a single variable + :param variable: a single variable :return: a gsl processable format of the variable """ - assert isinstance(node, ASTVariable) - symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE) + assert isinstance(variable, ASTVariable) + symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) if symbol is None: # test if variable name can be resolved to a type - if PredefinedUnits.is_unit(node.get_complete_name()): - return str(NESTUnitConverter.get_factor(PredefinedUnits.get_unit(node.get_complete_name()).get_unit())) + if PredefinedUnits.is_unit(variable.get_complete_name()): + return str( + NESTUnitConverter.get_factor(PredefinedUnits.get_unit(variable.get_complete_name()).get_unit())) - code, message = Messages.get_could_not_resolve(node.get_name()) + code, message = Messages.get_could_not_resolve(variable.get_name()) Logger.log_message(log_level=LoggingLevel.ERROR, code=code, message=message, - error_position=node.get_source_position()) + error_position=variable.get_source_position()) return "" - if node.is_delay_variable(): - return self._print_delay_variable(node) + if variable.is_delay_variable(): + return self._print_delay_variable(variable) if symbol.is_state() and not symbol.is_inline_expression: - if "_is_numeric" in dir(node) and node._is_numeric: + if "_is_numeric" in dir(variable) and variable._is_numeric: # ode_state[] here is---and must be---the state vector supplied by the integrator, not the state vector in the node, node.S_.ode_state[]. - return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(node.get_complete_name()) + "]" - - # non-ODE state symbol - return "node.S_." + CppVariablePrinter._print_cpp_name(node.get_complete_name()) - - if symbol.is_parameters(): - return "node.P_." + super().print_variable(node) - - if symbol.is_internals(): - return "node.V_." + super().print_variable(node) + return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + "]" if symbol.is_input(): - return "node.B_." + self._print_buffer_value(node) + return "node.B_." + self._print_buffer_value(variable) - raise Exception("Unknown node type") + return self._print(variable, symbol, with_origin=self.with_origin) def _print_delay_variable(self, variable: ASTVariable) -> str: """ @@ -105,3 +102,9 @@ def _print_buffer_value(self, variable: ASTVariable) -> str: assert variable_symbol.is_continuous_input_port() return "continuous_inputs_grid_sum_[" + variable.get_name().upper() + "]" + + def _print(self, variable, symbol, with_origin: bool = True): + variable_name = CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + if with_origin: + return "node." + NESTCodeGeneratorUtils.print_symbol_origin(symbol, variable) % variable_name + return "node." + variable_name diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 index b3cd2142b..bc87e1bea 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 @@ -69,6 +69,19 @@ along with NEST. If not, see . #include "volume_transmitter.h" {%- endif %} +{%- if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + +#ifndef HAVE_GSL +#error "The GSL library is required for the Runge-Kutta solver." +#endif + +// External includes: +#include +#include +#include +{%- endif %} +{%- endif %} // Includes from sli: #include "dictdatum.h" @@ -101,9 +114,22 @@ namespace {{names_namespace}} const Name _{{sym.get_symbol_name()}}( "{{sym.get_symbol_name()}}" ); {%- endfor %} {%- endif %} -} +{%- if uses_numeric_solver %} + const Name gsl_abs_error_tol("gsl_abs_error_tol"); + const Name gsl_rel_error_tol("gsl_rel_error_tol"); +{%- endif %} +} // end namespace {{names_namespace}}; -class {{synapseName}}CommonSynapseProperties : public CommonSynapseProperties { +namespace {{ synapseName }} +{ +{%- if uses_numeric_solver %} +{%- for s in utils.create_integrate_odes_combinations(astnode) %} +extern "C" inline int {{synapseName}}_dynamics{% if s | length > 0 %}_{{ s }}{% endif %}( double, const double ode_state[], double f[], void* pnode ); +{%- endfor %} +{%- endif %} + +class {{synapseName}}CommonSynapseProperties : public CommonSynapseProperties +{ public: {{synapseName}}CommonSynapseProperties() @@ -214,43 +240,7 @@ public: } {%- endif %} -}; - -template < typename targetidentifierT > -class {{synapseName}} : public Connection< targetidentifierT > -{ -{%- if paired_neuron_name | length > 0 %} - typedef {{ paired_neuron_name }} post_neuron_t; - -{% endif %} -{%- if vt_ports is defined and vt_ports|length > 0 %} -public: -{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} - void trigger_update_weight( size_t t, - const std::vector< spikecounter >& vt_spikes, - double t_trig, - const {{synapseName}}CommonSynapseProperties& cp ); -{%- else %} - void trigger_update_weight( thread t, - const std::vector< spikecounter >& vt_spikes, - double t_trig, - const {{synapseName}}CommonSynapseProperties& cp ); -{%- endif %} -{%- endif %} -private: - double t_lastspike_; -{%- if vt_ports is defined and vt_ports|length > 0 %} - // time of last update, which is either time of last presyn. spike or time-driven update - double t_last_update_; - - // vt_spikes_idx_ refers to the vt spike that has just been processed after trigger_update_weight - // a pseudo vt spike at t_trig is stored at index 0 and vt_spikes_idx_ = 0 -{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} - size_t vt_spikes_idx_; -{%- else %} - index vt_spikes_idx_; -{%- endif %} -{%- endif %} +}; // end class {{synapseName}}CommonSynapseProperties /** * Dynamic state of the synapse. @@ -293,11 +283,13 @@ private: //! state vector, must be C-array for GSL solver double ode_state[STATE_VEC_SIZE]; - // state variables from state block +{# // state variables from state block#} {%- filter indent(4,True) %} {%- for variable_symbol in synapse.get_state_symbols() %} -{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- include "directives_cpp/MemberDeclaration.jinja2" %} +{% if variable_symbol.get_symbol_name() not in numeric_state_variables %} +{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_cpp/MemberDeclaration.jinja2" %} +{%- endif %} {%- endfor %} {%- endfilter %} {%- endif %} @@ -338,10 +330,49 @@ private: {%- endif %} {%- endfor %} {%- endfilter %} + double __gsl_abs_error_tol; + double __gsl_rel_error_tol; /** Initialize parameters to their default values. */ Parameters_() {}; - }; + }; // end Parameters_ + + +template < typename targetidentifierT > +class {{synapseName}} : public Connection< targetidentifierT > +{ +{%- if paired_neuron_name | length > 0 %} + typedef {{ paired_neuron_name }} post_neuron_t; + +{% endif %} +{%- if vt_ports is defined and vt_ports|length > 0 %} +public: +{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} + void trigger_update_weight( size_t t, + const std::vector< spikecounter >& vt_spikes, + double t_trig, + const {{synapseName}}CommonSynapseProperties& cp ); +{%- else %} + void trigger_update_weight( thread t, + const std::vector< spikecounter >& vt_spikes, + double t_trig, + const {{synapseName}}CommonSynapseProperties& cp ); +{%- endif %} +{%- endif %} +private: + double t_lastspike_; +{%- if vt_ports is defined and vt_ports|length > 0 %} + // time of last update, which is either time of last presyn. spike or time-driven update + double t_last_update_; + + // vt_spikes_idx_ refers to the vt spike that has just been processed after trigger_update_weight + // a pseudo vt spike at t_trig is stored at index 0 and vt_spikes_idx_ = 0 +{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} + size_t vt_spikes_idx_; +{%- else %} + index vt_spikes_idx_; +{%- endif %} +{%- endif %} /** * Internal variables of the synapse. @@ -361,10 +392,30 @@ private: {%- endfor %} }; +{%- if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + gsl_odeiv_step* __s = nullptr; //!< stepping function + gsl_odeiv_control* __c = nullptr; //!< adaptive stepsize control function + gsl_odeiv_evolve* __e = nullptr; //!< evolution function + gsl_odeiv_system __sys; //!< struct describing system + + // __integration_step should be reset with the neuron on ResetNetwork, + // but remain unchanged during calibration. Since it is initialized with + // step_, and the resolution cannot change after nodes have been created, + // it is safe to place both here. + double __step; //!< step size in ms + double __integration_step; //!< current integration time step, updated by GSL +{%- endif %} +{%- endif %} + Parameters_ P_; //!< Free parameters. State_ S_; //!< Dynamic state. Variables_ V_; //!< Internal Variables -{%- if synapse.get_state_symbols()|length > 0 or synapse.get_parameter_symbols()|length > 0 %} + +{%- for s in utils.create_integrate_odes_combinations(astnode) %} + friend int {{synapseName}}_dynamics{% if s | length > 0 %}_{{ s }}{% endif %}( double, const double ode_state[], double f[], void* pnode ); +{%- endfor %} + // ------------------------------------------------------------------------- // Getters/setters for parameters and state variables // ------------------------------------------------------------------------- @@ -390,7 +441,6 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t {%- endif %} {%- endfor %} {%- endfilter %} -{%- endif %} // ------------------------------------------------------------------------- // Getters/setters for inline expressions @@ -424,6 +474,8 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t void recompute_internal_variables(); + std::string get_name() const; + public: // this line determines which common properties to use typedef {{synapseName}}CommonSynapseProperties CommonPropertiesType; @@ -676,7 +728,7 @@ void get_entry_from_continuous_variable_history(double t, runner = start; while ( runner != finish ) { - if ( fabs( t - runner->t_ ) < nest::kernel().connection_manager.get_stdp_eps() ) + if ( fabs( t - runner->t_ ) < kernel().connection_manager.get_stdp_eps() ) { histentry = *runner; return; @@ -704,7 +756,7 @@ void get_entry_from_continuous_variable_history(double t, send( Event& e, const thread tid, const {{synapseName}}CommonSynapseProperties& cp ) {%- endif %} { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function auto get_thread = [tid]() { @@ -980,7 +1032,7 @@ void get_entry_from_continuous_variable_history(double t, {%- if nest_version.startswith("v2") %} librandom::NormalRandomDev normal_dev_; //!< random deviate generator {%- else %} - nest::normal_distribution normal_dev_; //!< random deviate generator + normal_distribution normal_dev_; //!< random deviate generator {%- endif %} {%- endif %} }; @@ -990,7 +1042,7 @@ void get_entry_from_continuous_variable_history(double t, void register_{{ synapseName }}( const std::string& name ) { - nest::register_connection_model< {{ synapseName }} >( name ); + register_connection_model< {{ synapseName }} >( name ); } {%- endif %} @@ -1070,6 +1122,14 @@ void } {%- endif %} +/* +** Synapse dynamics +*/ +{% if uses_numeric_solver %} +{%- for ast in utils.get_all_integrate_odes_calls_unique(synapse) %} +{%- include "directives_cpp/GSLDifferentiationFunction.jinja2" %} +{%- endfor %} +{%- endif %} template < typename targetidentifierT > void @@ -1087,17 +1147,29 @@ void {%- if variable.get_name() == nest_codegen_opt_delay_variable %} {#- special case for NEST special variable delay #} def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, names::delay, {{ printer.print(variable) }} ); // NEST special case for delay variable -def(__d, nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, {{ printer.print(variable) }}); +def(__d, {{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, {{ printer.print(variable) }}); {#- special case for NEST special variable weight #} {%- elif variable.get_name() == synapse_weight_variable %} def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, names::weight, {{ printer.print(variable) }} ); // NEST special case for weight variable -def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}, {{ printer.print(variable) }} ); // NEST special case for weight variable +def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, {{ synapseName }}_names::_{{ synapse_weight_variable }}, {{ printer.print(variable) }} ); // NEST special case for weight variable {%- else %} {%- include "directives_cpp/WriteInDictionary.jinja2" %} {%- endif %} {%- endif %} {%- endfor %} {%- endfilter %} +{%- if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + def< double >(__d, nest::{{ synapseName }}_names::gsl_abs_error_tol, P_.__gsl_abs_error_tol); + if ( P_.__gsl_abs_error_tol <= 0. ){ + throw nest::BadProperty( "The gsl_abs_error_tol must be strictly positive." ); + } + def< double >(__d, nest::{{ synapseName }}_names::gsl_rel_error_tol, P_.__gsl_rel_error_tol); + if ( P_.__gsl_rel_error_tol < 0. ){ + throw nest::BadProperty( "The gsl_rel_error_tol must be zero or positive." ); + } +{%- endif %} +{%- endif %} } template < typename targetidentifierT > @@ -1106,14 +1178,14 @@ void ConnectorModel& cm ) { {%- if synapse_weight_variable|length > 0 and synapse_weight_variable != "weight" %} - if (__d->known(nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}) and __d->known(nest::names::weight)) + if (__d->known({{ synapseName }}_names::_{{ synapse_weight_variable }}) and __d->known(names::weight)) { throw BadProperty( "To prevent inconsistencies, please set either 'weight' or '{{ synapse_weight_variable }}' variable; not both at the same time." ); } {%- endif %} {%- if nest_codegen_opt_delay_variable != "delay" %} - if (__d->known(nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}) and __d->known(nest::names::delay)) + if (__d->known({{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}) and __d->known(names::delay)) { throw BadProperty( "To prevent inconsistencies, please set either 'delay' or '{{ nest_codegen_opt_delay_variable }}' variable; not both at the same time." ); } @@ -1131,17 +1203,17 @@ void {%- if variable.get_name() == nest_codegen_opt_delay_variable %} // special treatment of NEST delay double tmp_{{ nest_codegen_opt_delay_variable }} = get_delay(); -updateValue(__d, nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, tmp_{{nest_codegen_opt_delay_variable}}); +updateValue(__d, {{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, tmp_{{nest_codegen_opt_delay_variable}}); {%- elif variable.get_name() == synapse_weight_variable %} // special treatment of NEST weight double tmp_{{ synapse_weight_variable }} = get_weight(); -if (__d->known(nest::{{ synapseName }}_names::_{{ synapse_weight_variable }})) +if (__d->known({{ synapseName }}_names::_{{ synapse_weight_variable }})) { - updateValue(__d, nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}, tmp_{{synapse_weight_variable}}); + updateValue(__d, {{ synapseName }}_names::_{{ synapse_weight_variable }}, tmp_{{synapse_weight_variable}}); } -if (__d->known(nest::names::weight)) +if (__d->known(names::weight)) { - updateValue(__d, nest::names::weight, tmp_{{synapse_weight_variable}}); + updateValue(__d, names::weight, tmp_{{synapse_weight_variable}}); } {%- else %} {%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} @@ -1179,13 +1251,38 @@ set_delay(tmp_{{ nest_codegen_opt_delay_variable }}); {% for invariant in synapse.get_parameter_invariants() %} if ( !({{printer.print(invariant)}}) ) { - throw nest::BadProperty("The constraint '{{nestml_printer.print(invariant)}}' is violated!"); + throw BadProperty("The constraint '{{nestml_printer.print(invariant)}}' is violated!"); } {%- endfor %} +{%- endif %} + +{% if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + updateValue< double >(__d, nest::{{ synapseName }}_names::gsl_abs_error_tol, P_.__gsl_abs_error_tol); + if ( P_.__gsl_abs_error_tol <= 0. ) + { + throw nest::BadProperty( "The gsl_abs_error_tol must be strictly positive." ); + } + updateValue< double >(__d, nest::{{ synapseName }}_names::gsl_rel_error_tol, P_.__gsl_rel_error_tol); + if ( P_.__gsl_rel_error_tol < 0. ) + { + throw nest::BadProperty( "The gsl_rel_error_tol must be zero or positive." ); + } + + // Reinitialize the control function of the solver with new values of tolerance + if ( not __c ) + { + __c = gsl_odeiv_control_y_new( P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol ); + } + else + { + gsl_odeiv_control_init( __c, P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol, 1.0, 0.0 ); + } +{%- endif %} {%- endif %} // recompute internal variables in case they are dependent on parameters or state that might have been updated in this call to set_status() - V_.__h = nest::Time::get_resolution().get_ms(); + V_.__h = Time::get_resolution().get_ms(); recompute_internal_variables(); } @@ -1195,7 +1292,7 @@ set_delay(tmp_{{ nest_codegen_opt_delay_variable }}); template < typename targetidentifierT > void {{synapseName}}< targetidentifierT >::recompute_internal_variables() { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function {% filter indent(2) %} {%- for variable_symbol in synapse.get_internal_symbols() %} @@ -1207,14 +1304,20 @@ void {{synapseName}}< targetidentifierT >::recompute_internal_variables() {%- endfilter %} } +template < typename targetidentifierT > +std::string {{synapseName}}< targetidentifierT >::get_name() const +{ + std::string s ("{{ synapseName }}"); + return s; +} + /** * constructor **/ template < typename targetidentifierT > {{synapseName}}< targetidentifierT >::{{synapseName}}() : ConnectionBase() { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function - + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function // initial values for parameters {%- filter indent(2, True) %} {%- for variable_symbol in synapse.get_parameter_symbols() %} @@ -1226,6 +1329,10 @@ template < typename targetidentifierT > {%- endif %} {%- endif %} {%- endfor %} +{%- if uses_numeric_solver and numeric_solver == "rk45" %} +P_.__gsl_abs_error_tol = 1e-6; +P_.__gsl_rel_error_tol = 1e-6; +{%- endif %} {%- endfilter %} // initial values for internal variables @@ -1261,6 +1368,41 @@ template < typename targetidentifierT > {%- endif %} {%- endif %} +{%- if uses_numeric_solver and numeric_solver == "rk45" %} + if ( not __s ) + { + __s = gsl_odeiv_step_alloc( gsl_odeiv_step_rkf45, State_::STATE_VEC_SIZE ); + } + else + { + gsl_odeiv_step_reset( __s ); + } + + if ( not __c ) + { + __c = gsl_odeiv_control_y_new( P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol ); + } + else + { + gsl_odeiv_control_init( __c, P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol, 1.0, 0.0 ); + } + + if ( not __e ) + { + __e = gsl_odeiv_evolve_alloc( State_::STATE_VEC_SIZE ); + } + else + { + gsl_odeiv_evolve_reset( __e ); + } + + __sys.jacobian = nullptr; + __sys.dimension = State_::STATE_VEC_SIZE; + __sys.params = reinterpret_cast< void* >( &P_ ); + __step = Time::get_resolution().get_ms(); + __integration_step = Time::get_resolution().get_ms(); +{%- endif %} + t_lastspike_ = 0.; {%- if vt_ports is defined and vt_ports|length > 0 %} t_last_update_ = 0.; @@ -1296,7 +1438,7 @@ template < typename targetidentifierT > {%- for variable_symbol in synapse.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- if variable.get_name() != synapse_weight_variable and variable.get_name() != nest_codegen_opt_delay_variable %} - S_.{{ printer_no_origin.print(variable) }} = rhs.S_.{{ printer_no_origin.print(variable) }}; + {{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }} = rhs.{{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }}; {%- endif %} {%- endfor %} @@ -1305,6 +1447,16 @@ template < typename targetidentifierT > {%- endif %} t_lastspike_ = rhs.t_lastspike_; +{%- if uses_numeric_solver and numeric_solver == "rk45" %} + // Numeric solver variables + __s = rhs.__s; + __c = rhs.__c; + __e = rhs.__e; + __sys = rhs.__sys; + __step = rhs.__step; + __integration_step = rhs.__integration_step; +{%- endif %} + // special treatment of NEST delay set_delay(rhs.get_delay()); {%- if synapse_weight_variable | length > 0 %} @@ -1473,6 +1625,7 @@ inline void {%- endif %} -} // namespace +} // namespace {{ synapseName }}; +} // end namespace nest; #endif /* #ifndef {{synapseName.upper()}}_H */ diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 index 60aadcca3..956d117ff 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 @@ -1,19 +1,30 @@ {# Creates GSL implementation of the differentiation step for the system of ODEs. -#} -extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}(double __time, const double ode_state[], double f[], void* pnode) +{%- if neuronName is defined %} +{%- set modelName = neuronName %} +{%- else %} +{%- set modelName = synapseName %} +{%- endif %} +extern "C" inline int {{modelName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}(double __time, const double ode_state[], double f[], void* pnode) { - typedef {{neuronName}}::State_ State_; - // get access to node so we can almost work as in a member function +{%- if neuronName is defined %} + typedef {{modelName}}::State_ State_; + // get access to node so we can almost work as in a member function assert( pnode ); const {{neuronName}}& node = *( reinterpret_cast< {{neuronName}}* >( pnode ) ); {%- for port in neuron.get_continuous_input_ports() %} constexpr int {{ port.get_symbol_name().upper() }} = {{ neuronName }}::{{ port.get_symbol_name().upper() }}; {%- endfor %} - +{%- else %} + // get access to node so we can almost work as in a member function + assert( pnode ); + const Parameters_& node = *( reinterpret_cast< Parameters_* >( pnode ) ); +{%- endif %} // ode_state[] here is---and must be---the state vector supplied by the integrator, // not the state vector in the node, node.S_.ode_state[]. +{%- if neuronName is defined %} {%- for eq_block in neuron.get_equations_blocks() %} {%- for ode in eq_block.get_declarations() %} {%- for inline_expr in utils.get_inline_expression_symbols(ode) %} @@ -25,7 +36,20 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % {%- endfor %} {%- endfor %} -{%- if use_gap_junctions %} +{%- else %} +{%- for eq_block in synapse.get_equations_blocks() %} +{%- for ode in eq_block.get_declarations() %} +{%- for inline_expr in utils.get_inline_expression_symbols(ode) %} +{%- if not inline_expr.is_equation() %} +{%- set declaring_expr = inline_expr.get_declaring_expression() %} + double {{ printer.print(utils.get_state_variable_by_name(astnode, inline_expr)) }} = {{ gsl_printer.print(declaring_expr) }}; +{%- endif %} +{%- endfor %} +{%- endfor %} +{%- endfor %} +{%- endif %} + +{%- if use_gap_junctions and neuronName is defined %} // set I_gap depending on interpolation order double __I_gap = 0.0; @@ -54,6 +78,7 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % } {%- endif %} +{%- if neuronName is defined %} {% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %} {%- if ast.get_args() | length > 0 %} {%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %} @@ -68,9 +93,18 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % {%- endif %} {%- endfor %} +{%- else %} +{%- for variable_name in numeric_state_variables %} +{%- set update_expr = numeric_update_expressions[variable_name] %} +{%- set variable_symbol = variable_symbols[variable_name] %} + f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer_no_origin.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer_no_origin.print(update_expr) }}{% endif %}; +{%- endfor %} +{%- endif %} + {%- if numeric_solver == "rk45" %} return GSL_SUCCESS; {%- else %} return 0; {%- endif %} } + diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 index 4a8090537..c0617ded1 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 @@ -5,7 +5,11 @@ {%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} {%- if numeric_solver == "rk45" %} double __t = 0; +{%- if neuronName is defined %} B_.__sys.function = {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}; +{%- else %} +__sys.function = {{synapseName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}; +{%- endif %} // numerical integration with adaptive step size control: // ------------------------------------------------------ // gsl_odeiv_evolve_apply performs only a single numerical @@ -18,11 +22,12 @@ B_.__sys.function = {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_ // enforce setting IntegrationStep to step-t; this is of advantage // for a consistent and efficient integration across subsequent // simulation intervals +{%- if neuronName is defined %} while ( __t < B_.__step ) { -{%- if use_gap_junctions %} +{%- if use_gap_junctions %} gap_junction_step = B_.__step; -{%- endif %} +{%- endif %} const int status = gsl_odeiv_evolve_apply(B_.__e, B_.__c, @@ -38,6 +43,25 @@ while ( __t < B_.__step ) throw nest::GSLSolverFailure( get_name(), status ); } } +{%- else %} +while ( __t < timestep ) +{ + const int status = gsl_odeiv_evolve_apply(__e, + __c, + __s, + &__sys, // system of ODE + &__t, // from t + timestep, // to t <= step + &__integration_step, // integration step size + S_.ode_state); // neuronal state + + if ( status != GSL_SUCCESS ) + { + throw nest::GSLSolverFailure( get_name(), status ); + } + } +{%- endif %} + {%- elif numeric_solver == "forward-Euler" %} double f[State_::STATE_VEC_SIZE]; {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}( get_t(), S_.ode_state, f, reinterpret_cast< void* >( this ) ); diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 index 65f8b218e..12f8bc0e2 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 @@ -25,7 +25,12 @@ {%- if uses_numeric_solver %} +{%- if neuronName is defined %} {% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %} +{%- else %} +{% set numeric_state_variables_to_be_integrated = numeric_state_variables %} +{%- endif %} + {%- if ast.get_args() | length > 0 %} {%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %} {%- endif %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 index 3f6646d42..f9175e2c9 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 @@ -132,7 +132,7 @@ void {%- if synapses %} // register synapses {%- for synapse in synapses %} - nest::register_connection_model< nest::{{synapse.get_name()}} >( "{{synapse.get_name()}}" ); + nest::register_connection_model< nest::{{synapse.get_name()}}::{{ synapse.get_name() }} >( "{{synapse.get_name()}}" ); {%- endfor %} {%- endif %} } // {{moduleName}}::init() diff --git a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 index fe2d49582..a43c3912c 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 @@ -78,7 +78,7 @@ void {{moduleName}}::initialize() {%- if synapses %} // register synapses {%- for synapse in synapses %} - nest::register_{{synapse.get_name()}}( "{{synapse.get_name()}}" ); + nest::{{synapse.get_name()}}::register_{{synapse.get_name()}}( "{{synapse.get_name()}}" ); {%- endfor %} {%- endif %} -} \ No newline at end of file +} diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index ef51e0812..29c4179a6 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -2555,7 +2555,7 @@ def get_spike_input_ports_in_pairs(cls, neuron: ASTModel) -> Dict[int, List[Vari return rport_to_port_map @classmethod - def assign_numeric_non_numeric_state_variables(cls, neuron, numeric_state_variable_names, numeric_update_expressions, update_expressions): + def assign_numeric_non_numeric_state_variables(cls, model, numeric_state_variable_names, numeric_update_expressions, update_expressions): r"""For each ASTVariable, set the ``node._is_numeric`` member to True or False based on whether this variable will be solved with the analytic or numeric solver. Ideally, this would not be a property of the ASTVariable as it is an implementation detail (that only emerges during code generation) and not an intrinsic part of the model itself. However, this approach is preferred over setting it as a property of the variable printers as it would have to make each printer aware of all models and variables therein.""" @@ -2574,10 +2574,10 @@ def visit_variable(self, node): visitor = ASTVariableOriginSetterVisitor() visitor._numeric_state_variables = numeric_state_variable_names - neuron.accept(visitor) + model.accept(visitor) - if "extra_on_emit_spike_stmts_from_synapse" in dir(neuron): - for expr in neuron.extra_on_emit_spike_stmts_from_synapse: + if "extra_on_emit_spike_stmts_from_synapse" in dir(model): + for expr in model.extra_on_emit_spike_stmts_from_synapse: expr.accept(visitor) if update_expressions: @@ -2588,15 +2588,17 @@ def visit_variable(self, node): for expr in numeric_update_expressions.values(): expr.accept(visitor) - for update_expr_list in neuron.spike_updates.values(): + for update_expr_list in model.spike_updates.values(): for update_expr in update_expr_list: update_expr.accept(visitor) - for update_expr in neuron.post_spike_updates.values(): - update_expr.accept(visitor) + if "post_spike_updates" in dir(model): + for update_expr in model.post_spike_updates.values(): + update_expr.accept(visitor) - for node in neuron.equations_with_delay_vars + neuron.equations_with_vector_vars: - node.accept(visitor) + if "equations_with_delay_vars" in dir(model): + for node in model.equations_with_delay_vars + model.equations_with_vector_vars: + node.accept(visitor) @classmethod def depends_only_on_vars(cls, expr, vars): diff --git a/tests/nest_tests/resources/non_linear_synapse.nestml b/tests/nest_tests/resources/non_linear_synapse.nestml new file mode 100644 index 000000000..d7281fa09 --- /dev/null +++ b/tests/nest_tests/resources/non_linear_synapse.nestml @@ -0,0 +1,60 @@ +# non_linear_synapse.nestml +# ######################### +# +# +# Description +# +++++++++++ +# +# This model is used to test vector operations with NEST. +# +# +# Copyright statement +# +++++++++++++++++++ +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +model non_linear_synapse: + state: + x real = 1. + y real = 1. + z real = 1. + w real = 0. + d ms = 1.0 ms + + equations: + x' = (sigma * (y - x)) / ms + y' = (x * (rho - z) - y) / ms + z' = (x * y - beta * z) / ms + + parameters: + sigma real = 10. + beta real = 8/3 + rho real = 28 + + input: + pre_spikes <- spike + + output: + spike(weight real, delay ms) + + onReceive(pre_spikes): + w += x * y / z + emit_spike(w, d) + + update: + integrate_odes() diff --git a/tests/nest_tests/test_synapse_numeric_solver.py b/tests/nest_tests/test_synapse_numeric_solver.py new file mode 100644 index 000000000..652c4fbc6 --- /dev/null +++ b/tests/nest_tests/test_synapse_numeric_solver.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# +# test_synapse_numeric_solver.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +import os + +import nest +import numpy as np +import pytest +from scipy.integrate import solve_ivp + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_target + +try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.ticker + import matplotlib.pyplot as plt + + TEST_PLOTS = True +except Exception: + TEST_PLOTS = False + + +@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"), + reason="This test does not support NEST 2") +class TestSynapseNumericSolver: + """ + Tests a synapse with non-linear dynamics requiring a numeric solver for ODEs. + """ + + def lorenz_attractor_system(self, t, state, sigma, rho, beta): + x, y, z = state + dxdt = (sigma * (y - x)) + dydt = (x * (rho - z) - y) + dzdt = (x * y - beta * z) + return [dxdt, dydt, dzdt] + + def evaluate_odes_scipy(self, func, params, initial_state, spike_times, sim_time): + """ + Evaluate ODEs using SciPy. + """ + sol = [] + + # integrate the ODES from one spike time to the next, until the end of the simulation. + t_last_spike = 0. + spike_idx = 0 + for i in np.arange(1., sim_time + 0.01, 1.0): + if spike_idx < len(spike_times) and i == spike_times[spike_idx]: + t_spike = spike_times[spike_idx] + t_span = [t_last_spike, t_spike] + # Solve using RK45 + solution = solve_ivp( + fun=func, + t_span=t_span, # interval of integration + y0=initial_state, # initial state + args=params, # parameters + method='RK45', + rtol=1e-14, # relative tolerance + atol=1e-14 # absolute tolerance + ) + initial_state = solution.y[:, -1] + t_last_spike = t_spike + spike_idx += 1 + + sol += [initial_state] + + return sol + + def test_non_linear_synapse(self): + nest.ResetKernel() + nest.set_verbosity("M_WARNING") + dt = 0.1 + nest.resolution = dt + sim_time = 30. + spike_times = [3.0, 5.0, 9.0, 11.0, 22.0, 28.0] + + files = ["models/neurons/iaf_psc_exp_neuron.nestml", "tests/nest_tests/resources/non_linear_synapse.nestml"] + input_paths = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( + os.pardir, os.pardir, s))) for s in files] + target_path = "target_nl" + modulename = "nl_syn_module" + + generate_nest_target(input_path=input_paths, + target_path=target_path, + logging_level="INFO", + suffix="_nestml", + module_name=modulename, + codegen_opts={"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "synapse": "non_linear_synapse"}], + "delay_variable": {"non_linear_synapse": "d"}, + "weight_variable": {"non_linear_synapse": "w"}, + "strictly_synaptic_vars": {"non_linear_synapse": ["x", "y", "z"]}}) + nest.Install(modulename) + + neuron_model = "iaf_psc_exp_neuron_nestml__with_non_linear_synapse_nestml" + synapse_model = "non_linear_synapse_nestml__with_iaf_psc_exp_neuron_nestml" + + neuron = nest.Create(neuron_model) + sg = nest.Create("spike_generator", params={"spike_times": spike_times}) + + syn_spec = {"synapse_model": synapse_model, "gsl_abs_error_tol": 1E-14, "gsl_rel_error_tol": 1E-14} + nest.Connect(sg, neuron, syn_spec=syn_spec) + connections = nest.GetConnections(source=sg, synapse_model=synapse_model) + + # Get the parameter values + sigma = connections.get("sigma") + rho = connections.get("rho") + beta = connections.get("beta") + + # Initial values of state variables + inital_state = [connections.get("x"), connections.get("y"), connections.get("z")] + + # Scipy simulation + params = (sigma, rho, beta) + sol = self.evaluate_odes_scipy(self.lorenz_attractor_system, params, inital_state, spike_times, sim_time) + sol_arr = np.array(sol) + + # NEST simulation + x = [] + y = [] + z = [] + sim_step_size = 1. + for i in np.arange(0., sim_time, sim_step_size): + nest.Simulate(sim_step_size) + syn_stats = connections.get() + x += [syn_stats["x"]] + y += [syn_stats["y"]] + z += [syn_stats["z"]] + + # Plotting + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(7, 5)) + times = np.arange(0., sim_time, sim_step_size) + + ax[0].plot(times, x, label="NESTML") + ax[0].scatter(times, x, marker='x') + ax[0].plot(times, sol_arr[:, 0], '--', label="scipy") + ax[0].scatter(times, sol_arr[:, 0], marker='o') + ax[0].set_ylabel("x") + + ax[1].plot(times, y, label="NESTML") + ax[1].scatter(times, y, marker='x') + ax[1].plot(times, sol_arr[:, 1], '--', label="scipy") + ax[1].scatter(times, sol_arr[:, 1], marker='o') + ax[1].set_ylabel("y") + + ax[2].plot(times, z, label="NESTML") + ax[2].scatter(times, z, marker='x') + ax[2].plot(times, sol_arr[:, 2], '--', label="scipy") + ax[2].scatter(times, sol_arr[:, 2], marker='o') + ax[2].set_ylabel("z") + for _ax in ax: + _ax.set_xlabel("time") + _ax.scatter(spike_times, np.zeros_like(spike_times), marker='d', color='r') + + handles, labels = ax[-1].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper center') + plt.savefig("non_linear_synapse.png") + + np.testing.assert_allclose(x, sol_arr[:, 0], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(y, sol_arr[:, 1], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(z, sol_arr[:, 2], rtol=1e-3, atol=1e-3)