Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Feb 23, 2023
1 parent b3a7fea commit 5bff388
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion econpizza/examples/hank_with_comments.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ aux_equations: |
# `dist` here corresponds to the dist *at the beginning of the period*
aggr_a = jnp.sum(dist*a, axis=(0,1))
aggr_c = jnp.sum(dist*c, axis=(0,1))
# calculate consumption and wealth share of top-10% cumsumers
# calculate consumption and wealth share of top-10%
top10c = 1 - percentile(c, dist, .9)
top10a = 1 - percentile(a, dist, .9)
Expand Down
12 changes: 8 additions & 4 deletions econpizza/solvers/solve_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import time
import jax.numpy as jnp
from ..utilities.jacobian import get_stst_jacobian
from ..parser.checks import check_if_compiled, write_compiled_objects
from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives


def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=True):
def find_path_linear(model, shock=None, init_state=None, parameters=None, horizon=200, verbose=True):
"""Find the linear expected trajectory given an initial state.
Parameters
Expand All @@ -16,6 +17,8 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr
PizzaModel instance
init_state : array
initial state
parameters : dict, optional
alternative parameters. Warning: do only change those parameters that are invariant to the steady state.
shock : tuple, optional
shock in period 0 as in `(shock_name_as_str, shock_size)`. NOTE: Not (yet) implemented.
horizon : int, optional
Expand Down Expand Up @@ -44,7 +47,8 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr
# get model variables
stst = jnp.array(list(model['stst'].values()))
nvars = len(model["variables"])
pars = jnp.array(list(model["parameters"].values()))
pars = jnp.array(
list((parameters if parameters is not None else model["parameters"]).values()))
shocks = model.get("shocks") or ()
x_stst = jnp.ones((horizon + 1, nvars)) * stst

Expand All @@ -53,7 +57,7 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr

x0 = jnp.array(list(init_state)) if init_state is not None else stst

if model['new_model_horizon'] != horizon:
if not check_if_compiled(model, horizon, pars):
# get derivatives via AD and compile functions
build_aggr_het_agent_funcs(
model, nvars, pars, stst, zero_shocks, horizon)
Expand All @@ -62,7 +66,7 @@ def find_path_linear(model, shock=None, init_state=None, horizon=200, verbose=Tr

# accumulate steady stat jacobian
get_stst_jacobian(model, derivatives, horizon, nvars, verbose)
model['new_model_horizon'] = horizon
write_compiled_objects(model, horizon, pars)

jacobian = model['jac']

Expand Down
1 change: 0 additions & 1 deletion econpizza/solvers/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def find_path_stacking(
jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[])
model['jav_func'] = jav_func_eqns_partial
# mark as compiled
model['new_model_horizon'] = horizon
write_compiled_objects(model, horizon, pars)

# actual newton iterations
Expand Down

0 comments on commit 5bff388

Please sign in to comment.