diff --git a/docs/contents/longitudinal_diffusion.rst b/docs/contents/longitudinal_diffusion.rst new file mode 100644 index 000000000..305ddc2d1 --- /dev/null +++ b/docs/contents/longitudinal_diffusion.rst @@ -0,0 +1,33 @@ +Longitudinal Diffusion +====================== + +The idea behind ``LONGITUDINAL_DIFFUSION`` is to allow a ``STATE`` variable to +diffuse along a section, i.e. from one segment into a neighbouring segment. + +This problem is solved by registering callbacks. In particular, NEURON needs to +be informed of the volume and diffusion rate. Additionally, the implicit +time-stepping requires information about certain derivatives. + +Implementation in NMODL +----------------------- + +The following ``KINETIC`` block + + .. code-block:: + + KINETIC state { + COMPARTMENT vol {X} + LONGITUDINAL_DIFFUSION mu {X} + + ~ X << (ica) + } + +Will undergo two transformations. The first is to create a system of ODEs that +can be solved. This consumed the AST node. However, to print the code for +longitudinal diffusion we require information from the ``COMPARTMENT`` and +``LONGITUDINAL_DIFFUSION`` statements. This is why there's a second +transformation, that runs before the other transformation, to extract the +required information and store it a AST node called +``LONGITUDINAL_DIFFUSION_BLOCK``. This block can then be converted into an +"info" object, which is then used to print the callbacks. + diff --git a/docs/index.rst b/docs/index.rst index 15125ef4a..8ed98af9b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ About NMODL contents/pointers contents/cable_equations contents/globals + contents/longitudinal_diffusion contents/cvode .. toctree:: diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index 6f8f85d01..391dbc6d3 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -9,8 +9,10 @@ #include #include +#include #include "ast/all.hpp" +#include "ast/constant_var.hpp" #include "codegen/codegen_naming.hpp" #include "parser/c11_driver.hpp" #include "visitors/visitor_utils.hpp" @@ -853,5 +855,63 @@ void CodegenHelperVisitor::visit_after_block(const ast::AfterBlock& node) { info.before_after_blocks.push_back(&node); } +static std::shared_ptr find_compartment( + const ast::LongitudinalDiffusionBlock& node, + const std::string& var_name) { + const auto& compartment_block = node.get_compartment_statements(); + for (const auto& stmt: compartment_block->get_statements()) { + auto comp = std::dynamic_pointer_cast(stmt); + + auto species = comp->get_species(); + auto it = std::find_if(species.begin(), species.end(), [&var_name](auto var) { + return var->get_node_name() == var_name; + }); + + if (it != species.end()) { + return comp; + } + } + + return nullptr; +} + +void CodegenHelperVisitor::visit_longitudinal_diffusion_block( + const ast::LongitudinalDiffusionBlock& node) { + auto longitudinal_diffusion_block = node.get_longitudinal_diffusion_statements(); + for (auto stmt: longitudinal_diffusion_block->get_statements()) { + auto diffusion = std::dynamic_pointer_cast(stmt); + auto rate_index_name = diffusion->get_index_name(); + auto rate_expr = diffusion->get_rate(); + auto species = diffusion->get_species(); + + auto process_compartment = [](const std::shared_ptr& compartment) + -> std::pair, std::shared_ptr> { + std::shared_ptr volume_expr; + std::shared_ptr volume_index_name; + if (!compartment) { + volume_index_name = nullptr; + volume_expr = std::make_shared("1.0"); + } else { + volume_index_name = compartment->get_index_name(); + volume_expr = std::shared_ptr(compartment->get_volume()->clone()); + } + return {std::move(volume_index_name), std::move(volume_expr)}; + }; + + for (auto var: species) { + std::string state_name = var->get_value()->get_value(); + auto compartment = find_compartment(node, state_name); + auto [volume_index_name, volume_expr] = process_compartment(compartment); + + info.longitudinal_diffusion_info.insert( + {state_name, + LongitudinalDiffusionInfo(volume_index_name, + std::shared_ptr(volume_expr), + rate_index_name, + std::shared_ptr(rate_expr->clone()))}); + } + } +} + } // namespace codegen } // namespace nmodl diff --git a/src/codegen/codegen_helper_visitor.hpp b/src/codegen/codegen_helper_visitor.hpp index 9258f1cf7..3468802c6 100644 --- a/src/codegen/codegen_helper_visitor.hpp +++ b/src/codegen/codegen_helper_visitor.hpp @@ -115,6 +115,7 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { void visit_verbatim(const ast::Verbatim& node) override; void visit_before_block(const ast::BeforeBlock& node) override; void visit_after_block(const ast::AfterBlock& node) override; + void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock& node) override; }; /** @} */ // end of codegen_details diff --git a/src/codegen/codegen_info.cpp b/src/codegen/codegen_info.cpp index b7d40125a..10dec0ddd 100644 --- a/src/codegen/codegen_info.cpp +++ b/src/codegen/codegen_info.cpp @@ -8,6 +8,8 @@ #include "codegen/codegen_info.hpp" #include "ast/all.hpp" +#include "ast/longitudinal_diffusion_block.hpp" +#include "visitors/rename_visitor.hpp" #include "visitors/var_usage_visitor.hpp" #include "visitors/visitor_utils.hpp" @@ -17,6 +19,48 @@ namespace codegen { using visitor::VarUsageVisitor; +LongitudinalDiffusionInfo::LongitudinalDiffusionInfo( + const std::shared_ptr& volume_index_name, + std::shared_ptr volume_expr, + const std::shared_ptr& rate_index_name, + std::shared_ptr rate_expr) + : volume_index_name(volume_index_name ? volume_index_name->get_node_name() : std::string{}) + , volume_expr(std::move(volume_expr)) + , rate_index_name(rate_index_name ? rate_index_name->get_node_name() : std::string{}) + , rate_expr(std::move(rate_expr)) {} + +std::shared_ptr LongitudinalDiffusionInfo::volume( + const std::string& index_name) const { + return substitute_index(index_name, volume_index_name, volume_expr); +} +std::shared_ptr LongitudinalDiffusionInfo::diffusion_rate( + const std::string& index_name) const { + return substitute_index(index_name, rate_index_name, rate_expr); +} + +double LongitudinalDiffusionInfo::dfcdc(const std::string& /* index_name */) const { + // Needed as part of the Jacobian to stabalize + // the implicit time-integration. However, + // currently, it's set to `0.0` for simplicity. + return 0.0; +} + +std::shared_ptr LongitudinalDiffusionInfo::substitute_index( + const std::string& index_name, + const std::string& old_index_name, + const std::shared_ptr& old_expr) const { + if (old_index_name == "") { + // The expression doesn't contain an index that needs substituting. + return old_expr; + } + auto new_expr = old_expr->clone(); + + auto v = visitor::RenameVisitor(old_index_name, index_name); + new_expr->accept(v); + + return std::shared_ptr(dynamic_cast(new_expr)); +} + /// if any ion has write variable bool CodegenInfo::ion_has_write_variable() const noexcept { return std::any_of(ions.begin(), ions.end(), [](auto const& ion) { diff --git a/src/codegen/codegen_info.hpp b/src/codegen/codegen_info.hpp index 6d39479e1..0e86ac0c3 100644 --- a/src/codegen/codegen_info.hpp +++ b/src/codegen/codegen_info.hpp @@ -284,6 +284,44 @@ struct IndexSemantics { , size(size) {} }; +/** + * \brief Information required to print LONGITUDINAL_DIFFUSION callbacks. + */ +class LongitudinalDiffusionInfo { + public: + LongitudinalDiffusionInfo(const std::shared_ptr& index_name, + std::shared_ptr volume_expr, + const std::shared_ptr& rate_index_name, + std::shared_ptr rate_expr); + /// Volume of this species. + /// + /// If the volume expression is an indexed expression, the index in the + /// expression is substituted with `index_name`. + std::shared_ptr volume(const std::string& index_name) const; + + /// Difusion rate of this species. + /// + /// If the diffusion expression is an indexed expression, the index in the + /// expression is substituted with `index_name`. + std::shared_ptr diffusion_rate(const std::string& index_name) const; + + /// The value of what NEURON calls `dfcdc`. + double dfcdc(const std::string& /* index_name */) const; + + protected: + std::shared_ptr substitute_index( + const std::string& index_name, + const std::string& old_index_name, + const std::shared_ptr& old_expr) const; + + private: + std::string volume_index_name; + std::shared_ptr volume_expr; + + std::string rate_index_name; + std::shared_ptr rate_expr; +}; + /** * \class CodegenInfo @@ -447,6 +485,9 @@ struct CodegenInfo { /// all factors defined in the mod file std::vector factor_definitions; + /// for each state, the information needed to print the callbacks. + std::map longitudinal_diffusion_info; + /// ions used in the mod file std::vector ions; diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 7e012f6b0..55933e2a5 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -36,7 +36,6 @@ using visitor::VarUsageVisitor; using symtab::syminfo::NmodlType; - /****************************************************************************************/ /* Generic information getters */ /****************************************************************************************/ @@ -436,6 +435,79 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_definitions() { } } +CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::ldifusfunc1_parameters() const { + return ParamVector{{"", "ldifusfunc2_t", "", "_f"}, + {"", "const _nrn_model_sorted_token&", "", "_sorted_token"}, + {"", "NrnThread&", "", "_nt"}}; +} + + +CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::ldifusfunc3_parameters() const { + return ParamVector{{"", "int", "", "_i"}, + {"", "Memb_list*", "", "_ml_arg"}, + {"", "size_t", "", "id"}, + {"", "Datum*", "", "_ppvar"}, + {"", "double*", "", "_pdvol"}, + {"", "double*", "", "_pdfcdc"}, + {"", "Datum*", "", "/* _thread */"}, + {"", "NrnThread*", "", "nt"}, + {"", "const _nrn_model_sorted_token&", "", "_sorted_token"}}; +} + +void CodegenNeuronCppVisitor::print_longitudinal_diffusion_callbacks() { + auto coeff_callback_name = [](const std::string& var_name) { + return fmt::format("_diffusion_coefficient_{}", var_name); + }; + + auto space_name = [](const std::string& var_name) { + return fmt::format("_diffusion_space_{}", var_name); + }; + + for (auto [var_name, values]: info.longitudinal_diffusion_info) { + printer->fmt_line("static void* {};", space_name(var_name)); + printer->fmt_push_block("static double {}({})", + coeff_callback_name(var_name), + get_parameter_str(ldifusfunc3_parameters())); + + print_entrypoint_setup_code_from_memb_list(); + + auto volume_expr = values.volume("_i"); + auto mu_expr = values.diffusion_rate("_i"); + + printer->add_indent(); + printer->add_text("*_pdvol= "); + volume_expr->accept(*this); + printer->add_text(";"); + printer->add_newline(); + + printer->add_line("*_pdfcdc = 0.0;"); + printer->add_indent(); + printer->add_text("return "); + mu_expr->accept(*this); + printer->add_text(";"); + printer->add_newline(); + + printer->pop_block(); + } + + printer->fmt_push_block("static void _apply_diffusion_function({})", + get_parameter_str(ldifusfunc1_parameters())); + for (auto [var_name, values]: info.longitudinal_diffusion_info) { + auto var = program_symtab->lookup(var_name); + size_t array_size = var->get_length(); + printer->fmt_push_block("for(size_t _i = 0; _i < {}; ++_i)", array_size); + printer->fmt_line( + "(*_f)(mech_type, {}, &{}, _i, /* x pos */ {}, /* Dx pos */ {}, _sorted_token, _nt);", + coeff_callback_name(var_name), + space_name(var_name), + position_of_float_var(var_name), + position_of_float_var("D" + var_name)); + printer->pop_block(); + } + printer->pop_block(); + printer->add_newline(); +} + /****************************************************************************************/ /* Code-specific helper routines */ /****************************************************************************************/ @@ -1301,6 +1373,11 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { info.semantics[i].name); } + if (!info.longitudinal_diffusion_info.empty()) { + printer->fmt_line("hoc_register_ldifus1(_apply_diffusion_function);"); + } + + if (info.write_concentration) { printer->fmt_line("nrn_writes_conc(mech_type, 0);"); } @@ -1589,6 +1666,41 @@ void CodegenNeuronCppVisitor::print_initial_block(const InitialBlock* node) { } } +void CodegenNeuronCppVisitor::print_entrypoint_setup_code_from_memb_list() { + printer->add_line( + "_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _ml_arg->type()};"); + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + } + printer->add_line("auto* _thread = _ml_arg->_thread;"); + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", + thread_variables_struct(), + info.thread_var_thread_id); + } +} + + +void CodegenNeuronCppVisitor::print_entrypoint_setup_code_from_prop() { + printer->add_line("Datum* _ppvar = _nrn_mechanism_access_dparam(prop);"); + printer->add_line("_nrn_mechanism_cache_instance _lmc{prop};"); + printer->add_line("const size_t id = 0;"); + + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(prop);", info.mod_suffix); + } + + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}({}_global.thread_data);", + thread_variables_struct(), + info.mod_suffix); + } + + printer->add_newline(); +} + void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type, const std::string& function_name) { @@ -1598,18 +1710,8 @@ void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type, {"", "Memb_list*", "", "_ml_arg"}, {"", "int", "", "_type"}}; printer->fmt_push_block("void {}({})", method, get_parameter_str(args)); - - printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); - printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); - + print_entrypoint_setup_code_from_memb_list(); printer->add_line("auto nodecount = _ml_arg->nodecount;"); - printer->add_line("auto* _thread = _ml_arg->_thread;"); - if (!codegen_thread_variables.empty()) { - printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", - thread_variables_struct(), - info.thread_var_thread_id); - } } @@ -1659,11 +1761,7 @@ void CodegenNeuronCppVisitor::print_nrn_jacob() { get_parameter_str(args)); // begin function - printer->add_multi_line( - "_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - - printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); - printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + print_entrypoint_setup_code_from_memb_list(); printer->fmt_line("auto nodecount = _ml_arg->nodecount;"); printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for @@ -1680,25 +1778,6 @@ void CodegenNeuronCppVisitor::print_nrn_jacob() { } -void CodegenNeuronCppVisitor::print_callable_preamble_from_prop() { - printer->add_line("Datum* _ppvar = _nrn_mechanism_access_dparam(prop);"); - printer->add_line("_nrn_mechanism_cache_instance _lmc{prop};"); - printer->add_line("const size_t id = 0;"); - - printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); - if (!info.artificial_cell) { - printer->fmt_line("auto node_data = make_node_data_{}(prop);", info.mod_suffix); - } - - if (!codegen_thread_variables.empty()) { - printer->fmt_line("auto _thread_vars = {}({}_global.thread_data);", - thread_variables_struct(), - info.mod_suffix); - } - - printer->add_newline(); -} - void CodegenNeuronCppVisitor::print_nrn_constructor_declaration() { if (info.constructor_node) { printer->fmt_line("void {}(Prop* prop);", method_name(naming::NRN_CONSTRUCTOR_METHOD)); @@ -1709,7 +1788,7 @@ void CodegenNeuronCppVisitor::print_nrn_constructor() { if (info.constructor_node) { printer->fmt_push_block("void {}(Prop* prop)", method_name(naming::NRN_CONSTRUCTOR_METHOD)); - print_callable_preamble_from_prop(); + print_entrypoint_setup_code_from_prop(); auto block = info.constructor_node->get_statement_block(); print_statement_block(*block, false, false); @@ -1725,7 +1804,7 @@ void CodegenNeuronCppVisitor::print_nrn_destructor_declaration() { void CodegenNeuronCppVisitor::print_nrn_destructor() { printer->fmt_push_block("void {}(Prop* prop)", method_name(naming::NRN_DESTRUCTOR_METHOD)); - print_callable_preamble_from_prop(); + print_entrypoint_setup_code_from_prop(); for (const auto& rv: info.random_variables) { printer->fmt_line("nrnran123_deletestream((nrnran123_State*) {});", @@ -2273,6 +2352,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() { print_nrn_destructor_declaration(); print_nrn_alloc(); print_function_prototypes(); + print_longitudinal_diffusion_callbacks(); print_point_process_function_definitions(); print_setdata_functions(); print_check_table_entrypoint(); @@ -2455,6 +2535,15 @@ void CodegenNeuronCppVisitor::visit_watch_statement(const ast::WatchStatement& / return; } +void CodegenNeuronCppVisitor::visit_longitudinal_diffusion_block( + const ast::LongitudinalDiffusionBlock& /* node */) { + // These are handled via `print_longitdudinal_*`. +} + +void CodegenNeuronCppVisitor::visit_lon_diffuse(const ast::LonDiffuse& /* node */) { + // These are handled via `print_longitdudinal_*`. +} + void CodegenNeuronCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { // The setup for enabling this loop is: // double ** _fornetcon_data = ...; diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index b1a02fc7d..794c7dfbf 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -250,6 +250,21 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_hoc_py_wrapper_function_definitions(); + /** + * Prints the callbacks required for LONGITUDINAL_DIFFUSION. + */ + void print_longitudinal_diffusion_callbacks(); + + /** + * Parameters for what NEURON calls `ldifusfunc1_t`. + */ + ParamVector ldifusfunc1_parameters() const; + + /** + * Parameters for what NEURON calls `ldifusfunc3_t`. + */ + ParamVector ldifusfunc3_parameters() const; + /****************************************************************************************/ /* Code-specific helper routines */ @@ -492,6 +507,26 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_global_function_common_code(BlockType type, const std::string& function_name = "") override; + /** + * Prints setup code for entrypoints from NEURON. + * + * The entrypoints typically receive a `sorted_token` and a bunch of other things, which then + * need to be converted into the default arguments for functions called (recursively) from the + * entrypoint. + * + * This variation prints the fast entrypoint, where NEURON is fully initialized and setup. + */ + void print_entrypoint_setup_code_from_memb_list(); + + + /** + * Prints setup code for entrypoints NEURON. + * + * See `print_entrypoint_setup_code_from_memb_list`. This variation should be used when one only + * has access to a `Prop`, but not the full `Memb_list`. + */ + void print_entrypoint_setup_code_from_prop(); + /** * Print the \c nrn\_init function definition @@ -509,11 +544,6 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_nrn_constructor() override; void print_nrn_constructor_declaration(); - /** - * Print the set of common variables from a `Prop` only. - */ - void print_callable_preamble_from_prop(); - /** * Print nrn_destructor function definition * @@ -693,8 +723,8 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void visit_watch_statement(const ast::WatchStatement& node) override; void visit_for_netcon(const ast::ForNetcon& node) override; - - + void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock& node) override; + void visit_lon_diffuse(const ast::LonDiffuse& node) override; public: diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index 5803d4ae9..ae8d5e292 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -116,6 +116,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/local_list_statement.hpp ${PROJECT_BINARY_DIR}/src/ast/local_var.hpp ${PROJECT_BINARY_DIR}/src/ast/lon_diffuse.hpp + ${PROJECT_BINARY_DIR}/src/ast/longitudinal_diffusion_block.hpp ${PROJECT_BINARY_DIR}/src/ast/model.hpp ${PROJECT_BINARY_DIR}/src/ast/mutex_lock.hpp ${PROJECT_BINARY_DIR}/src/ast/mutex_unlock.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 292cb567c..67c9efe21 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -103,6 +103,22 @@ brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f))" type: StatementBlock brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks" + - LongitudinalDiffusionBlock: + brief: "Extracts information required for LONGITUDINAL_DIFFUSION for each KINETIC block." + nmodl: "LONGITUDINAL_DIFFUSION_BLOCK" + members: + - name: + brief: "Name of the longitudinal diffusion block" + type: Name + node_name: true + prefix: { value: " "} + suffix: { value: " "} + - longitudinal_diffusion_statements: + brief: "All LONGITUDINAL_DIFFUSION statements in the KINETIC block." + type: StatementBlock + - compartment_statements: + brief: "All (required) COMPARTMENT statements in the KINETIC block." + type: StatementBlock - WrappedExpression: brief: "Wrap any other expression type" diff --git a/src/language/nmodl.yaml b/src/language/nmodl.yaml index dbfbaea6e..0f76b01a3 100644 --- a/src/language/nmodl.yaml +++ b/src/language/nmodl.yaml @@ -1152,15 +1152,15 @@ type: Expression - LinEquation: - brief: "TODO" + brief: "One equation in a system of equations tha collectively form a LINEAR block." nmodl: "~ " members: - - left_linxpression: - brief: "TODO" + - lhs: + brief: "Left-hand-side of the equation." type: Expression suffix: {value: " = "} - - linxpression: - brief: "TODO" + - rhs: + brief: "Right-hand-side of the equation." type: Expression - FunctionCall: @@ -1568,6 +1568,7 @@ separator: " " brief: "Represent LONGITUDINAL_DIFFUSION statement in NMODL" + - ReactionStatement: brief: "TODO" nmodl: "~ " @@ -1750,14 +1751,14 @@ separator: ", " description: | Here is an example of RANDOM statement - - \code{.mod} + + \code{.mod} NEURON { THREADSAFE POINT_PROCESS NetStim RANDOM ranvar \endcode - + - Pointer: nmodl: "POINTER " members: diff --git a/src/language/node_info.py b/src/language/node_info.py index ff6d72faf..d2cfe50c9 100644 --- a/src/language/node_info.py +++ b/src/language/node_info.py @@ -80,6 +80,7 @@ "ProcedureBlock", "DerivativeBlock", "LinearBlock", + "LongitudinalDiffusionBlock", "NonLinearBlock", "DiscreteBlock", "KineticBlock", diff --git a/src/main.cpp b/src/main.cpp index 97045dad7..27386c655 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -36,6 +36,7 @@ #include "visitors/local_to_assigned_visitor.hpp" #include "visitors/local_var_rename_visitor.hpp" #include "visitors/localize_visitor.hpp" +#include "visitors/longitudinal_diffusion_visitor.hpp" #include "visitors/loop_unroll_visitor.hpp" #include "visitors/neuron_solve_visitor.hpp" #include "visitors/nmodl_visitor.hpp" @@ -431,6 +432,13 @@ int run_nmodl(int argc, const char* argv[]) { SymtabVisitor(update_symtab).visit_program(*ast); } + if (neuron_code) { + CreateLongitudinalDiffusionBlocks().visit_program(*ast); + ast_to_nmodl(*ast, filepath("londifus")); + SymtabVisitor(update_symtab).visit_program(*ast); + } + + /// note that we can not symtab visitor in update mode as we /// replace kinetic block with derivative block of same name /// in global scope diff --git a/src/symtab/symbol_properties.hpp b/src/symtab/symbol_properties.hpp index 8bf635919..7cc57c7b7 100644 --- a/src/symtab/symbol_properties.hpp +++ b/src/symtab/symbol_properties.hpp @@ -223,7 +223,10 @@ enum class NmodlType : enum_type { random_var = 1LL << 34, /// FUNCTION or PROCEDURE needs setdata check - use_range_ptr_var = 1LL << 35 + use_range_ptr_var = 1LL << 35, + + /// Internal LONGITUDINAL_DIFFUSION block + longitudinal_diffusion_block = 1LL << 36 }; template diff --git a/src/symtab/symbol_table.hpp b/src/symtab/symbol_table.hpp index 9b64157f8..e4c100ccf 100644 --- a/src/symtab/symbol_table.hpp +++ b/src/symtab/symbol_table.hpp @@ -134,10 +134,10 @@ class SymbolTable { /** * get variables * - * \param with variables with properties. 0 matches everything - * \param without variables without properties. 0 matches nothing + * \param with variables with properties. `syminfo::NmodlType::empty` matches everything + * \param without variables without properties. `syminfo::NmodlType::empty` matches nothing * - * The two different behaviors for 0 depend on the fact that we get + * The two different behaviors for `syminfo::NmodlType::empty` depend on the fact that we get * get variables with ALL the with properties and without ANY of the * without properties */ diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index f51a65b73..10a1ded54 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -23,6 +23,7 @@ add_library( local_to_assigned_visitor.cpp local_var_rename_visitor.cpp localize_visitor.cpp + longitudinal_diffusion_visitor.cpp loop_unroll_visitor.cpp neuron_solve_visitor.cpp perf_visitor.cpp diff --git a/src/visitors/longitudinal_diffusion_visitor.cpp b/src/visitors/longitudinal_diffusion_visitor.cpp new file mode 100644 index 000000000..104affc71 --- /dev/null +++ b/src/visitors/longitudinal_diffusion_visitor.cpp @@ -0,0 +1,51 @@ +#include "longitudinal_diffusion_visitor.hpp" + +#include "ast/ast_decl.hpp" +#include "ast/kinetic_block.hpp" +#include "ast/longitudinal_diffusion_block.hpp" +#include "ast/name.hpp" +#include "ast/program.hpp" +#include "ast/statement.hpp" +#include "ast/statement_block.hpp" +#include "ast/string.hpp" +#include "visitor_utils.hpp" + +#include + + +namespace nmodl { +namespace visitor { + +static std::shared_ptr make_statement_block( + ast::KineticBlock& kinetic_block, + nmodl::ast::AstNodeType node_type) { + auto nodes = collect_nodes(kinetic_block, {node_type}); + + ast::StatementVector statements; + statements.reserve(nodes.size()); + for (auto& node: nodes) { + statements.push_back(std::dynamic_pointer_cast(node)); + } + + return std::make_shared(std::move(statements)); +} + + +static std::shared_ptr create_block(ast::KineticBlock& node) { + return std::make_shared( + std::make_shared(std::make_shared("ld_" + node.get_node_name())), + make_statement_block(node, nmodl::ast::AstNodeType::LON_DIFFUSE), + make_statement_block(node, nmodl::ast::AstNodeType::COMPARTMENT)); +} + +void CreateLongitudinalDiffusionBlocks::visit_program(ast::Program& node) { + auto kinetic_blocks = collect_nodes(node, {nmodl::ast::AstNodeType::KINETIC_BLOCK}); + + for (const auto& ast_node: kinetic_blocks) { + auto kinetic_block = std::dynamic_pointer_cast(ast_node); + node.emplace_back_node(create_block(*kinetic_block)); + } +} + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/longitudinal_diffusion_visitor.hpp b/src/visitors/longitudinal_diffusion_visitor.hpp new file mode 100644 index 000000000..7caec2a6f --- /dev/null +++ b/src/visitors/longitudinal_diffusion_visitor.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "visitors/ast_visitor.hpp" + + +namespace nmodl { +namespace ast { +class Program; +} + +namespace visitor { + +class CreateLongitudinalDiffusionBlocks: public AstVisitor { + public: + + void visit_program(ast::Program& node) override; +}; + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/sympy_replace_solutions_visitor.cpp b/src/visitors/sympy_replace_solutions_visitor.cpp index 9906c5577..a30e8f00e 100644 --- a/src/visitors/sympy_replace_solutions_visitor.cpp +++ b/src/visitors/sympy_replace_solutions_visitor.cpp @@ -161,8 +161,7 @@ void SympyReplaceSolutionsVisitor::visit_statement_block(ast::StatementBlock& no void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( const ast::Node& node, - std::shared_ptr get_lhs(const ast::Node& node), - std::shared_ptr get_rhs(const ast::Node& node)) { + std::shared_ptr get_lhs(const ast::Node& node)) { interleaves_counter.new_equation(true); const auto& statement = std::static_pointer_cast( @@ -176,8 +175,7 @@ void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( switch (policy) { case ReplacePolicy::VALUE: { - const auto dependencies = statement_dependencies(get_lhs(node), get_rhs(node)); - const auto& key = dependencies.first; + const auto key = statement_dependencies_key(get_lhs(node)); if (solution_statements.is_var_assigned_here(key)) { logger->debug("SympyReplaceSolutionsVisitor :: marking for replacement {}", @@ -216,24 +214,16 @@ void SympyReplaceSolutionsVisitor::visit_diff_eq_expression(ast::DiffEqExpressio return dynamic_cast(node).get_expression()->get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_expression()->get_rhs(); - }; - - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) { logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node)); auto get_lhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_left_linxpression(); - }; - - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_left_linxpression(); + return dynamic_cast(node).get_lhs(); }; - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } @@ -243,11 +233,7 @@ void SympyReplaceSolutionsVisitor::visit_non_lin_equation(ast::NonLinEquation& n return dynamic_cast(node).get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_rhs(); - }; - - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } diff --git a/src/visitors/sympy_replace_solutions_visitor.hpp b/src/visitors/sympy_replace_solutions_visitor.hpp index 42bc4da2d..c055d57c0 100644 --- a/src/visitors/sympy_replace_solutions_visitor.hpp +++ b/src/visitors/sympy_replace_solutions_visitor.hpp @@ -249,12 +249,10 @@ class SympyReplaceSolutionsVisitor: public AstVisitor { * * \param node it can be Diff_Eq_Expression/LinEquation/NonLinEquation * \param get_lhs method with witch we may get the lhs (in case we need it) - * \param get_rhs method with witch we may get the rhs (in case we need it) */ void try_replace_tagged_statement( const ast::Node& node, - std::shared_ptr get_lhs(const ast::Node& node), - std::shared_ptr get_rhs(const ast::Node& node)); + std::shared_ptr get_lhs(const ast::Node& node)); /** * \struct InterleavesCounter diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index 42936ae5e..f57a70aab 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -580,9 +580,9 @@ void SympySolverVisitor::visit_derivative_block(ast::DerivativeBlock& node) { void SympySolverVisitor::visit_lin_equation(ast::LinEquation& node) { check_expr_statements_in_same_block(); - std::string lin_eq = to_nmodl_for_sympy(*node.get_left_linxpression()); + std::string lin_eq = to_nmodl_for_sympy(*node.get_lhs()); lin_eq += " = "; - lin_eq += to_nmodl_for_sympy(*node.get_linxpression()); + lin_eq += to_nmodl_for_sympy(*node.get_rhs()); eq_system.push_back(lin_eq); expression_statements.insert(current_expression_statement); last_expression_statement = current_expression_statement; diff --git a/src/visitors/visitor_utils.cpp b/src/visitors/visitor_utils.cpp index ca9e1a0b6..b279e2087 100644 --- a/src/visitors/visitor_utils.cpp +++ b/src/visitors/visitor_utils.cpp @@ -250,19 +250,24 @@ std::string to_json(const ast::Ast& node, bool compact, bool expand, bool add_nm return stream.str(); } +std::string statement_dependencies_key(const std::shared_ptr& lhs) { + if (!lhs->is_var_name()) { + return ""; + } + + const auto& lhs_var_name = std::dynamic_pointer_cast(lhs); + return get_full_var_name(*lhs_var_name); +} + std::pair> statement_dependencies( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - std::string key; + std::string key = statement_dependencies_key(lhs); std::unordered_set out; - if (!lhs->is_var_name()) { return {key, out}; } - const auto& lhs_var_name = std::dynamic_pointer_cast(lhs); - key = get_full_var_name(*lhs_var_name); - visitor::AstLookupVisitor lookup_visitor; lookup_visitor.lookup(*rhs, ast::AstNodeType::VAR_NAME); auto rhs_nodes = lookup_visitor.get_nodes(); diff --git a/src/visitors/visitor_utils.hpp b/src/visitors/visitor_utils.hpp index aa89d3b84..e1e7c9989 100644 --- a/src/visitors/visitor_utils.hpp +++ b/src/visitors/visitor_utils.hpp @@ -82,11 +82,6 @@ std::shared_ptr create_statement_block( const std::vector& code_statements); -/// Remove statements from given statement block if they exist -void remove_statements_from_block(ast::StatementBlock& block, - const std::set& statements); - - /// Return set of strings with the names of all global variables std::set get_global_vars(const ast::Program& node); @@ -129,6 +124,9 @@ std::string to_json(const ast::Ast& node, bool expand = false, bool add_nmodl = false); +/// The `result.first` of `statement_dependencies`. +std::string statement_dependencies_key(const std::shared_ptr& lhs); + /// If \p lhs and \p rhs combined represent an assignment (we assume to have an "=" in between them) /// we extract the variables on which the assigned variable depends on. We provide the input with /// lhs and rhs because there are a few nodes that have this similar structure but slightly diff --git a/test/usecases/longitudinal_diffusion/heat_eqn_array.mod b/test/usecases/longitudinal_diffusion/heat_eqn_array.mod new file mode 100644 index 000000000..74117389f --- /dev/null +++ b/test/usecases/longitudinal_diffusion/heat_eqn_array.mod @@ -0,0 +1,47 @@ +NEURON { + SUFFIX heat_eqn_array + RANGE x +} + +DEFINE N 4 + +PARAMETER { + kf = 0.0 + kb = 0.0 +} + +ASSIGNED { + x + mu[N] + vol[N] +} + +STATE { + X[N] +} + +INITIAL { + FROM i=0 TO N-1 { + mu[i] = 1.0 + i + vol[i] = 0.01 / (i + 1.0) + + if(x < 0.5) { + X[i] = 1.0 + i + } else { + X[i] = 0.0 + } + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT i, vol[i] {X} + LONGITUDINAL_DIFFUSION i, mu[i] {X} + + FROM i=0 TO N-2 { + ~ X[i] <-> X[i+1] (kf, kb) + } +} diff --git a/test/usecases/longitudinal_diffusion/heat_eqn_function.mod b/test/usecases/longitudinal_diffusion/heat_eqn_function.mod new file mode 100644 index 000000000..70cf7e009 --- /dev/null +++ b/test/usecases/longitudinal_diffusion/heat_eqn_function.mod @@ -0,0 +1,56 @@ +NEURON { + SUFFIX heat_eqn_function + RANGE x + GLOBAL g_mu, g_vol + THREADSAFE +} + +ASSIGNED { + x + g_mu + g_vol +} + +STATE { + X +} + +INITIAL { + g_mu = 1.1 + g_vol = 0.01 + + if(x < 0.5) { + X = 1.0 + } else { + X = 0.0 + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +FUNCTION factor(x) { + if(x < 0.25) { + factor = 0.0 + } else { + factor = 10*(x - 0.25) + } +} + +FUNCTION vol(x) { + vol = (1 + x) * g_vol +} + +FUNCTION mu(x) { + mu = x * g_mu +} + +KINETIC state { + COMPARTMENT vol(x) {X} + LONGITUDINAL_DIFFUSION mu(factor(x)) {X} + + : There must be a reaction equation, but + : we only want to test diffusion. + ~ X << (0.0) +} diff --git a/test/usecases/longitudinal_diffusion/heat_eqn_global.mod b/test/usecases/longitudinal_diffusion/heat_eqn_global.mod new file mode 100644 index 000000000..8e8842856 --- /dev/null +++ b/test/usecases/longitudinal_diffusion/heat_eqn_global.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX heat_eqn_global + RANGE x +} + +PARAMETER { + mu = 2.0 + vol = 0.01 +} + +ASSIGNED { + x +} + +STATE { + X +} + +INITIAL { + if(x < 0.5) { + X = 1.0 + } else { + X = 0.0 + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT vol {X} + LONGITUDINAL_DIFFUSION mu {X} + + : There must be a reaction equation, but + : we only want to test diffusion. + ~ X << (0.0) +} diff --git a/test/usecases/longitudinal_diffusion/heat_eqn_scalar.mod b/test/usecases/longitudinal_diffusion/heat_eqn_scalar.mod new file mode 100644 index 000000000..5d337c631 --- /dev/null +++ b/test/usecases/longitudinal_diffusion/heat_eqn_scalar.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX heat_eqn_scalar + RANGE x +} + +ASSIGNED { + x + mu + vol +} + +STATE { + X +} + +INITIAL { + mu = 1.1 + vol = 0.01 + + if(x < 0.5) { + X = 1.0 + } else { + X = 0.0 + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT vol {X} + LONGITUDINAL_DIFFUSION mu {X} + + : There must be a reaction equation, but + : we only want to test diffusion. + ~ X << (0.0) +} diff --git a/test/usecases/longitudinal_diffusion/heat_eqn_thread_vars.mod b/test/usecases/longitudinal_diffusion/heat_eqn_thread_vars.mod new file mode 100644 index 000000000..b90599f0e --- /dev/null +++ b/test/usecases/longitudinal_diffusion/heat_eqn_thread_vars.mod @@ -0,0 +1,40 @@ +NEURON { + SUFFIX heat_eqn_thread_vars + RANGE x + GLOBAL mu, vol + THREADSAFE +} + +ASSIGNED { + x + mu + vol +} + +STATE { + X +} + +INITIAL { + mu = 1.1 + vol = 0.01 + + if(x < 0.5) { + X = 1.0 + } else { + X = 0.0 + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT vol {X} + LONGITUDINAL_DIFFUSION mu {X} + + : There must be a reaction equation, but + : we only want to test diffusion. + ~ X << (0.0) +} diff --git a/test/usecases/longitudinal_diffusion/test_heat_eqn.py b/test/usecases/longitudinal_diffusion/test_heat_eqn.py new file mode 100644 index 000000000..b454371ba --- /dev/null +++ b/test/usecases/longitudinal_diffusion/test_heat_eqn.py @@ -0,0 +1,145 @@ +import os +import sys +import pickle + +from neuron import h, gui + +import numpy as np + + +def reference_filename(mech_name): + return f"diffuse-{mech_name}.pkl" + + +def save_state(mech_name, t, X): + with open(reference_filename(mech_name), "wb") as f: + pickle.dump((t, X), f) + + +def load_state(mech_name): + filename = reference_filename(mech_name) + if not os.path.exists(filename): + raise RuntimeError("References unavailable. Try running with NOCMODL first.") + + with open(filename, "rb") as f: + return pickle.load(f) + + +def simulator_name(): + return sys.argv[1] if len(sys.argv) >= 2 else None + + +def run_simulation(mech_name, record_states): + nseg = 50 + + s = h.Section() + s.nseg = nseg + s.insert(mech_name) + + t_hoc = h.Vector().record(h._ref_t) + X_hoc = [] + for i in range(nseg): + x = (0.5 + i) / nseg + inst = getattr(s(x), mech_name) + + inst.x = x + X_hoc.append(record_states(inst)) + + h.finitialize() + h.continuerun(1.0) + + t = np.array(t_hoc.as_numpy()) + X = np.array([[np.array(xx.as_numpy()) for xx in x] for x in X_hoc]) + + # The axes are: + # time, spatial position, state variable + X = np.transpose(X, axes=(2, 0, 1)) + + return t, X + + +def check_timeseries(mech_name, t, X): + t_noc, X_noc = load_state(mech_name) + + np.testing.assert_allclose(t, t_noc, atol=1e-10, rtol=0.0) + np.testing.assert_allclose(X, X_noc, atol=1e-10, rtol=0.0) + + +def plot_timeseries(mech_name, t, X, i_state): + try: + import matplotlib.pyplot as plt + except ImportError: + return + + nseg = X.shape[1] + frames_with_label = [0, 1, len(t) - 1] + + fig = plt.figure() + for i_time, t in enumerate(t): + kwargs = {"label": f"t = {t:.3f}"} if i_time in frames_with_label else dict() + + x = [(0.5 + i) / nseg for i in range(nseg)] + plt.plot(x, X[i_time, :, i_state], **kwargs) + + plt.xlabel("Spatial position") + plt.ylabel(f"STATE value: X[{i_state}]") + plt.title(f"Simulator: {simulator_name()}") + plt.legend() + plt.savefig(f"diffusion-{mech_name}-{simulator_name()}-state{i_state}.png", dpi=300) + plt.close(fig) + + +def check_heat_equation(mech_name, record_states): + t, X = run_simulation(mech_name, record_states) + + for i_state in range(X.shape[2]): + plot_timeseries(mech_name, t, X, i_state) + + simulator = sys.argv[1] + if simulator == "nocmodl": + save_state(mech_name, t, X) + + else: + check_timeseries(mech_name, t, X) + + +def record_states_factory(array_size, get_reference): + return lambda inst: [ + h.Vector().record(get_reference(inst, k)) for k in range(array_size) + ] + + +def check_heat_equation_scalar(mech_name): + check_heat_equation( + mech_name, record_states_factory(1, lambda inst, k: inst._ref_X) + ) + + +def test_heat_equation_scalar(): + check_heat_equation_scalar("heat_eqn_scalar") + + +def test_heat_equation_global(): + check_heat_equation_scalar("heat_eqn_global") + + +def test_heat_equation_function(): + check_heat_equation_scalar("heat_eqn_function") + + +def test_heat_equation_thread_vars(): + check_heat_equation_scalar("heat_eqn_thread_vars") + + +def test_heat_equation_array(): + check_heat_equation( + "heat_eqn_array", record_states_factory(4, lambda inst, k: inst._ref_X[k]) + ) + + +if __name__ == "__main__": + test_heat_equation_scalar() + test_heat_equation_global() + test_heat_equation_thread_vars() + test_heat_equation_function() + test_heat_equation_array()