diff --git a/README.md b/README.md
index eda0dbe..ec8a297 100644
--- a/README.md
+++ b/README.md
@@ -123,7 +123,7 @@ fm = test_myfm(df_train, df_test, rank=8, grouping=True)
## Examples for Relational Data format
-Below is a toy movielens-like example which utilizes relational data format proposed in [3].
+Below is a toy movielens-like example that utilizes relational data format proposed in [3].
This example, however, is too simplistic to exhibit the computational advantage of this data format. For an example with drastically reduced computational complexity, see `examples/ml-100k-extended.ipynb`;
diff --git a/doc/source/index.rst b/doc/source/index.rst
index 2c2f367..3551f2a 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -10,10 +10,10 @@ myFM - Bayesian Factorization Machines in Python/C++
**myFM** is an unofficial implementation of Bayesian Factorization Machines in Python/C++.
Notable features include:
-* Implementation most functionalities of `libFM `_ MCMC engine (including grouping & relation block)
-* A simpler and faster implementation with `Pybind11 `_ and `Eigen `_
+* Implementation of all corresponding functionalities in `libFM `_ MCMC engine (including grouping & relation block)
+* A simpler and faster implementation using `Pybind11 `_ and `Eigen `_
* Gibbs sampling for **ordinal regression** with probit link function. See :ref:`the tutorial ` for its usage.
-* Variational inference which converges faster and requires lower memory (but usually less accurate than the Gibbs sampling).
+* Support variational inference, which converges faster and requires lower memory (but usually less accurate than the Gibbs sampling).
In most cases, you can install the library from PyPI: ::
diff --git a/doc/source/ordinal-regression.rst b/doc/source/ordinal-regression.rst
index c924523..2d949a5 100644
--- a/doc/source/ordinal-regression.rst
+++ b/doc/source/ordinal-regression.rst
@@ -116,10 +116,10 @@ you can train our ordered probit regressor by
y_train = df_train.rating.values
y_test = df_test.rating.values
- fm_grouped_ordered = myfm.MyFMOrderedProbit(
+ fm = myfm.MyFMOrderedProbit(
rank=FM_RANK, random_seed=42,
)
- fm_grouped_ordered.fit(
+ fm.fit(
X_train, y_train - 1, n_iter=300, n_kept_samples=300,
group_shapes=[len(group) for group in ohe.categories_]
)
@@ -132,7 +132,7 @@ We can predict the class probability given ``X_test`` as
.. testcode ::
- p_ordinal = fm_grouped_ordered.predict_proba(X_test)
+ p_ordinal = fm.predict_proba(X_test)
and the expected rating as
@@ -154,11 +154,10 @@ which gives us RMSE=0.8906 and MAE=0.6985, a slight improvement over the regress
To see why it had an advantage over regression, let us check
the posterior samples for the cutpoint parameters.
-You can access them via ``fm_grouped_ordered.predictor_.samples``: ::
- cutpoints = np.vstack(
- [ fm.cutpoints[0] - fm.w0 for fm in fm_grouped_ordered.predictor_.samples]
- )
+.. testcode ::
+
+ cutpoints = fm.cutpoint_samples - fm.w0_samples[:, None]
You can see how rating boundaries vs cutpoints looks like. ::
diff --git a/doc/source/relation-blocks.rst b/doc/source/relation-blocks.rst
index f0b747f..2737f52 100644
--- a/doc/source/relation-blocks.rst
+++ b/doc/source/relation-blocks.rst
@@ -205,8 +205,8 @@ but the result should be the same up to floating point artifacts:
.. testcode ::
for i in range(3):
- sample_naive = fm_naive.predictor_.samples[i].w
- sample_rb = fm_rb.predictor_.samples[i].w
+ sample_naive = fm_naive.w_samples[i]
+ sample_rb = fm_rb.w_samples[i]
assert(np.max(np.abs(sample_naive - sample_rb)) < 1e-5)
# should print tiny numbers
diff --git a/examples/oprobit_example.py b/examples/oprobit_example.py
index 2a2727a..abab773 100644
--- a/examples/oprobit_example.py
+++ b/examples/oprobit_example.py
@@ -28,5 +28,4 @@
n_kept_samples=10000,
)
-c0 = np.asfarray([s.cutpoints[0] for s in fm.predictor_.samples])
-print(c0.mean(axis=0))
+print(fm.cutpoint_samples.mean(axis=0))
diff --git a/src/myfm/gibbs.py b/src/myfm/gibbs.py
index d738c5c..1dee991 100644
--- a/src/myfm/gibbs.py
+++ b/src/myfm/gibbs.py
@@ -530,3 +530,14 @@ def predict(
result: ClassIndexArray = self.predict_proba(X, X_rel=X_rel).argmax(axis=1)
return result
+
+ @property
+ def cutpoint_samples(self) -> Optional[DenseArray]:
+ r"""Obtain samples for the cutpoints. If the model is not fit yet, return `None`.
+
+ Returns:
+ Samples for cutpoints.
+ """
+ if self.predictor_ is None:
+ return None
+ return np.asfarray([fm.cutpoints[0] for fm in self.predictor_.samples])
diff --git a/src/myfm/variational.py b/src/myfm/variational.py
index 47179ab..47320df 100644
--- a/src/myfm/variational.py
+++ b/src/myfm/variational.py
@@ -17,13 +17,12 @@
REAL,
ArrayLike,
ClassifierMixin,
- DenseArray,
MyFMBase,
RegressorMixin,
check_data_consistency,
)
-ArrayOrDenseArray = TypeVar("ArrayOrDenseArray", DenseArray, float)
+ArrayOrDenseArray = TypeVar("ArrayOrDenseArray", np.ndarray, float)
def runtime_error_to_optional(
@@ -48,7 +47,8 @@ class MyFMVariationalBase(
):
@property
def w0_mean(self) -> Optional[float]:
- """Mean of variational posterior distribution of global bias `w0`.
+ r"""Mean of variational posterior distribution of global bias `w0`.
+ If the model is not fit yet, returns `None`.
Returns:
Mean of variational posterior distribution of global bias `w0`.
@@ -61,7 +61,8 @@ def _retrieve(fm: VariationalFM) -> float:
@property
def w0_var(self) -> Optional[float]:
- """Variance of variational posterior distribution of global bias `w0`.
+ r"""Variance of variational posterior distribution of global bias `w0`.
+ If the model is not fit yet, returns `None`.
Returns:
Variance of variational posterior distribution of global bias `w0`.
@@ -73,53 +74,57 @@ def _retrieve(fm: VariationalFM) -> float:
return runtime_error_to_optional(self, _retrieve)
@property
- def w_mean(self) -> Optional[DenseArray]:
- """Mean of variational posterior distribution of linear coefficnent `w`.
+ def w_mean(self) -> Optional[np.ndarray]:
+ r"""Mean of variational posterior distribution of linear coefficnent `w`.
+ If the model is not fit yet, returns `None`.
Returns:
- Mean of variational posterior distribution of linear coefficnent `w.
+ Mean of variational posterior distribution of linear coefficnent `w`.
"""
- def _retrieve(fm: VariationalFM) -> DenseArray:
+ def _retrieve(fm: VariationalFM) -> np.ndarray:
return fm.w
return runtime_error_to_optional(self, _retrieve)
@property
- def w_var(self) -> Optional[DenseArray]:
- """Variance of variational posterior distribution of linear coefficnent `w`.
+ def w_var(self) -> Optional[np.ndarray]:
+ r"""Variance of variational posterior distribution of linear coefficnent `w`.
+ If the model is not fit yet, returns `None`.
Returns:
- Variance of variational posterior distribution of linear coefficnent `w.
+ Variance of variational posterior distribution of linear coefficnent `w`.
"""
- def _retrieve(fm: VariationalFM) -> DenseArray:
+ def _retrieve(fm: VariationalFM) -> np.ndarray:
return fm.w_var
return runtime_error_to_optional(self, _retrieve)
@property
- def V_mean(self) -> Optional[DenseArray]:
- """Mean of variational posterior distribution of factorized quadratic coefficnent `V`.
+ def V_mean(self) -> Optional[np.ndarray]:
+ r"""Mean of variational posterior distribution of factorized quadratic coefficnent `V`.
+ If the model is not fit yet, returns `None`.
Returns:
- Mean of variational posterior distribution of factorized quadratic coefficient `w.
+ Mean of variational posterior distribution of factorized quadratic coefficient `V`.
"""
- def _retrieve(fm: VariationalFM) -> DenseArray:
+ def _retrieve(fm: VariationalFM) -> np.ndarray:
return fm.V
return runtime_error_to_optional(self, _retrieve)
@property
- def V_var(self) -> Optional[DenseArray]:
- """Variance of variational posterior distribution of factorized quadratic coefficnent `V`.
+ def V_var(self) -> Optional[np.ndarray]:
+ r"""Variance of variational posterior distribution of factorized quadratic coefficnent `V`.
+ If the model is not fit yet, returns `None`.
Returns:
- Variance of variational posterior distribution of factorized quadratic coefficient `w.
+ Variance of variational posterior distribution of factorized quadratic coefficient `V`.
"""
- def _retrieve(fm: VariationalFM) -> DenseArray:
+ def _retrieve(fm: VariationalFM) -> np.ndarray:
return fm.V_var
return runtime_error_to_optional(self, _retrieve)
diff --git a/tests/oprobit/test_oprobit_1dim.py b/tests/oprobit/test_oprobit_1dim.py
index b5d1b51..6d655f5 100644
--- a/tests/oprobit/test_oprobit_1dim.py
+++ b/tests/oprobit/test_oprobit_1dim.py
@@ -31,8 +31,8 @@ def test_oprobit(use_libfm_callback: bool) -> None:
)
assert fm.predictor_ is not None
- for sample in fm.predictor_.samples[-10:]:
- cp_1, cp_2, cp_3 = sample.cutpoints[0]
+ for cutpoint_sample in fm.cutpoint_samples[-10:]:
+ cp_1, cp_2, cp_3 = cutpoint_sample
assert abs(cp_1) < 0.25
assert abs(cp_2 - cp_1 - 0.5) < 0.25
assert abs(cp_3 - cp_1 - 1.5) < 0.25
diff --git a/tests/regression/test_block.py b/tests/regression/test_block.py
index e1bf8a0..708b137 100644
--- a/tests/regression/test_block.py
+++ b/tests/regression/test_block.py
@@ -57,27 +57,21 @@ def test_block_vfm() -> None:
y,
n_iter=100,
)
- fm_blocked = VariationalFMRegressor(3).fit(
+ fm_blocked_serialized = VariationalFMRegressor(3).fit(
tm_column,
y,
blocks,
n_iter=100,
)
- assert fm_flatten.predictor_ is not None
- assert fm_blocked.predictor_ is not None
with tempfile.TemporaryFile() as temp_fs:
- pickle.dump(fm_blocked, temp_fs)
- del fm_blocked
+ pickle.dump(fm_blocked_serialized, temp_fs)
+ del fm_blocked_serialized
temp_fs.seek(0)
- fm_blocked = pickle.load(temp_fs)
+ fm_blocked: VariationalFMRegressor = pickle.load(temp_fs)
- np.testing.assert_allclose(
- fm_flatten.predictor_.weights().w, fm_blocked.predictor_.weights().w
- )
- np.testing.assert_allclose(
- fm_flatten.predictor_.weights().V, fm_blocked.predictor_.weights().V
- )
+ np.testing.assert_allclose(fm_flatten.w_mean, fm_blocked.w_mean)
+ np.testing.assert_allclose(fm_flatten.V_mean, fm_blocked.V_mean)
predicton_flatten = fm_flatten.predict(tm_column, blocks)
predicton_blocked = fm_blocked.predict(X_flatten)
np.testing.assert_allclose(predicton_flatten, predicton_blocked)