Skip to content

Commit

Permalink
Initial check-in of EI-based transition solver
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Oct 6, 2023
1 parent 3d6aef0 commit 9ebe5c5
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/elexsolver/EITransitionSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging

import pymc as pm
import numpy as np

from elexsolver.logging import initialize_logging
from elexsolver.TransitionSolver import TransitionSolver

initialize_logging()

LOG = logging.getLogger(__name__)


class EITransitionSolver(TransitionSolver):
"""
A (voter) transition solver based on RxC ecological inference.
Largely adapted from version 1.0.1 of
Knudson et al., (2021). PyEI: A Python package for ecological inference.
Journal of Open Source Software, 6(64), 3397, https://doi.org/10.21105/joss.03397
"""

def __init__(self, n: np.ndarray, alpha=4, beta=0.5, sampling_chains=1):
super().__init__()
self._n = n
self._alpha = alpha # lmbda1 in PyEI
self._beta = beta # lmbda2 in PyEI, supplied as an int then used as 1 / lmbda2
self._chains = sampling_chains
self._sampled = None # will not be None after model-fit

def mean_absolute_error(self, X, Y):
# x = self._get_expected_totals(X)
# y = self._get_expected_totals(Y)

# absolute_errors = np.abs(np.matmul(x, self._transition_matrix) - y)
# error_sum = np.sum(absolute_errors)
# mae = error_sum / len(absolute_errors)

# return mae
return 0 # TODO

def fit_predict(self, X, Y):
self._check_any_element_nan_or_inf(X)
self._check_any_element_nan_or_inf(Y)
self._check_percentages(X)
self._check_percentages(Y)

# TODO: check if these matrices are (long x short), then transpose
# currently assuming this is the case since the other solver expects (long x short)
X = np.transpose(X)
Y = np.transpose(Y)

num_units = len(self._n) # should be the same as the number of units in Y
num_rows = X.shape[0] # number of things in X that are being transitioned "from"
num_cols = Y.shape[0] # number of things in Y that are being transitioned "to"

# reshaping and rounding
Y_obs = np.swapaxes(Y * self._n, 0, 1).round()
X_extended = np.expand_dims(X, axis=2)
X_extended = np.repeat(X_extended, num_cols, axis=2)
X_extended = np.swapaxes(X_extended, 0, 1)

with pm.Model() as model:
conc_params = pm.Gamma(
"conc_params", alpha=self._alpha, beta=self._beta, shape=(num_rows, num_cols)
)
beta = pm.Dirichlet("beta", a=conc_params, shape=(num_units, num_rows, num_cols))
theta = (X_extended * beta).sum(axis=1)
yhat = pm.Multinomial(
"transfers",
n=self._n,
p=theta,
observed=Y_obs,
shape=(num_units, num_cols),
)
# TODO: allow other samplers; this one is very good but slow
model_trace = pm.sample(chains=self._chains)

b_values = np.transpose(
model_trace["posterior"]["beta"].stack(all_draws=["chain", "draw"]).values, axes=(3, 0, 1, 2))
samples_converted = np.transpose(b_values, axes=(3, 0, 1, 2)) * X.T.values
samples_summed_across = samples_converted.sum(axis=2)
self._sampled = np.transpose(samples_summed_across / X.T.sum(axis=0).values, axes=(1, 2, 0))

posterior_mean_rxc = self._sampled.mean(axis=0)
X_totals = self._get_expected_totals(np.transpose(X))
# TODO
# LOG.info("MAE = {}".format(np.around(self.mean_absolute_error(X, Y), 4)))
# to go from inferences to transitions
transitions = []
for col in posterior_mean_rxc.T:
transitions.append(col * X_totals)
return np.array(transitions).T

0 comments on commit 9ebe5c5

Please sign in to comment.