Skip to content

Commit

Permalink
Enhance SympyConductanceVisitor - fixes #7 (#14)
Browse files Browse the repository at this point in the history
- now takes into account all previous statements in BREAKPOINT
- substitutes them (recursively) before differentiating
- extended unit tests
- also added unit tests for differentiate2c function in ode.py
  • Loading branch information
lkeegan authored and pramodk committed Feb 25, 2019
1 parent a0589ae commit fd3c6c6
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 49 deletions.
63 changes: 54 additions & 9 deletions nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def integrate2c(diff_string, t_var, dt_var, vars, use_pade_approx=False):
sympy_vars[t_var] = t

# parse string into SymPy equation
diffeq = sp.Eq(x.diff(t), sp.sympify(diff_string.split("=")[1], locals=sympy_vars))
diffeq = sp.Eq(
x.diff(t), sp.sympify(diff_string.split("=", 1)[1], locals=sympy_vars)
)

# classify ODE, if it is too hard then exit
ode_properties = set(sp.classify_ode(diffeq))
Expand All @@ -89,31 +91,46 @@ def integrate2c(diff_string, t_var, dt_var, vars, use_pade_approx=False):
_a0 = taylor_series.nth(0)
_a1 = taylor_series.nth(1)
_a2 = taylor_series.nth(2)
solution = ((_a0*_a1 + (_a1*_a1-_a0*_a2)*dt)/(_a1-_a2*dt)).simplify()
solution = (
(_a0 * _a1 + (_a1 * _a1 - _a0 * _a2) * dt) / (_a1 - _a2 * dt)
).simplify()

# return result as C code in NEURON format
return f"{sp.ccode(x_0)} = {sp.ccode(solution)}"


def differentiate2c(expression, dependent_var, vars):
def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
"""Analytically differentiate supplied expression, return solution as C code.
Expression should be of the form "f(x)", where "x" is
the dependent variable, and the function returns df(x)/dx
vars should contain the set of all the variables
referenced by f(x), for example:
The set vars must contain all variables used in the expression.
Furthermore, if any of these variables are themselves functions that should
be substituted before differentiating, they can be supplied in the prev_expressions list.
Before differentiating each of these expressions will be substituted into expressions,
where possible, in reverse order - i.e. starting from the end of the list.
If the result coincides with one of the vars, or the LHS of one of
the prev_expressions, then it is simplified to this expression.
Some simple examples of use:
-differentiate2c("a*x", "x", {"a"}) == "a"
-differentiate2c("cos(y) + b*y**2", "y", {"a","b"}) == "Dy = 2*b*y - sin(y)"
Args:
expression: expression to be differentiated e.g. "a*x + b"
dependent_var: dependent variable, e.g. "x"
vars: set of all other variables used in expression, e.g. {"a", "b"}
vars: set of all other variables used in expression, e.g. {"a", "b", "c"}
prev_expressions: time-ordered list of preceeding expressions
to evaluate & substitute, e.g. ["b = x + c", "a = 12*b"]
Returns:
String containing analytic derivative as C code
String containing analytic derivative of expression (including any substitutions
of variables from supplied prev_expressions) w.r.t dependent_var as C code.
"""

prev_expressions = prev_expressions or []
# every symbol (a.k.a variable) that SymPy
# is going to manipulate needs to be declared
# explicitly
Expand All @@ -127,8 +144,36 @@ def differentiate2c(expression, dependent_var, vars):
# parse string into SymPy equation
expr = sp.sympify(expression, locals=sympy_vars)

# parse previous equations into (lhs, rhs) pairs & reverse order
prev_eqs = [
(
sp.sympify(e.split("=", 1)[0], locals=sympy_vars),
sp.sympify(e.split("=", 1)[1], locals=sympy_vars),
)
for e in prev_expressions
]
prev_eqs.reverse()

# substitute each prev equation in reverse order: latest first
for eq in prev_eqs:
expr = expr.subs(eq[0], eq[1])

# differentiate w.r.t. x
diff = expr.diff(x)
diff = expr.diff(x).simplify()

# if expression is equal to one of the supplied vars, replace with this var
for v in sympy_vars:
if (diff - sympy_vars[v]).simplify() == 0:
diff = sympy_vars[v]
# or if equal to rhs of one of supplied equations, replace with lhs
for i_eq, eq in enumerate(prev_eqs):
# each supplied eq also needs recursive substitution of preceeding statements
# here, before comparison with diff expression
expr = eq[1]
for sub_eq in prev_eqs[i_eq:]:
expr = expr.subs(sub_eq[0], sub_eq[1])
if (diff - expr).simplify() == 0:
diff = eq[0]

