Skip to content

Commit

Permalink
add method for ev < 0
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Feb 9, 2023
1 parent a47c4e4 commit c705fc6
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 61 deletions.
50 changes: 32 additions & 18 deletions econpizza/parser/build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,18 @@ def get_stacked_func_dist(pars, func_backw, func_dist, func_eqns, stst, vfSS, di
return stacked_func_dist, backwards_sweep, forwards_sweep, second_sweep


def get_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose):
def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose):

st = time.time()

shocks = model.get("shocks") or ()
# get functions
func_eqns = model['context']["func_eqns"]
func_backw = model['context'].get('func_backw')
func_dist = model['context'].get('func_dist')
backwards_sweep = model['context']['backwards_sweep']
second_sweep = model['context']['second_sweep']

# get stuff for het-agent models
vfSS = model['steady_state'].get('decisions')
distSS = jnp.array(model['steady_state']['distributions'])[..., None]
decisions_outputSS = jnp.array(
model['steady_state']['decisions_output'])[..., None]

# get actual functions
func_raw, backwards_sweep, forwards_sweep, second_sweep = get_stacked_func_dist(
pars, func_backw, func_dist, func_eqns, stst, vfSS, distSS[..., 0], horizon, nvars)

# basis for steady state jacobian construction
basis = jnp.zeros((nvars*(horizon-1), nvars))
basis = basis.at[-nvars:, -nvars:].set(jnp.eye(nvars))
Expand All @@ -116,16 +108,38 @@ def get_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, verbose)
f2X = jacrev_func_eqns(stst[:, None], stst[:, None], stst[:, None],
stst, zshocks[:, 0], pars, distSS, decisions_outputSS)

# store everything
model['context']['func_raw'] = func_raw
model['context']['backwards_sweep'] = backwards_sweep
model['context']['forwards_sweep'] = forwards_sweep
model['jvp_func'] = lambda primals, tangens, x0, shocks: jax.jvp(
func_raw, (primals, x0, shocks), (tangens, jnp.zeros(nvars), zshocks))

if verbose:
duration = time.time() - st
print(
f"(get_derivatives:) Derivatives calculation done ({duration:1.3f}s).")

return f2X, f2do, do2x


def build_aggr_het_agent_funcs(model, nvars, pars, stst, zshocks, horizon):

shocks = model.get("shocks") or ()
# get functions
func_eqns = model['context']["func_eqns"]
func_backw = model['context'].get('func_backw')
func_dist = model['context'].get('func_dist')

# get stuff for het-agent models
vfSS = model['steady_state'].get('decisions')
distSS = jnp.array(model['steady_state']['distributions'])[..., None]
decisions_outputSS = jnp.array(
model['steady_state']['decisions_output'])[..., None]

# get actual functions
func_raw, backwards_sweep, forwards_sweep, second_sweep = get_stacked_func_dist(
pars, func_backw, func_dist, func_eqns, stst, vfSS, distSS[..., 0], horizon, nvars)

# store everything
model['context']['func_raw'] = func_raw
model['context']['backwards_sweep'] = backwards_sweep
model['context']['forwards_sweep'] = forwards_sweep
model['context']['second_sweep'] = second_sweep
model['context']['jvp_func'] = lambda primals, tangens, x0, shocks: jax.jvp(
func_raw, (primals, x0, shocks), (tangens, jnp.zeros(nvars), zshocks))
model['context']['vjp_func'] = lambda primals, tangens, x0, shocks: jax.vjp(
lambda x: func_raw(x, x0, shocks), primals)[1](tangens)[0]
18 changes: 11 additions & 7 deletions econpizza/solvers/shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import jax
import jax.numpy as jnp
from grgrlib.jaxed import newton_jax_jit, newton_jax_jittable
from grgrlib.jaxed import newton_jax_jit, jacfwd_and_val


msgs = (
Expand Down Expand Up @@ -112,10 +112,13 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime):
"""Solves for one period.
"""

res = newton_jax_jittable(lambda x: func(
XLag, x, XPrime, stst, shock, pars), XLastGuess)
# partial_func = jax.tree_util.Partial(func, XLag=XLag, XPrime=XPrime, XSS=stst, shocks=shock, pars=pars)
def partial_func(x): return func(XLag, x, XPrime, stst, shock, pars)
jav = jacfwd_and_val(partial_func)
partial_jav = jax.tree_util.Partial(jav)
res = newton_jax_jit(partial_jav, XLastGuess, verbose=False)

return res[0], res[2], res[3]
return res[0], res[3]

try:
for i in range(T):
Expand All @@ -138,14 +141,15 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime):
else:
tshock.at[:].set(0)

