diff --git a/src/sequence_jacobian/blocks/solved_block.py b/src/sequence_jacobian/blocks/solved_block.py index 048f1ad..ac335c7 100644 --- a/src/sequence_jacobian/blocks/solved_block.py +++ b/src/sequence_jacobian/blocks/solved_block.py @@ -63,7 +63,7 @@ def _steady_state(self, calibration, dissolve, options, **kwargs): unknowns = {k: v for k, v in calibration.items() if k in self.unknowns} else: unknowns = self.unknowns - if 'solver' not in kwargs: + if not kwargs['solver']: # TODO: replace this with default option kwargs['solver'] = self.solver diff --git a/src/sequence_jacobian/blocks/support/simple_displacement.py b/src/sequence_jacobian/blocks/support/simple_displacement.py index 32e8708..6b0b77a 100644 --- a/src/sequence_jacobian/blocks/support/simple_displacement.py +++ b/src/sequence_jacobian/blocks/support/simple_displacement.py @@ -120,6 +120,7 @@ def __call__(self, index): return self def apply(self, f, **kwargs): + kwargs.update({arg: kwargs[arg].f_value for arg in kwargs if isinstance(kwargs[arg], AccumulatedDerivative)}) return ignore(f(numeric_primitive(self), **kwargs)) def __pos__(self): diff --git a/src/sequence_jacobian/hetblocks/hh_twoasset.py b/src/sequence_jacobian/hetblocks/hh_twoasset.py index a5b6e47..320db2e 100644 --- a/src/sequence_jacobian/hetblocks/hh_twoasset.py +++ b/src/sequence_jacobian/hetblocks/hh_twoasset.py @@ -179,4 +179,4 @@ def lhs_equals_rhs_interpolate(lhs, rhs, iout, piout): err_upper = rhs[i, j] - lhs[i] err_lower = rhs[i - 1, j] - lhs[i - 1] piout[j] = err_upper / (err_upper - err_lower) - \ No newline at end of file + diff --git a/src/sequence_jacobian/utilities/function.py b/src/sequence_jacobian/utilities/function.py index dd51281..62dd408 100644 --- a/src/sequence_jacobian/utilities/function.py +++ b/src/sequence_jacobian/utilities/function.py @@ -40,7 +40,11 @@ def output_list(f): Important to write functions in this way when they will be scanned by output_list, for either SimpleBlock or HetBlock. """ - return OrderedSet(re.findall('return (.*?)\n', inspect.getsource(f))[-1].replace(' ', '').split(',')) + source = inspect.getsource(f) + source_no_comments = re.sub(r'(?m)^ *#.*\n?', '', source) + return_statements = re.findall('return (.*?)\n', source_no_comments) + + return OrderedSet(return_statements[0].replace(' ', '').split(',')) def metadata(f):