Skip to content

Commit

Permalink
Weight accessor (#11)
Browse files Browse the repository at this point in the history
* defined accessor for Gibbs sampler

* Fix long_description (as it fails to upload)

* add accessor for variational fm

* fix test
  • Loading branch information
tohtsky authored May 28, 2022
1 parent 8ca5881 commit 5e77e1f
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 20 deletions.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"typing-extensions>=4.0.0",
]


TEST_BUILD = os.environ.get("TEST_BUILD", None) is not None
CURRENT_DIR = Path(__file__).resolve().parent
README_FILE = CURRENT_DIR / "README.md"


class get_eigen_include(object):
Expand Down Expand Up @@ -87,7 +87,8 @@ def local_scheme(version: Any) -> str:
url="https://github.com/tohtsky/myfm",
author_email="[email protected]",
description="Yet another Bayesian factorization machines.",
long_description="",
long_description=README_FILE.read_text(),
long_description_content_type="text/markdown",
ext_modules=ext_modules,
install_requires=install_requires,
cmdclass={"build_ext": build_ext},
Expand Down
35 changes: 35 additions & 0 deletions src/myfm/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,41 @@ class MyFMGibbsBase(
LearningHistory,
]
):
@property
def w0_samples(self) -> Optional[DenseArray]:
r"""Obtain samples for global bias `w0`. If the model is not fit yet, return `None`.
Returns:
Samples for lienar coefficients.
"""
if self.predictor_ is None:
return None
return np.asfarray([fm.w0 for fm in self.predictor_.samples])

@property
def w_samples(self) -> Optional[DenseArray]:
r"""Obtain the Gibbs samples for linear coefficients `w`. Returns `None` if the model is not fit yet.
Returns:
Samples for lienar coefficients.
The first dimension is for the sample index, and the second for the feature index.
"""
if self.predictor_ is None:
return None
return np.asfarray([fm.w for fm in self.predictor_.samples])

@property
def V_samples(self) -> Optional[DenseArray]:
r"""Obtain the Gibbs samples for factorized quadratic coefficient `V`. Returns `None` if the model is not fit yet.
Returns:
Samples for lienar coefficients.
The first dimension is for the sample index, the second for the feature index, and the third for the factorized dimension.
"""
if self.predictor_ is None:
return None
return np.asfarray([fm.V for fm in self.predictor_.samples])

def _predict_core(
self,
X: Optional[ArrayLike],
Expand Down
95 changes: 94 additions & 1 deletion src/myfm/variational.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, TypeVar

import numpy as np
import scipy.sparse as sps
Expand All @@ -17,11 +17,26 @@
REAL,
ArrayLike,
ClassifierMixin,
DenseArray,
MyFMBase,
RegressorMixin,
check_data_consistency,
)

ArrayOrDenseArray = TypeVar("ArrayOrDenseArray", DenseArray, float)


def runtime_error_to_optional(
fm: "MyFMVariationalBase",
retrieve_method: Callable[[VariationalFM], ArrayOrDenseArray],
) -> Optional[ArrayOrDenseArray]:
try:
predictor = fm._fetch_predictor()
except:
return None
weights = predictor.weights()
return retrieve_method(weights)