x_new, flag_root, flag_ftol = solve_current(
x_new, flag_ftol = solve_current(
pars, tshock, x[t], x[t + 1], x[t + 2])
flag_ftol = ~flag_ftol

x = x.at[t + 1].set(x_new)

flag_loc = flag_loc.at[0].set(flag_loc[0] or not flag_root)
flag_loc = flag_loc.at[0].set(flag_loc[0])
flag_loc = flag_loc.at[1].set(
flag_loc[2] or (not flag_ftol and flag_root))
flag_loc[2] or not flag_ftol)

flag = flag.at[2].set(flag[2] or jnp.any(jnp.isnan(x)))
flag = flag.at[3].set(flag[3] or jnp.any(jnp.isinf(x)))
Expand Down
10 changes: 6 additions & 4 deletions econpizza/solvers/solve_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import jax.numpy as jnp
from ..utilities.jacobian import get_stst_jacobian
from ..parser.build_functions import get_derivatives
from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives


def find_path_linear(model, shock=None, x0=None, horizon=300, verbose=True):
Expand Down Expand Up @@ -50,14 +50,16 @@ def find_path_linear(model, shock=None, x0=None, horizon=300, verbose=True):
x_stst = jnp.ones((horizon + 1, nvars)) * stst

# deal with shocks
zero_shocks = jnp.zeros((horizon-1, len(shocks)))
zero_shocks = jnp.zeros((horizon-1, len(shocks))).T

x0 = jnp.array(list(x0)) if x0 is not None else stst

if model['new_model_horizon'] != horizon:
# get derivatives via AD and compile functions
derivatives = get_derivatives(
model, nvars, pars, stst, x_stst, zero_shocks.T, horizon, verbose)
build_aggr_het_agent_funcs(
model, nvars, pars, stst, zero_shocks, horizon)
derivatives = get_stst_derivatives(
model, nvars, pars, stst, x_stst, zero_shocks, horizon, verbose)

# accumulate steady stat jacobian
get_stst_jacobian(model, derivatives, horizon, nvars, verbose)
Expand Down
40 changes: 27 additions & 13 deletions econpizza/solvers/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import jax
import time
import jax.numpy as jnp
from grgrlib.jaxed import *
from ..parser.build_functions import *
from ..utilities.jacobian import get_stst_jacobian
from grgrlib.jaxed import newton_jax_jit, jacrev_and_val
from ..parser.build_functions import build_aggr_het_agent_funcs, get_stst_derivatives
from ..utilities.jacobian import get_stst_jacobian, get_jac_and_value_sliced
from ..utilities.newton import newton_for_jvp, newton_for_banded_jac


Expand All @@ -16,6 +16,7 @@ def find_path_stacking(
shock=None,
x0=None,
horizon=300,
use_solid_solver=False,
verbose=True,
raise_errors=True,
**newton_args
Expand Down Expand Up @@ -89,22 +90,35 @@ def find_path_stacking(
else:
if model['new_model_horizon'] != horizon:
# get derivatives via AD and compile functions
derivatives = get_derivatives(
model, nvars, pars, stst, x_stst, jnp.zeros_like(shock_series).T, horizon, verbose)
zero_shocks = jnp.zeros_like(shock_series).T
build_aggr_het_agent_funcs(
model, nvars, pars, stst, zero_shocks, horizon)

if not use_solid_solver:
# get steady state partial jacobians
derivatives = get_stst_derivatives(
model, nvars, pars, stst, x_stst, zero_shocks, horizon, verbose)
# accumulate steady stat jacobian
get_stst_jacobian(model, derivatives, horizon, nvars, verbose)

# accumulate steady stat jacobian
get_stst_jacobian(model, derivatives, horizon, nvars, verbose)
# mark as compiled
model['new_model_horizon'] = horizon

# get jvp function and steady state jacobian
jvp_partial = jax.tree_util.Partial(
model['jvp_func'], x0=x0, shocks=shock_series.T)
jacobian = model['jac_factorized']

# actual newton iterations
x, flag, mess = newton_for_jvp(
jvp_partial, jacobian, x_init, verbose, **newton_args)
model['context']['jvp_func'], x0=x0, shocks=shock_series.T)
if not use_solid_solver:
jacobian = model['jac_factorized']
# actual newton iterations
x, flag, mess = newton_for_jvp(
jvp_partial, jacobian, x_init, verbose, **newton_args)
else:
# define function returning value and jacobian calculated in slices
value_and_jac_func = get_jac_and_value_sliced(
(horizon-1)*nvars, jvp_partial, newton_args)
x, _, _, flag = newton_jax_jit(
value_and_jac_func, x_init[1:-1].flatten(), **newton_args)
mess = ''
x_out = x_init.at[1:-1].set(x.reshape((horizon - 1, nvars)))

# some informative print messages
Expand Down
23 changes: 23 additions & 0 deletions econpizza/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,26 @@ def test_hank2(create=False):

assert flag == 0
assert jnp.allclose(x, test_x)


def test_solid(create=False):

mod_dict = ep.parse(example_hank)
mod = ep.load(mod_dict)
_ = mod.solve_stst(tol=1e-8)

shocks = ('e_beta', .01)

x, flag = mod.find_path(shocks, use_solid_solver=True,
horizon=20, chunk_size=90)

path = os.path.join(filepath, "test_storage", "hank_solid.npy")

if create:
jnp.save(path, x)
print(f'Test file updated at {path}')
else:
test_x = jnp.load(path)

assert flag == 0
assert jnp.allclose(x, test_x)
Binary file added econpizza/test_storage/hank_solid.npy
Binary file not shown.
54 changes: 53 additions & 1 deletion econpizza/utilities/jacobian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import time
import jax.numpy as jnp
from jax._src.api import partial


def accumulate(i_and_j, carry):
Expand All @@ -13,7 +14,8 @@ def accumulate(i_and_j, carry):


def get_stst_jacobian(model, derivatives, horizon, nvars, verbose):

"""Calculate the steady state jacobian
"""
st = time.time()

# load derivatives
Expand Down Expand Up @@ -43,3 +45,53 @@ def get_stst_jacobian(model, derivatives, horizon, nvars, verbose):
f"(get_jacobian:) Jacobian accumulation and decomposition done ({duration:1.3f}s).")

