From 2254a65005a83a1aa9ab52741f94defc254d6e71 Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Thu, 12 Oct 2023 09:54:45 -0400 Subject: [PATCH] Adding the ability to compute a prediction (credible) interval with the EI solver and also experimenting with different sampler options there --- src/elexsolver/EITransitionSolver.py | 48 ++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/elexsolver/EITransitionSolver.py b/src/elexsolver/EITransitionSolver.py index acbae345..1c155709 100644 --- a/src/elexsolver/EITransitionSolver.py +++ b/src/elexsolver/EITransitionSolver.py @@ -19,13 +19,17 @@ class EITransitionSolver(TransitionSolver): Journal of Open Source Software, 6(64), 3397, https://doi.org/10.21105/joss.03397 """ - def __init__(self, n: np.ndarray, alpha=4, beta=0.5, sampling_chains=1): + def __init__(self, n: np.ndarray, alpha=4, beta=0.5, sampling_chains=1, random_seed=None): super().__init__() self._n = n self._alpha = alpha # lmbda1 in PyEI self._beta = beta # lmbda2 in PyEI, supplied as an int then used as 1 / lmbda2 self._chains = sampling_chains - self._sampled = None # will not be None after model-fit + self._seed = random_seed + + # class members that are instantiated during model-fit + self._sampled = None + self._X_totals = None def mean_absolute_error(self, X, Y): y_pred = self._get_expected_totals(X) @@ -78,7 +82,7 @@ def fit_predict(self, X, Y): ) try: # TODO: allow other samplers; this one is very good but slow - model_trace = pm.sample(chains=self._chains) + model_trace = pm.sample(chains=self._chains, random_seed=self._seed, nuts_sampler="numpyro") except Exception as e: LOG.debug(model.debug()) raise e @@ -91,11 +95,37 @@ def fit_predict(self, X, Y): self._sampled = np.transpose(samples_summed_across / X.T.sum(axis=0).values, axes=(1, 2, 0)) posterior_mean_rxc = self._sampled.mean(axis=0) - X_totals = self._get_expected_totals(np.transpose(X)) - # to go from inferences to transitions - transitions = [] - for col in posterior_mean_rxc.T: - transitions.append(col * X_totals) - transitions = np.array(transitions).T + self._X_totals = self._get_expected_totals(np.transpose(X)) + transitions = self._get_transitions(posterior_mean_rxc) LOG.info("MAE = {}".format(np.around(self.mean_absolute_error(transitions, Y), 4))) return transitions + + def _get_transitions(self, A: np.ndarray): + # to go from inferences to transitions + transitions = [] + for col in A.T: + transitions.append(col * self._X_totals) + return np.array(transitions).T + + def get_prediction_interval(self, pi): + """ + Note: this is actually a credible interval, not a prediction interval. + """ + if pi <= 1: + pi = pi * 100 + if pi < 0 or pi > 100: + raise ValueError(f"Invalid prediction interval {pi}.") + + lower = (100 - pi) / 2 + upper = pi + lower + A_dict = { + lower: np.zeros((self._sampled.shape[1], self._sampled.shape[2])), + upper: np.zeros((self._sampled.shape[1], self._sampled.shape[2])), + } + + for ci in [lower, upper]: + for i in range(0, self._sampled.shape[1]): + for j in range(0, self._sampled.shape[2]): + A_dict[ci][i][j] = np.percentile(self._sampled[:, i, j], ci) + + return (self._get_transitions(A_dict[lower]), self._get_transitions(A_dict[upper]))