# return result as C code in NEURON format
return sp.ccode(diff)
75 changes: 44 additions & 31 deletions src/visitors/sympy_conductance_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,29 @@ using namespace syminfo;
std::vector<std::string> SympyConductanceVisitor::generate_statement_strings(
BreakpointBlock* node) {
std::vector<std::string> statements;
// iterate over binary expressions from breakpoint
for (const auto& expr: binary_exprs) {
auto lhs_str = expr.first;
auto equation_string = expr.second;
// iterate over binary expression lhs's from breakpoint
for (const auto& lhs_str: ordered_binary_exprs_lhs) {
// look for a current name that matches lhs of expr (current write name)
auto it = i_name.find(lhs_str);
if (it != i_name.end()) {
const auto& equation_string = ordered_binary_exprs[binary_expr_index[lhs_str]];
std::string i_name_str = it->second;
// SymPy needs the current expression & all previous expressions
std::vector<std::string> expressions(ordered_binary_exprs.begin(),
ordered_binary_exprs.begin() +
binary_expr_index[lhs_str] + 1);
// differentiate dI/dV
auto locals = py::dict("equation_string"_a = equation_string, "vars"_a = vars);
auto locals = py::dict("expressions"_a = expressions, "vars"_a = vars);
py::exec(R"(
from nmodl.ode import differentiate2c
exception_message = ""
try:
rhs = equation_string.split("=")[1]
solution = differentiate2c(rhs, "v", vars)
rhs = expressions[-1].split("=", 1)[1]
solution = differentiate2c(rhs,
"v",
vars,
expressions[:-1]
)
except Exception as e:
# if we fail, fail silently and return empty string
solution = ""
Expand All @@ -47,7 +54,7 @@ std::vector<std::string> SympyConductanceVisitor::generate_statement_strings(
auto dIdV = locals["solution"].cast<std::string>();
auto exception_message = locals["exception_message"].cast<std::string>();
if (!exception_message.empty()) {
logger->warn("SympyConductance :: python exception: " + exception_message);
logger->warn("SympyConductance :: python exception: {}", exception_message);
}
if (dIdV.empty()) {
logger->warn(
Expand All @@ -67,41 +74,48 @@ std::vector<std::string> SympyConductanceVisitor::generate_statement_strings(
// declare it
add_local_variable(node->get_statement_block().get(), g_var);
// asign dIdV to it
std::string statement_str = g_var + " = " + dIdV;
std::string statement_str = g_var;
statement_str.append(" = ").append(dIdV);
statements.insert(statements.begin(), statement_str);
logger->debug("SympyConductance :: Adding BREAKPOINT statement: " +
logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}",
statement_str);
}
std::string statement_str = "CONDUCTANCE " + g_var;
if (i_name_str != "") {
if (!i_name_str.empty()) {
statement_str += " USEION " + i_name_str;
}
statements.push_back(statement_str);
logger->debug("SympyConductance :: Adding BREAKPOINT statement: " + statement_str);
logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}", statement_str);
}
}
}
return statements;
}

void SympyConductanceVisitor::visit_binary_expression(BinaryExpression* node) {
// only want binary expressions from breakpoint block
if (!breakpoint_block) {
return;
}
// only want binary expressions of form x = ...
if (node->lhs->is_var_name() && (node->op.get_value() == BinaryOp::BOP_ASSIGN)) {
auto lhs_str = std::dynamic_pointer_cast<VarName>(node->lhs)->get_name()->get_node_name();
binary_exprs[lhs_str] = nmodl::to_nmodl(node);
binary_expr_index[lhs_str] = ordered_binary_exprs.size();
ordered_binary_exprs.push_back(nmodl::to_nmodl(node));
ordered_binary_exprs_lhs.push_back(lhs_str);
}
}

void SympyConductanceVisitor::lookup_nonspecific_statements() {
// add NONSPECIFIC_CURRENT statements to i_name map between write vars and names
// note that they don't have an ion name, so we set it to ""
if (!NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS) {
for (auto ns_curr_ast: nonspecific_nodes) {
for (const auto& ns_curr_ast: nonspecific_nodes) {
logger->debug("SympyConductance :: Found NONSPECIFIC_CURRENT statement");
for (auto write_name:
for (const auto& write_name:
std::dynamic_pointer_cast<Nonspecific>(ns_curr_ast).get()->get_currents()) {
std::string curr_write = write_name->get_node_name();
logger->debug("SympyConductance :: -> Adding non-specific current write name: " +
logger->debug("SympyConductance :: -> Adding non-specific current write name: {}",
curr_write);
i_name[curr_write] = "";
}
Expand All @@ -111,18 +125,18 @@ void SympyConductanceVisitor::lookup_nonspecific_statements() {

void SympyConductanceVisitor::lookup_useion_statements() {
// add USEION statements to i_name map between write vars and names
for (auto useion_ast: use_ion_nodes) {
for (const auto& useion_ast: use_ion_nodes) {
auto ion = std::dynamic_pointer_cast<Useion>(useion_ast).get();
std::string ion_name = ion->get_node_name();
logger->debug("SympyConductance :: Found USEION statement " + nmodl::to_nmodl(ion));
logger->debug("SympyConductance :: Found USEION statement {}", nmodl::to_nmodl(ion));
if (i_ignore.find(ion_name) != i_ignore.end()) {
logger->debug("SympyConductance :: -> Ignoring ion current name: " + ion_name);
logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion_name);
} else {
auto wl = ion->get_writelist();
for (auto w: wl) {
for (const auto& w: ion->get_writelist()) {
std::string ion_write = w->get_node_name();
logger->debug("SympyConductance :: -> Adding ion write name: " + ion_write +
" for ion current name: " + ion_name);
logger->debug(
"SympyConductance :: -> Adding ion write name: {} for ion current name: {}",
ion_write, ion_name);
i_name[ion_write] = ion_name;
}
}
Expand All @@ -132,11 +146,10 @@ void SympyConductanceVisitor::lookup_useion_statements() {
void SympyConductanceVisitor::visit_conductance_hint(ConductanceHint* node) {
// find existing CONDUCTANCE statements - do not want to alter them
// so keep a set of ion names i_ignore that we should ignore later
logger->debug("SympyConductance :: Found existing CONDUCTANCE statement: " +
logger->debug("SympyConductance :: Found existing CONDUCTANCE statement: {}",
nmodl::to_nmodl(node));
auto ion = node->get_ion();
if (ion) {
logger->debug("SympyConductance :: -> Ignoring ion current name: " + ion->get_node_name());
if (auto ion = node->get_ion()) {
logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion->get_node_name());
i_ignore.insert(ion->get_node_name());
} else {
logger->debug("SympyConductance :: -> Ignoring all non-specific currents");
Expand All @@ -146,14 +159,15 @@ void SympyConductanceVisitor::visit_conductance_hint(ConductanceHint* node) {

void SympyConductanceVisitor::visit_breakpoint_block(BreakpointBlock* node) {
// add any breakpoint local variables to vars
if (auto symtab = node->get_statement_block()->get_symbol_table()) {
for (auto localvar: symtab->get_variables_with_properties(NmodlType::local_var)) {
if (auto* symtab = node->get_statement_block()->get_symbol_table()) {
for (const auto& localvar: symtab->get_variables_with_properties(NmodlType::local_var)) {
vars.insert(localvar->get_name());
}
}

// visit BREAKPOINT block statements
breakpoint_block = true;
node->visit_children(this);
breakpoint_block = false;

// lookup USEION and NONSPECIFIC statements from NEURON block
lookup_useion_statements();
Expand Down Expand Up @@ -181,7 +195,6 @@ void SympyConductanceVisitor::visit_breakpoint_block(BreakpointBlock* node) {

void SympyConductanceVisitor::visit_program(Program* node) {
vars = get_global_vars(node);

AstLookupVisitor ast_lookup_visitor;
use_ion_nodes = ast_lookup_visitor.lookup(node, AstNodeType::USEION);
nonspecific_nodes = ast_lookup_visitor.lookup(node, AstNodeType::NONSPECIFIC);
Expand Down
12 changes: 11 additions & 1 deletion src/visitors/sympy_conductance_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,23 @@

class SympyConductanceVisitor: public AstVisitor {
private:
/// true while visiting breakpoint block
bool breakpoint_block = false;
typedef std::map<std::string, std::string> string_map;
typedef std::set<std::string> string_set;
// set of all variables for SymPy
string_set vars;
// set of currents to ignore
string_set i_ignore;
// map between current write names and ion names
string_map i_name;
bool NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS = false;
string_map binary_exprs;
// list in order of binary expressions in breakpoint
std::vector<std::string> ordered_binary_exprs;
// ditto but for LHS of expression only
std::vector<std::string> ordered_binary_exprs_lhs;
// map from lhs of binary expression to index of expression in above vector
std::map<std::string, std::size_t> binary_expr_index;
std::vector<std::shared_ptr<ast::AST>> use_ion_nodes;
std::vector<std::shared_ptr<ast::AST>> nonspecific_nodes;

Expand Down
5 changes: 2 additions & 3 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ void SympySolverVisitor::visit_diff_eq_expression(DiffEqExpression* node) {

void SympySolverVisitor::visit_derivative_block(ast::DerivativeBlock* node) {
// get any local vars
auto symtab = node->get_statement_block()->get_symbol_table();
if (symtab) {
if (auto symtab = node->get_statement_block()->get_symbol_table()) {
auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
for (auto v: localvars) {
for (const auto& v: localvars) {
vars.insert(v->get_name());
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/visitors/visitor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ std::shared_ptr<Statement> create_statement(const std::string& code_statement) {

std::set<std::string> get_global_vars(Program* node) {
std::set<std::string> vars;
if (auto symtab = node->get_symbol_table()) {
if (auto* symtab = node->get_symbol_table()) {
syminfo::NmodlType property =
syminfo::NmodlType::global_var | syminfo::NmodlType::range_var |
syminfo::NmodlType::param_assign | syminfo::NmodlType::extern_var |
Expand All @@ -96,7 +96,7 @@ std::set<std::string> get_global_vars(Program* node) {
syminfo::NmodlType::nonspecific_cur_var | syminfo::NmodlType::electrode_cur_var |
syminfo::NmodlType::section_var | syminfo::NmodlType::constant_var |
syminfo::NmodlType::extern_neuron_variable | syminfo::NmodlType::state_var;
for (auto globalvar: symtab->get_variables_with_properties(property)) {
for (const auto& globalvar: symtab->get_variables_with_properties(property)) {
vars.insert(globalvar->get_name());
}
}
Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ set_tests_properties(Visitor PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_BINARY_DI
# pybind11 tests
# =============================================================================

add_test(NAME Ode
COMMAND python3 -m pytest ${PROJECT_SOURCE_DIR}/test/ode)
set_tests_properties(Ode PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_BINARY_DIR}:$ENV{PYTHONPATH})
add_test(NAME Pybind
COMMAND python3 -m pytest ${PROJECT_SOURCE_DIR}/test/pybind)
set_tests_properties(Pybind PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_BINARY_DIR}:$ENV{PYTHONPATH})
Empty file added test/ode/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions test/ode/test_ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# ***********************************************************************
# Copyright (C) 2018-2019 Blue Brain Project
#
# This file is part of NMODL distributed under the terms of the GNU
# Lesser General Public License. See top-level LICENSE file for details.
# ***********************************************************************

from nmodl.ode import differentiate2c


def test_differentiation():

# simple examples, no prev_expressions
assert differentiate2c("0", "x", "") == "0"
assert differentiate2c("x", "x", "") == "1"
assert differentiate2c("a", "x", "a") == "0"
assert differentiate2c("a*x", "x", "a") == "a"
assert differentiate2c("a*x", "a", "x") == "x"
assert differentiate2c("a*x", "y", {"x", "y"}) == "0"
assert differentiate2c("a*x + b*x*x", "x", {"a", "b"}) == "a + 2*b*x"
assert differentiate2c("a*cos(x+b)", "x", {"a", "b"}) == "-a*sin(b + x)"
assert (
differentiate2c("a*cos(x+b) + c*x*x", "x", {"a", "b", "c"})
== "-a*sin(b + x) + 2*c*x"
)

# single prev_expression to substitute
assert differentiate2c("a*x + b", "x", {"a", "b", "c", "d"}, ["c = sqrt(d)"]) == "a"
assert differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x"]) == "a + 2"

# multiple prev_eqs to substitute
# (these statements should be in the same order as in the mod file)
assert differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x", "a = -2"]) == "0"
assert differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x", "a = -2"]) == "0"

# multiple prev_eqs to recursively substitute
# note prev_eqs always substituted in reverse order
assert differentiate2c("a*x + b", "x", {"a", "b"}, ["a=3", "b = 2*a*x"]) == "9"
# if we can return result in terms of supplied var, do so
# even in this case where the supplied var a is equal to 3:
assert (
differentiate2c(
"a*x + b*c", "x", {"a", "b", "c"}, ["a=3", "b = 2*a*x", "c = a/x"]
)
== "a"
)
assert (
differentiate2c("-a*x + b*c", "x", {"a", "b", "c"}, ["b = 2*x*x", "c = a/x"])
== "a"
)
assert (
differentiate2c(
"(g1 + g2)*(v-e)",
"v",
{"g", "e", "g1", "g2", "c", "d"},
["g2 = sqrt(d) + 3", "g1 = 2*c", "g = g1 + g2"],
)
== "g"
)
Loading

0 comments on commit fd3c6c6

Please sign in to comment.