return 0


def vmapped_jvp(jvp, primals, tangents):
"""Compact version of jvp_vmap from grgrlib
"""
pushfwd = partial(jvp, primals)
y, jac = jax.vmap(pushfwd, out_axes=(None, -1), in_axes=-1)(tangents)
return y, jac


def jac_slicer(i, carry):
"""Calclulates a chunk of the jacobian
"""
(_, jac), (x, jvp, zeros_slice, marginal_base, chunk_size) = carry
# get base slice
base_slice = jax.lax.dynamic_update_slice(
zeros_slice, marginal_base, (i*chunk_size, len(x)))
# calculate slice of the jacobian
f, jac_slice = vmapped_jvp(jvp, x, base_slice)
# update jacobian
jac = jax.lax.dynamic_update_slice(jac, jac_slice, (0, i*chunk_size))
return (f, jac), (x, jvp, zeros_slice, marginal_base, chunk_size)


def jac_and_value_sliced(jvp, chunk_size, zero_slice, eye_chunk, x):
"""Calculate the value and jacobian at `x` while only evaluating chunks of the full jacobian at times. May be necessary due to memmory requirements.
"""
x_shape = len(x)
nloops = jnp.ceil(x_shape/chunk_size).astype(int)
init_vals = x, jnp.zeros((x_shape, x_shape))
args = x, jvp, zero_slice, eye_chunk, chunk_size
# in essence a wrapper around a for loop over `jac_slicer`
(f, jac), _ = jax.lax.fori_loop(0, nloops, jac_slicer, (init_vals, args))
return f, jac


def get_jac_and_value_sliced(dimx, jvp, newton_args):
"""Get the jac_and_value_sliced function. This is necessary because objects depending on chunk_size must be defined outsite the jitted function
"""
# get chunk_size from optional dictionary
if 'chunk_size' in newton_args:
chunk_size = newton_args['chunk_size']
newton_args.pop('chunk_size')
else:
chunk_size = 100

# define objects that depend on chunk_size
zero_slice = jnp.zeros((dimx, chunk_size))
eye_chunk = jnp.eye(chunk_size)
return jax.tree_util.Partial(jac_and_value_sliced, jvp, chunk_size, zero_slice, eye_chunk)
30 changes: 12 additions & 18 deletions econpizza/utilities/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"""

import jax
import time
import jax.numpy as jnp
from grgrlib.jaxed import *


def callback_func(cnt, err, dampening=None, ltime=None, verbose=True):
mess = f' Iteration {cnt:3d} | max error {err:.2e}'
if dampening is not None:
mess += f' | dampening {dampening:1.3f}'
if ltime is not None:
mess += f' | lapsed {ltime:3.4f}s'
if verbose:
print(mess)
from grgrlib.jaxed import callback_func, amax


def iteration_step(dummy, carry):
Expand Down Expand Up @@ -70,13 +61,16 @@ def check_status(err, cnt, maxit, tol):

# exit causes
if err < tol:
return True, (True, "The solution converged.")
if jnp.isnan(err):
return True, (False, "Function returns 'NaN's.")
if cnt > maxit:
return True, (False, f"Maximum number of {maxit} iterations reached.")

return False, (False, "")
r = True, (True, "The solution converged.")
elif jnp.isnan(err):
r = True, (False, "Function returns 'NaN's.")
elif cnt > maxit:
r = True, (False, f"Maximum number of {maxit} iterations reached.")
else:
r = False, (False, "")
return r

# return False, (False, "")


def newton_for_jvp(jvp_func, jacobian, x_init, verbose, tol=1e-8, maxit=200, nsteps=2, factor=1.5):
Expand Down

0 comments on commit c705fc6

Please sign in to comment.