From c705fc65dd58752dbe04fe298dad71f51454a9e8 Mon Sep 17 00:00:00 2001 From: Gregor Boehl Date: Thu, 9 Feb 2023 17:26:59 +0100 Subject: [PATCH] add method for ev < 0 --- econpizza/parser/build_functions.py | 50 +++++++++++++++--------- econpizza/solvers/shooting.py | 18 +++++---- econpizza/solvers/solve_linear.py | 10 +++-- econpizza/solvers/stacking.py | 40 ++++++++++++------- econpizza/test_all.py | 23 +++++++++++ econpizza/test_storage/hank_solid.npy | Bin 0 -> 2984 bytes econpizza/utilities/jacobian.py | 54 +++++++++++++++++++++++++- econpizza/utilities/newton.py | 30 ++++++-------- 8 files changed, 164 insertions(+), 61 deletions(-) create mode 100644 econpizza/test_storage/hank_solid.npy diff --git a/econpizza/parser/build_functions.py b/econpizza/parser/build_functions.py index 3176a6d..baf4918 100644 --- a/econpizza/parser/build_functions.py +++ b/econpizza/parser/build_functions.py @@ -80,26 +80,18 @@ def get_stacked_func_dist(pars, func_backw, func_dist, func_eqns, stst, vfSS, di return stacked_func_dist, backwards_sweep, forwards_sweep, second_sweep -def get_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose): +def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose): st = time.time() - shocks = model.get("shocks") or () - # get functions func_eqns = model['context']["func_eqns"] - func_backw = model['context'].get('func_backw') - func_dist = model['context'].get('func_dist') + backwards_sweep = model['context']['backwards_sweep'] + second_sweep = model['context']['second_sweep'] - # get stuff for het-agent models - vfSS = model['steady_state'].get('decisions') distSS = jnp.array(model['steady_state']['distributions'])[..., None] decisions_outputSS = jnp.array( model['steady_state']['decisions_output'])[..., None] - # get actual functions - func_raw, backwards_sweep, forwards_sweep, second_sweep = get_stacked_func_dist( - pars, func_backw, func_dist, func_eqns, stst, vfSS, distSS[..., 0], horizon, nvars) - # basis for steady state jacobian construction basis = jnp.zeros((nvars*(horizon-1), nvars)) basis = basis.at[-nvars:, -nvars:].set(jnp.eye(nvars)) @@ -116,16 +108,38 @@ def get_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose) f2X = jacrev_func_eqns(stst[:, None], stst[:, None], stst[:, None], stst, zshocks[:, 0], pars, distSS, decisions_outputSS) - # store everything - model['context']['func_raw'] = func_raw - model['context']['backwards_sweep'] = backwards_sweep - model['context']['forwards_sweep'] = forwards_sweep - model['jvp_func'] = lambda primals, tangens, x0, shocks: jax.jvp( - func_raw, (primals, x0, shocks), (tangens, jnp.zeros(nvars), zshocks)) - if verbose: duration = time.time() - st print( f"(get_derivatives:) Derivatives calculation done ({duration:1.3f}s).") return f2X, f2do, do2x + + +def build_aggr_het_agent_funcs(model, nvars, pars, stst, zshocks, horizon): + + shocks = model.get("shocks") or () + # get functions + func_eqns = model['context']["func_eqns"] + func_backw = model['context'].get('func_backw') + func_dist = model['context'].get('func_dist') + + # get stuff for het-agent models + vfSS = model['steady_state'].get('decisions') + distSS = jnp.array(model['steady_state']['distributions'])[..., None] + decisions_outputSS = jnp.array( + model['steady_state']['decisions_output'])[..., None] + + # get actual functions + func_raw, backwards_sweep, forwards_sweep, second_sweep = get_stacked_func_dist( + pars, func_backw, func_dist, func_eqns, stst, vfSS, distSS[..., 0], horizon, nvars) + + # store everything + model['context']['func_raw'] = func_raw + model['context']['backwards_sweep'] = backwards_sweep + model['context']['forwards_sweep'] = forwards_sweep + model['context']['second_sweep'] = second_sweep + model['context']['jvp_func'] = lambda primals, tangens, x0, shocks: jax.jvp( + func_raw, (primals, x0, shocks), (tangens, jnp.zeros(nvars), zshocks)) + model['context']['vjp_func'] = lambda primals, tangens, x0, shocks: jax.vjp( + lambda x: func_raw(x, x0, shocks), primals)[1](tangens)[0] diff --git a/econpizza/solvers/shooting.py b/econpizza/solvers/shooting.py index 2e49c90..3d39356 100644 --- a/econpizza/solvers/shooting.py +++ b/econpizza/solvers/shooting.py @@ -5,7 +5,7 @@ import time import jax import jax.numpy as jnp -from grgrlib.jaxed import newton_jax_jit, newton_jax_jittable +from grgrlib.jaxed import newton_jax_jit, jacfwd_and_val msgs = ( @@ -112,10 +112,13 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime): """Solves for one period. """ - res = newton_jax_jittable(lambda x: func( - XLag, x, XPrime, stst, shock, pars), XLastGuess) + # partial_func = jax.tree_util.Partial(func, XLag=XLag, XPrime=XPrime, XSS=stst, shocks=shock, pars=pars) + def partial_func(x): return func(XLag, x, XPrime, stst, shock, pars) + jav = jacfwd_and_val(partial_func) + partial_jav = jax.tree_util.Partial(jav) + res = newton_jax_jit(partial_jav, XLastGuess, verbose=False) - return res[0], res[2], res[3] + return res[0], res[3] try: for i in range(T): @@ -138,14 +141,15 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime): else: tshock.at[:].set(0) - x_new, flag_root, flag_ftol = solve_current( + x_new, flag_ftol = solve_current( pars, tshock, x[t], x[t + 1], x[t + 2]) + flag_ftol = ~flag_ftol x = x.at[t + 1].set(x_new) - flag_loc = flag_loc.at[0].set(flag_loc[0] or not flag_root) + flag_loc = flag_loc.at[0].set(flag_loc[0]) flag_loc = flag_loc.at[1].set( - flag_loc[2] or (not flag_ftol and flag_root)) + flag_loc[2] or not flag_ftol) flag = flag.at[2].set(flag[2] or jnp.any(jnp.isnan(x))) flag = flag.at[3].set(flag[3] or jnp.any(jnp.isinf(x))) diff --git a/econpizza/solvers/solve_linear.py b/econpizza/solvers/solve_linear.py index 9e78202..4229e95 100644 --- a/econpizza/solvers/solve_linear.py +++ b/econpizza/solvers/solve_linear.py @@ -5,7 +5,7 @@ import time import jax.numpy as jnp from ..utilities.jacobian import get_stst_jacobian -from ..parser.build_functions import get_derivatives +from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives def find_path_linear(model, shock=None, x0=None, horizon=300, verbose=True): @@ -50,14 +50,16 @@ def find_path_linear(model, shock=None, x0=None, horizon=300, verbose=True): x_stst = jnp.ones((horizon + 1, nvars)) * stst # deal with shocks - zero_shocks = jnp.zeros((horizon-1, len(shocks))) + zero_shocks = jnp.zeros((horizon-1, len(shocks))).T x0 = jnp.array(list(x0)) if x0 is not None else stst if model['new_model_horizon'] != horizon: # get derivatives via AD and compile functions - derivatives = get_derivatives( - model, nvars, pars, stst, x_stst, zero_shocks.T, horizon, verbose) + build_aggr_het_agent_funcs( + model, nvars, pars, stst, zero_shocks, horizon) + derivatives = get_stst_derivatives( + model, nvars, pars, stst, x_stst, zero_shocks, horizon, verbose) # accumulate steady stat jacobian get_stst_jacobian(model, derivatives, horizon, nvars, verbose) diff --git a/econpizza/solvers/stacking.py b/econpizza/solvers/stacking.py index 357fc17..d34a45b 100644 --- a/econpizza/solvers/stacking.py +++ b/econpizza/solvers/stacking.py @@ -5,9 +5,9 @@ import jax import time import jax.numpy as jnp -from grgrlib.jaxed import * -from ..parser.build_functions import * -from ..utilities.jacobian import get_stst_jacobian +from grgrlib.jaxed import newton_jax_jit, jacrev_and_val +from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives +from ..utilities.jacobian import get_stst_jacobian, get_jac_and_value_sliced from ..utilities.newton import newton_for_jvp, newton_for_banded_jac @@ -16,6 +16,7 @@ def find_path_stacking( shock=None, x0=None, horizon=300, + use_solid_solver=False, verbose=True, raise_errors=True, **newton_args @@ -89,22 +90,35 @@ def find_path_stacking( else: if model['new_model_horizon'] != horizon: # get derivatives via AD and compile functions - derivatives = get_derivatives( - model, nvars, pars, stst, x_stst, jnp.zeros_like(shock_series).T, horizon, verbose) + zero_shocks = jnp.zeros_like(shock_series).T + build_aggr_het_agent_funcs( + model, nvars, pars, stst, zero_shocks, horizon) + + if not use_solid_solver: + # get steady state partial jacobians + derivatives = get_stst_derivatives( + model, nvars, pars, stst, x_stst, zero_shocks, horizon, verbose) + # accumulate steady stat jacobian + get_stst_jacobian(model, derivatives, horizon, nvars, verbose) - # accumulate steady stat jacobian - get_stst_jacobian(model, derivatives, horizon, nvars, verbose) # mark as compiled model['new_model_horizon'] = horizon # get jvp function and steady state jacobian jvp_partial = jax.tree_util.Partial( - model['jvp_func'], x0=x0, shocks=shock_series.T) - jacobian = model['jac_factorized'] - - # actual newton iterations - x, flag, mess = newton_for_jvp( - jvp_partial, jacobian, x_init, verbose, **newton_args) + model['context']['jvp_func'], x0=x0, shocks=shock_series.T) + if not use_solid_solver: + jacobian = model['jac_factorized'] + # actual newton iterations + x, flag, mess = newton_for_jvp( + jvp_partial, jacobian, x_init, verbose, **newton_args) + else: + # define function returning value and jacobian calculated in slices + value_and_jac_func = get_jac_and_value_sliced( + (horizon-1)*nvars, jvp_partial, newton_args) + x, _, _, flag = newton_jax_jit( + value_and_jac_func, x_init[1:-1].flatten(), **newton_args) + mess = '' x_out = x_init.at[1:-1].set(x.reshape((horizon - 1, nvars))) # some informative print messages diff --git a/econpizza/test_all.py b/econpizza/test_all.py index 213d794..6a25591 100644 --- a/econpizza/test_all.py +++ b/econpizza/test_all.py @@ -152,3 +152,26 @@ def test_hank2(create=False): assert flag == 0 assert jnp.allclose(x, test_x) + + +def test_solid(create=False): + + mod_dict = ep.parse(example_hank) + mod = ep.load(mod_dict) + _ = mod.solve_stst(tol=1e-8) + + shocks = ('e_beta', .01) + + x, flag = mod.find_path(shocks, use_solid_solver=True, + horizon=20, chunk_size=90) + + path = os.path.join(filepath, "test_storage", "hank_solid.npy") + + if create: + jnp.save(path, x) + print(f'Test file updated at {path}') + else: + test_x = jnp.load(path) + + assert flag == 0 + assert jnp.allclose(x, test_x) diff --git a/econpizza/test_storage/hank_solid.npy b/econpizza/test_storage/hank_solid.npy new file mode 100644 index 0000000000000000000000000000000000000000..085db9d970e0635165101e7e7d2ac77f3cbf7b7e GIT binary patch literal 2984 zcmd7Udo)yQ9|v%xgd&z(BA0Q^5L1~Gre8XpN`p~L9a2dec~b~QDpW3cRp*pW2~AN+ zZV6Fxaw_TFcmKi|LJ{m*yJGi$!T&)U!IORD<@HxF-F znf)>c&3Ev_{lm;149%SaelTYmng@i0MTGeU`-X(=;7_fu^4lHGC)UFQ{X+S~XA71c z)6kA>$uvA{_&**>0e>A@%avYT3FRmu=-#taMn*;g>MPFWCSS&&x^;~X$6gG?;$Q17 zJ|)wm9VGD2Tq9kV6LiZTjVuGN?m~4%Wg%F}u9QUK806NW+DhubF_d21O_Lp78M?l z03&~TnNXij($V}`V6Mmg6HciOMqXAFLh8cI>?U>tZ1S-+T3ypa^+&z6U02)OM>*=G zj%n+(?KRdTDB=35^6)R)~-NndUreLn6>9rbl?g>7pQ zSlV%XmM%x&9qju3V(|-TE~sYAxhw)@my7$f?jaB`b$ND<5}1{iWARFhPSTMV-E=6o zq6{YGgcn;0$EVa1aShl5!Y>*f$5@ojTl7ODi`N%z^VSSaZCMI^|T{Z z>JWS@3f*XnVCxYfcIQFz8(C|CAGmQ~iGoZ8w@l>@#`)NR~Q=aZ{xP zCW4)f6VA8W5me#%PYs{}ik-sV4(N#CbeCuBrG5ll)v}K+mJ&$q@^>FFp_6oY^5eLD z9}D1!^J+%$EFrv&w|6jY#bEPB$1LtXoa#TNQm=(~(2k>02XA7_V@*p0Oc&#UmiL55 zx!z}|Nds6_=E>;K7em`fc2VnB;(NXLCubH&pwwDn_%^XWWRi~5wg*F|&kNzr(~yq$ zgd-;@puDFYg9jtUpNeB}s{cO@%lDl5Ks%zOjybY1H4~-?qS$9+19}m7RMpw5&uIYV z`9}ofMq+rRo~>~F8v>`A<7wl>^|Dgs`*RnYQH1+6kNp&v6cY4`>rUDNhj$@)X0i&e0~#JF0@u=5e^yu2pjG97&IIi3iMCG zsea=P#|`a?{~PJxOC1kAOkS5866e9C_Mw6V!K-e2L#uua%rE61d}uBP_H3c(<|qRD z!O~T)Z6wf--(dYcNr%EKixb!Q#jx)C=t`V$ME#s#yznCi3*s6-I{%7O{TYm184O|$ z>DVfDT$j(P-9I0}jm^fDwo(w;KWGc z#iqX{>5z-}7)$RchGpXoX=W-yxbpKL)3OHxPsUpLma{n3@6G96|ACl8IyOih7p^nI zL-i2YO5Vj*3?j%f-S0f734?oHxqDifVu;vn#m}8Upvg^-IY!*qA+*Oai$y2t5Ts1< zc6;6eMXxPGUW9|KV!Bt2cL(($9z5xPROYeWY@@rb-h>KFn%Ty2*I@tkzI zb*^7;kr=wS#YH#~??=q4wuO_6B(Prn7mJ)FbdnBHdgCUXRRV)5%hK}*M^pD!?NTuY zr+qoPLsxOCKcV94fvGv9!$Im;>|~6^X(M>HQTy{5VqSC5j;zdd45m5HbY069!`1Ei zfkDK&c;Nh|rNsPp+2=`b?dc>Pq0t$)J1&<%wpBFe1>tx*aOT;+5C*?AcG{gQz^VQ= zmU~z#(MLL1Qpf4*$ maxit: - return True, (False, f"Maximum number of {maxit} iterations reached.") - - return False, (False, "") + r = True, (True, "The solution converged.") + elif jnp.isnan(err): + r = True, (False, "Function returns 'NaN's.") + elif cnt > maxit: + r = True, (False, f"Maximum number of {maxit} iterations reached.") + else: + r = False, (False, "") + return r + + # return False, (False, "") def newton_for_jvp(jvp_func, jacobian, x_init, verbose, tol=1e-8, maxit=200, nsteps=2, factor=1.5):