class MyFMVariationalBase(
MyFMBase[
Expand All @@ -31,6 +46,84 @@ class MyFMVariationalBase(
VariationalLearningHistory,
]
):
@property
def w0_mean(self) -> Optional[float]:
"""Mean of variational posterior distribution of global bias `w0`.
Returns:
Mean of variational posterior distribution of global bias `w0`.
"""

def _retrieve(fm: VariationalFM) -> float:
return fm.w0

return runtime_error_to_optional(self, _retrieve)

@property
def w0_var(self) -> Optional[float]:
"""Variance of variational posterior distribution of global bias `w0`.
Returns:
Variance of variational posterior distribution of global bias `w0`.
"""

def _retrieve(fm: VariationalFM) -> float:
return fm.w0_var

return runtime_error_to_optional(self, _retrieve)

@property
def w_mean(self) -> Optional[DenseArray]:
"""Mean of variational posterior distribution of linear coefficnent `w`.
Returns:
Mean of variational posterior distribution of linear coefficnent `w.
"""

def _retrieve(fm: VariationalFM) -> DenseArray:
return fm.w

return runtime_error_to_optional(self, _retrieve)

@property
def w_var(self) -> Optional[DenseArray]:
"""Variance of variational posterior distribution of linear coefficnent `w`.
Returns:
Variance of variational posterior distribution of linear coefficnent `w.
"""

def _retrieve(fm: VariationalFM) -> DenseArray:
return fm.w_var

return runtime_error_to_optional(self, _retrieve)

@property
def V_mean(self) -> Optional[DenseArray]:
"""Mean of variational posterior distribution of factorized quadratic coefficnent `V`.
Returns:
Mean of variational posterior distribution of factorized quadratic coefficient `w.
"""

def _retrieve(fm: VariationalFM) -> DenseArray:
return fm.V

return runtime_error_to_optional(self, _retrieve)

@property
def V_var(self) -> Optional[DenseArray]:
"""Variance of variational posterior distribution of factorized quadratic coefficnent `V`.
Returns:
Variance of variational posterior distribution of factorized quadratic coefficient `w.
"""

def _retrieve(fm: VariationalFM) -> DenseArray:
return fm.V_var

return runtime_error_to_optional(self, _retrieve)

@classmethod
def _train_core(
cls,
Expand Down
23 changes: 19 additions & 4 deletions tests/classification/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,37 @@ def test_middle_clf(
if use_libfm_callback:
np.testing.assert_allclose(fm.predict_proba(X), callback.predictions / 200)

vfm = VariationalFMClassifier(3).fit(
vfm_before_fit = VariationalFMClassifier(3)
assert vfm_before_fit.w0_mean is None
assert vfm_before_fit.w0_var is None
assert vfm_before_fit.w_mean is None
assert vfm_before_fit.w_var is None
assert vfm_before_fit.V_mean is None
assert vfm_before_fit.V_var is None

vfm = vfm_before_fit.fit(
X, y, X_test=X, y_test=y, n_iter=200 # , n_kept_samples=50
)

assert vfm.w0_mean is not None
assert vfm.w0_var is not None
assert vfm.w_mean is not None
assert vfm.w_var is not None
assert vfm.V_mean is not None
assert vfm.V_var is not None

assert fm.predictor_ is not None

last_samples = fm.predictor_.samples[-20:]

for i in range(3):
for j in range(i + 1, 3):
cross_term = stub_weight.factors[:, i].dot(stub_weight.factors[:, j])
m = vfm.predictor_.weights()
if abs(cross_term) < 0.5:
continue
sign = cross_term / abs(cross_term)
assert m.V[i].dot(m.V[j]) > sign * cross_term * 0.8
assert m.V[i].dot(m.V[j]) < sign * cross_term * 1.2
assert vfm.V_mean[i].dot(vfm.V_mean[j]) > sign * cross_term * 0.8
assert vfm.V_mean[i].dot(vfm.V_mean[j]) < sign * cross_term * 1.2

for s in last_samples:
sample_cross_term = s.V[i].dot(s.V[j])
Expand Down
38 changes: 26 additions & 12 deletions tests/regression/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from typing import Tuple
from typing import Optional, Tuple

import numpy as np
import pytest
from scipy import sparse as sps

from myfm import MyFMGibbsRegressor, VariationalFMRegressor
from myfm.base import DenseArray
from myfm.utils.callbacks import RegressionCallback

from ..test_utils import FMWeights


def assert_unwrap(x: Optional[DenseArray]) -> DenseArray:
assert x is not None
return x


@pytest.mark.parametrize("alpha_inv", [0.3, 1.0, 3])
def test_middle_reg(
alpha_inv: float,
Expand All @@ -22,25 +28,33 @@ def test_middle_reg(

callback = RegressionCallback(100, X_test=X, y_test=y)

fm = MyFMGibbsRegressor(3).fit(
fm_init = MyFMGibbsRegressor(3)
assert fm_init.w0_samples is None
assert fm_init.w_samples is None
assert fm_init.V_samples is None
fm = fm_init.fit(
X, y, X_test=X, y_test=y, n_iter=100, n_kept_samples=100, callback=callback
)

np.testing.assert_allclose(fm.predict(X), callback.predictions / 100)
vfm = VariationalFMRegressor(3).fit(X, y, X_test=X, y_test=y, n_iter=50)
vfm_weights = vfm.predictor_.weights()
hp_trance = fm.get_hyper_trace()
last_alphs = hp_trance["alpha"].iloc[-20:].values
hp_trace = fm.get_hyper_trace()
last_alphs = hp_trace["alpha"].iloc[-20:].values
assert np.all(last_alphs > ((1 / alpha_inv**2) / 2))
assert np.all(last_alphs < ((1 / alpha_inv**2) * 2))

last_samples = fm.predictor_.samples[-20:]
assert np.all([s.w0 < stub_weight.global_bias + 0.5 for s in last_samples])
assert np.all([s.w0 > stub_weight.global_bias - 0.5 for s in last_samples])
last_w0_samples = assert_unwrap(fm.w0_samples)[-20:]
assert np.all(last_w0_samples < (stub_weight.global_bias + 0.5))
assert np.all(last_w0_samples > (stub_weight.global_bias - 0.5))

last_w_samples = assert_unwrap(fm.w_samples)[-20:]

for w_ in last_w_samples:
assert np.all(w_ < (stub_weight.weight + 1.0))
assert np.all(w_ > (stub_weight.weight - 1.0))

for s in last_samples:
assert np.all(s.w < (stub_weight.weight + 1.0))
assert np.all(s.w > (stub_weight.weight - 1.0))
last_V_samples = assert_unwrap(fm.V_samples)[-20:]

for i in range(3):
for j in range(i + 1, 3):
Expand All @@ -52,7 +66,7 @@ def test_middle_reg(
assert vfm_cross_term > sign * cross_term * 0.8
assert vfm_cross_term < sign * cross_term * 1.25

for s in last_samples:
sample_cross_term = s.V[i].dot(s.V[j])
for V_ in last_V_samples:
sample_cross_term = V_[i].dot(V_[j])
assert sample_cross_term > sign * cross_term * 0.5
assert sample_cross_term < sign * cross_term * 2

0 comments on commit 5e77e1f

Please sign in to comment.