Skip to content

Commit

Permalink
also check calls for vals defined in t
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Mar 31, 2023
1 parent 59676a1 commit 55d08e4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
5 changes: 3 additions & 2 deletions econpizza/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def compile_stst_inputs(model):
evars, par_names, init_guesses, fixed_evaluated)

# get the initial decision functions
if model.get('decisions'):
if model.get('distributions'):
dist_names = list(model['distributions'].keys())
# for now assume that this must be present
init_wf = jnp.array([init_guesses[dec_input]
Expand Down Expand Up @@ -330,7 +330,8 @@ def load(
check_dublicates(model.get("parameters"))
evars = check_determinancy(model["variables"], eqns)
# check if each variable is defined in time t (only defining xSS does not give a valid root)
check_if_defined(evars, eqns, model.get('skip_check_if_defined'))
check_if_defined(evars, eqns, model.get('decisions'),
model.get('skip_check_if_defined'))

# create fixed (time invariant) grids
grids.create_grids(model.get('distributions'), model["context"], verbose)
Expand Down
18 changes: 16 additions & 2 deletions econpizza/parser/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,34 @@
import sys
import jax.numpy as jnp
from .build_functions import func_pre_stst
import re


def check_if_defined(evars, eqns, skipped_vars):
def _strip_comments(code):
code = str(code)
return re.sub(r'(?m)^ *#.*\n?', '', code)


def check_if_defined(evars, eqns, decisions, skipped_vars):
"""Check if all variables are defined in period t.
"""
skipped_vars = [] if skipped_vars is None else skipped_vars
calls = _strip_comments(
decisions['calls']) if decisions is not None else ''

for v in evars:
v_in_eqns = [
v in e.replace(v + "SS", "").replace(v + "Lag",
"").replace(v + "Prime", "")
for e in eqns
]
if not any(v_in_eqns) and not v in skipped_vars:
v_in_calls = v in calls.replace(
v + "SS", "").replace(v + "Lag", "").replace(v + "Prime", "")
print(v)
print(calls.replace(v + "SS", "").replace(v +
"Lag", "").replace(v + "Prime", ""))

if not any(v_in_eqns) and not v_in_calls and not v in skipped_vars:
raise Exception(f"Variable `{v}` is not defined for time t.")
return

Expand Down

0 comments on commit 55d08e4

Please sign in to comment.