Skip to content

Commit

Permalink
Splitting fit_predict() into two methods, renaming 'percentages' to '…
Browse files Browse the repository at this point in the history
…self._betas', renaming 'weights' to 'sample_weight' to be consistent with sklearn
  • Loading branch information
dmnapolitano committed Mar 19, 2024
1 parent ffde8aa commit 71f29c9
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 37 deletions.
10 changes: 5 additions & 5 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray) -> np.ndarr

return transition_matrix.value

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None) -> np.ndarray:
def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = None) -> np.ndarray:
self._check_data_type(X)
self._check_data_type(Y)
self._check_any_element_nan_or_inf(X)
Expand All @@ -91,8 +91,8 @@ def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None =
X = self._rescale(X)
Y = self._rescale(Y)

weights = self._check_and_prepare_weights(X, Y, weights)
weights = self._check_and_prepare_weights(X, Y, sample_weight)

percentages = self.__solve(X, Y, weights)
self._transitions = np.diag(X_expected_totals) @ percentages
return percentages
self._betas = self.__solve(X, Y, weights)
self._transitions = np.diag(X_expected_totals) @ self._betas
return self
37 changes: 32 additions & 5 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,59 @@ class TransitionSolver(ABC):
"""

def __init__(self):
self._betas = None
self._transitions = None

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None):
def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = None):
"""
After this method finishes, transitions will be available in the `transitions` class member.
Parameters
----------
`X` : np.ndarray matrix or pandas.DataFrame of int
Must have the same number of rows as `Y` but can have any number of columns greater than the number of rows.
`Y` : np.ndarray matrix or pandas.DataFrame of int
Must have the same number of rows as `X` but can have any number of columns greater than the number of rows.
`weights` : list, np.ndarray, or pandas.Series of int, optional
`sample_weight` : list, np.ndarray, or pandas.Series of int, optional
Must have the same length (number of rows) as both `X` and `Y`.
Returns
-------
np.ndarray matrix of float of shape (number of columns in `X`) x (number of columns in `Y`).
`self` and populates `betas` with the beta coefficients determined by this solver.
`betas` is an np.ndarray matrix of float of shape (number of columns in `X`) x (number of columns in `Y`).
Each float represents the percent of how much of row x is part of column y.
"""
raise NotImplementedError

def predict(self, X: np.ndarray):
"""
Parameters
----------
`X` : np.ndarray matrix or pandas.DataFrame of int
Must have the same dimensions as the `X` supplied to `fit()`.
Returns
-------
`Y_hat`, np.ndarray of float of the same shape as Y.
"""
if self._betas is None:
raise RuntimeError("Solver must be fit before prediction can be performed.")
return X @ self._betas

@property
def transitions(self) -> np.ndarray:
return self._transitions

@property
def betas(self) -> np.ndarray:
"""
Returns
-------
The solved coefficients, an np.ndarray matrix of float of shape
(number of columns in `X`) x (number of columns in `Y`).
Each float represents the percent of how much of row x is part of column y.
Will return `None` if `fit()` hasn't been called yet.
"""
return self._betas

def _check_any_element_nan_or_inf(self, A: np.ndarray):
"""
Check whether any element in a matrix or vector is NaN or infinity
Expand Down
56 changes: 31 additions & 25 deletions tests/test_transition_matrix_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ def test_matrix_fit_predict():
]
)

expected = np.array([[0.760428, 0.239572], [0.216642, 0.783358]])
expected_betas = np.array([[0.760428, 0.239572], [0.216642, 0.783358]])
expected_yhat = np.array(
[
[1.19371187, 1.80628813],
[3.14785177, 3.85214823],
[5.10199167, 5.89800833],
[7.05613156, 7.94386844],
[9.01027146, 9.98972854],
[10.96441136, 12.03558864],
]
)

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver().fit(X, Y)
current_yhat = tms.predict(X)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(expected_yhat, current_yhat, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_with_weights():
Expand Down Expand Up @@ -62,11 +73,10 @@ def test_matrix_fit_predict_with_weights():

weights = np.array([500, 250, 125, 62.5, 31.25, 15.625])

expected = np.array([[0.737329, 0.262671], [0.230589, 0.769411]])
expected_betas = np.array([[0.737329, 0.262671], [0.230589, 0.769411]])

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y, weights=weights)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver().fit(X, Y, sample_weight=weights)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_not_strict():
Expand All @@ -92,11 +102,10 @@ def test_matrix_fit_predict_not_strict():
]
)

expected = np.array([[0.760451, 0.239558], [0.216624, 0.783369]])
expected_betas = np.array([[0.760451, 0.239558], [0.216624, 0.783369]])

tms = TransitionMatrixSolver(strict=False)
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver(strict=False).fit(X, Y)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)


def test_ridge_matrix_fit_predict():
Expand All @@ -122,11 +131,10 @@ def test_ridge_matrix_fit_predict():
]
)

expected = np.array([[0.479416, 0.520584], [0.455918, 0.544082]])
expected_betas = np.array([[0.479416, 0.520584], [0.455918, 0.544082]])

tms = TransitionMatrixSolver(lam=1)
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver(lam=1).fit(X, Y)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_pivoted():
Expand All @@ -152,7 +160,7 @@ def test_matrix_fit_predict_pivoted():
]
).T

expected = np.array(
expected_betas = np.array(
[
[0.68274443, 0.18437159, 0.06760119, 0.03363495, 0.0197597, 0.01188814],
[0.13541428, 0.48122828, 0.22128163, 0.0960816, 0.04540571, 0.02058852],
Expand All @@ -163,9 +171,8 @@ def test_matrix_fit_predict_pivoted():
]
)

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver().fit(X, Y)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_bad_dimensions():
Expand All @@ -191,7 +198,7 @@ def test_matrix_fit_predict_bad_dimensions():

tms = TransitionMatrixSolver()
with pytest.raises(ValueError):
tms.fit_predict(X, Y)
tms.fit(X, Y)


def test_matrix_fit_predict_pandas():
Expand Down Expand Up @@ -222,11 +229,10 @@ def test_matrix_fit_predict_pandas():
columns=["y1", "y2"],
)

expected = np.array([[0.760428, 0.239572], [0.216642, 0.783358]])
expected_betas = np.array([[0.760428, 0.239572], [0.216642, 0.783358]])

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)
tms = TransitionMatrixSolver().fit(X, Y)
np.testing.assert_allclose(expected_betas, tms.betas, rtol=RTOL, atol=ATOL)

except ImportError:
# pass this test through since pandas isn't a requirement for elex-solver
Expand Down
17 changes: 15 additions & 2 deletions tests/test_transition_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_superclass_fit_predict():
def test_superclass_fit():
with pytest.raises(NotImplementedError):
ts = TransitionSolver()
ts.fit_predict(None, None)
ts.fit(None, None)


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_superclass_predict():
with pytest.raises(RuntimeError):
ts = TransitionSolver()
ts.predict(None)


@patch.object(TransitionSolver, "__abstractmethods__", set())
Expand All @@ -19,6 +26,12 @@ def test_superclass_get_transitions():
assert ts.transitions is None


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_superclass_get_betas():
ts = TransitionSolver()
assert ts.betas is None


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_check_any_element_nan_or_inf_with_nan():
with pytest.raises(ValueError):
Expand Down

0 comments on commit 71f29c9

Please sign in to comment.