Skip to content

Commit

Permalink
Adding the ability to compute a prediction (credible) interval with t…
Browse files Browse the repository at this point in the history
…he EI solver and also experimenting with different sampler options there
  • Loading branch information
dmnapolitano committed Oct 12, 2023
1 parent 79076a7 commit 2254a65
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions src/elexsolver/EITransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ class EITransitionSolver(TransitionSolver):
Journal of Open Source Software, 6(64), 3397, https://doi.org/10.21105/joss.03397
"""

def __init__(self, n: np.ndarray, alpha=4, beta=0.5, sampling_chains=1):
def __init__(self, n: np.ndarray, alpha=4, beta=0.5, sampling_chains=1, random_seed=None):
super().__init__()
self._n = n
self._alpha = alpha # lmbda1 in PyEI
self._beta = beta # lmbda2 in PyEI, supplied as an int then used as 1 / lmbda2
self._chains = sampling_chains
self._sampled = None # will not be None after model-fit
self._seed = random_seed

# class members that are instantiated during model-fit
self._sampled = None
self._X_totals = None

def mean_absolute_error(self, X, Y):
y_pred = self._get_expected_totals(X)
Expand Down Expand Up @@ -78,7 +82,7 @@ def fit_predict(self, X, Y):
)
try:
# TODO: allow other samplers; this one is very good but slow
model_trace = pm.sample(chains=self._chains)
model_trace = pm.sample(chains=self._chains, random_seed=self._seed, nuts_sampler="numpyro")
except Exception as e:
LOG.debug(model.debug())
raise e
Expand All @@ -91,11 +95,37 @@ def fit_predict(self, X, Y):
self._sampled = np.transpose(samples_summed_across / X.T.sum(axis=0).values, axes=(1, 2, 0))

posterior_mean_rxc = self._sampled.mean(axis=0)
X_totals = self._get_expected_totals(np.transpose(X))
# to go from inferences to transitions
transitions = []
for col in posterior_mean_rxc.T:
transitions.append(col * X_totals)
transitions = np.array(transitions).T
self._X_totals = self._get_expected_totals(np.transpose(X))
transitions = self._get_transitions(posterior_mean_rxc)
LOG.info("MAE = {}".format(np.around(self.mean_absolute_error(transitions, Y), 4)))
return transitions

def _get_transitions(self, A: np.ndarray):
# to go from inferences to transitions
transitions = []
for col in A.T:
transitions.append(col * self._X_totals)
return np.array(transitions).T

def get_prediction_interval(self, pi):
"""
Note: this is actually a credible interval, not a prediction interval.
"""
if pi <= 1:
pi = pi * 100
if pi < 0 or pi > 100:
raise ValueError(f"Invalid prediction interval {pi}.")

lower = (100 - pi) / 2
upper = pi + lower
A_dict = {
lower: np.zeros((self._sampled.shape[1], self._sampled.shape[2])),
upper: np.zeros((self._sampled.shape[1], self._sampled.shape[2])),
}

for ci in [lower, upper]:
for i in range(0, self._sampled.shape[1]):
for j in range(0, self._sampled.shape[2]):
A_dict[ci][i][j] = np.percentile(self._sampled[:, i, j], ci)

return (self._get_transitions(A_dict[lower]), self._get_transitions(A_dict[upper]))

0 comments on commit 2254a65

Please sign in to comment.