Skip to content

Commit

Permalink
A version of TransitionSolver that inherets from LinearSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Mar 19, 2024
1 parent b5a2e07 commit 73f0c9e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = N

weights = self._check_and_prepare_weights(X, Y, sample_weight)

self._betas = self.__solve(X, Y, weights)
self._transitions = np.diag(X_expected_totals) @ self._betas
self.coefficients = self.__solve(X, Y, weights)
self._transitions = np.diag(X_expected_totals) @ self.coefficients
return self
19 changes: 6 additions & 13 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import logging
import warnings
from abc import ABC

import numpy as np

from elexsolver.logging import initialize_logging
from elexsolver.LinearSolver import LinearSolver

initialize_logging()

LOG = logging.getLogger(__name__)


class TransitionSolver(ABC):
class TransitionSolver(LinearSolver):
"""
Abstract class for transition solvers.
"""

def __init__(self):
self._betas = None
super().__init__()
self._transitions = None

def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = None):
Expand Down Expand Up @@ -50,9 +50,9 @@ def predict(self, X: np.ndarray) -> np.ndarray:
-------
`Y_hat`, np.ndarray of float of the same shape as Y.
"""
if self._betas is None:
if self.coefficients is None:
raise RuntimeError("Solver must be fit before prediction can be performed.")
return X @ self._betas
return X @ self.coefficients

@property
def transitions(self) -> np.ndarray:
Expand All @@ -68,14 +68,7 @@ def betas(self) -> np.ndarray:
Each float represents the percent of how much of row x is part of column y.
Will return `None` if `fit()` hasn't been called yet.
"""
return self._betas

def _check_any_element_nan_or_inf(self, A: np.ndarray):
"""
Check whether any element in a matrix or vector is NaN or infinity
"""
if np.any(np.isnan(A)) or np.any(np.isinf(A)):
raise ValueError("Matrix contains NaN or Infinity.")
return self.coefficients

def _check_data_type(self, A: np.ndarray):
if not np.all(A.astype("int64") == A):
Expand Down

0 comments on commit 73f0c9e

Please sign in to comment.