From 5bff388ac23ff4db4aa0146a1c227fced2af80b2 Mon Sep 17 00:00:00 2001 From: Gregor Boehl Date: Thu, 23 Feb 2023 16:48:10 +0100 Subject: [PATCH] fix test --- econpizza/examples/hank_with_comments.yml | 2 +- econpizza/solvers/solve_linear.py | 12 ++++++++---- econpizza/solvers/stacking.py | 1 - 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/econpizza/examples/hank_with_comments.yml b/econpizza/examples/hank_with_comments.yml index 7de657a..3058d93 100644 --- a/econpizza/examples/hank_with_comments.yml +++ b/econpizza/examples/hank_with_comments.yml @@ -79,7 +79,7 @@ aux_equations: | # `dist` here corresponds to the dist *at the beginning of the period* aggr_a = jnp.sum(dist*a, axis=(0,1)) aggr_c = jnp.sum(dist*c, axis=(0,1)) - # calculate consumption and wealth share of top-10% cumsumers + # calculate consumption and wealth share of top-10% top10c = 1 - percentile(c, dist, .9) top10a = 1 - percentile(a, dist, .9) diff --git a/econpizza/solvers/solve_linear.py b/econpizza/solvers/solve_linear.py index e45fafd..a4500e4 100644 --- a/econpizza/solvers/solve_linear.py +++ b/econpizza/solvers/solve_linear.py @@ -4,10 +4,11 @@ import time import jax.numpy as jnp from ..utilities.jacobian import get_stst_jacobian +from ..parser.checks import check_if_compiled, write_compiled_objects from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives -def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=True): +def find_path_linear(model, shock=None, init_state=None, parameters=None, horizon=200, verbose=True): """Find the linear expected trajectory given an initial state. Parameters @@ -16,6 +17,8 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr PizzaModel instance init_state : array initial state + parameters : dict, optional + alternative parameters. Warning: do only change those parameters that are invariant to the steady state. shock : tuple, optional shock in period 0 as in `(shock_name_as_str, shock_size)`. NOTE: Not (yet) implemented. horizon : int, optional @@ -44,7 +47,8 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr # get model variables stst = jnp.array(list(model['stst'].values())) nvars = len(model["variables"]) - pars = jnp.array(list(model["parameters"].values())) + pars = jnp.array( + list((parameters if parameters is not None else model["parameters"]).values())) shocks = model.get("shocks") or () x_stst = jnp.ones((horizon + 1, nvars)) * stst @@ -53,7 +57,7 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr x0 = jnp.array(list(init_state)) if init_state is not None else stst - if model['new_model_horizon'] != horizon: + if not check_if_compiled(model, horizon, pars): # get derivatives via AD and compile functions build_aggr_het_agent_funcs( model, nvars, pars, stst, zero_shocks, horizon) @@ -62,7 +66,7 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr # accumulate steady stat jacobian get_stst_jacobian(model, derivatives, horizon, nvars, verbose) - model['new_model_horizon'] = horizon + write_compiled_objects(model, horizon, pars) jacobian = model['jac'] diff --git a/econpizza/solvers/stacking.py b/econpizza/solvers/stacking.py index cf12deb..209fa61 100644 --- a/econpizza/solvers/stacking.py +++ b/econpizza/solvers/stacking.py @@ -89,7 +89,6 @@ def find_path_stacking( jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[]) model['jav_func'] = jav_func_eqns_partial # mark as compiled - model['new_model_horizon'] = horizon write_compiled_objects(model, horizon, pars) # actual newton iterations