diff --git a/README.md b/README.md index eda0dbe..ec8a297 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ fm = test_myfm(df_train, df_test, rank=8, grouping=True) ## Examples for Relational Data format -Below is a toy movielens-like example which utilizes relational data format proposed in [3]. +Below is a toy movielens-like example that utilizes relational data format proposed in [3]. This example, however, is too simplistic to exhibit the computational advantage of this data format. For an example with drastically reduced computational complexity, see `examples/ml-100k-extended.ipynb`; diff --git a/doc/source/index.rst b/doc/source/index.rst index 2c2f367..3551f2a 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -10,10 +10,10 @@ myFM - Bayesian Factorization Machines in Python/C++ **myFM** is an unofficial implementation of Bayesian Factorization Machines in Python/C++. Notable features include: -* Implementation most functionalities of `libFM `_ MCMC engine (including grouping & relation block) -* A simpler and faster implementation with `Pybind11 `_ and `Eigen `_ +* Implementation of all corresponding functionalities in `libFM `_ MCMC engine (including grouping & relation block) +* A simpler and faster implementation using `Pybind11 `_ and `Eigen `_ * Gibbs sampling for **ordinal regression** with probit link function. See :ref:`the tutorial ` for its usage. -* Variational inference which converges faster and requires lower memory (but usually less accurate than the Gibbs sampling). +* Support variational inference, which converges faster and requires lower memory (but usually less accurate than the Gibbs sampling). In most cases, you can install the library from PyPI: :: diff --git a/doc/source/ordinal-regression.rst b/doc/source/ordinal-regression.rst index c924523..2d949a5 100644 --- a/doc/source/ordinal-regression.rst +++ b/doc/source/ordinal-regression.rst @@ -116,10 +116,10 @@ you can train our ordered probit regressor by y_train = df_train.rating.values y_test = df_test.rating.values - fm_grouped_ordered = myfm.MyFMOrderedProbit( + fm = myfm.MyFMOrderedProbit( rank=FM_RANK, random_seed=42, ) - fm_grouped_ordered.fit( + fm.fit( X_train, y_train - 1, n_iter=300, n_kept_samples=300, group_shapes=[len(group) for group in ohe.categories_] ) @@ -132,7 +132,7 @@ We can predict the class probability given ``X_test`` as .. testcode :: - p_ordinal = fm_grouped_ordered.predict_proba(X_test) + p_ordinal = fm.predict_proba(X_test) and the expected rating as @@ -154,11 +154,10 @@ which gives us RMSE=0.8906 and MAE=0.6985, a slight improvement over the regress To see why it had an advantage over regression, let us check the posterior samples for the cutpoint parameters. -You can access them via ``fm_grouped_ordered.predictor_.samples``: :: - cutpoints = np.vstack( - [ fm.cutpoints[0] - fm.w0 for fm in fm_grouped_ordered.predictor_.samples] - ) +.. testcode :: + + cutpoints = fm.cutpoint_samples - fm.w0_samples[:, None] You can see how rating boundaries vs cutpoints looks like. :: diff --git a/doc/source/relation-blocks.rst b/doc/source/relation-blocks.rst index f0b747f..2737f52 100644 --- a/doc/source/relation-blocks.rst +++ b/doc/source/relation-blocks.rst @@ -205,8 +205,8 @@ but the result should be the same up to floating point artifacts: .. testcode :: for i in range(3): - sample_naive = fm_naive.predictor_.samples[i].w - sample_rb = fm_rb.predictor_.samples[i].w + sample_naive = fm_naive.w_samples[i] + sample_rb = fm_rb.w_samples[i] assert(np.max(np.abs(sample_naive - sample_rb)) < 1e-5) # should print tiny numbers diff --git a/examples/oprobit_example.py b/examples/oprobit_example.py index 2a2727a..abab773 100644 --- a/examples/oprobit_example.py +++ b/examples/oprobit_example.py @@ -28,5 +28,4 @@ n_kept_samples=10000, ) -c0 = np.asfarray([s.cutpoints[0] for s in fm.predictor_.samples]) -print(c0.mean(axis=0)) +print(fm.cutpoint_samples.mean(axis=0)) diff --git a/src/myfm/gibbs.py b/src/myfm/gibbs.py index d738c5c..1dee991 100644 --- a/src/myfm/gibbs.py +++ b/src/myfm/gibbs.py @@ -530,3 +530,14 @@ def predict( result: ClassIndexArray = self.predict_proba(X, X_rel=X_rel).argmax(axis=1) return result + + @property + def cutpoint_samples(self) -> Optional[DenseArray]: + r"""Obtain samples for the cutpoints. If the model is not fit yet, return `None`. + + Returns: + Samples for cutpoints. + """ + if self.predictor_ is None: + return None + return np.asfarray([fm.cutpoints[0] for fm in self.predictor_.samples]) diff --git a/src/myfm/variational.py b/src/myfm/variational.py index 47179ab..47320df 100644 --- a/src/myfm/variational.py +++ b/src/myfm/variational.py @@ -17,13 +17,12 @@ REAL, ArrayLike, ClassifierMixin, - DenseArray, MyFMBase, RegressorMixin, check_data_consistency, ) -ArrayOrDenseArray = TypeVar("ArrayOrDenseArray", DenseArray, float) +ArrayOrDenseArray = TypeVar("ArrayOrDenseArray", np.ndarray, float) def runtime_error_to_optional( @@ -48,7 +47,8 @@ class MyFMVariationalBase( ): @property def w0_mean(self) -> Optional[float]: - """Mean of variational posterior distribution of global bias `w0`. + r"""Mean of variational posterior distribution of global bias `w0`. + If the model is not fit yet, returns `None`. Returns: Mean of variational posterior distribution of global bias `w0`. @@ -61,7 +61,8 @@ def _retrieve(fm: VariationalFM) -> float: @property def w0_var(self) -> Optional[float]: - """Variance of variational posterior distribution of global bias `w0`. + r"""Variance of variational posterior distribution of global bias `w0`. + If the model is not fit yet, returns `None`. Returns: Variance of variational posterior distribution of global bias `w0`. @@ -73,53 +74,57 @@ def _retrieve(fm: VariationalFM) -> float: return runtime_error_to_optional(self, _retrieve) @property - def w_mean(self) -> Optional[DenseArray]: - """Mean of variational posterior distribution of linear coefficnent `w`. + def w_mean(self) -> Optional[np.ndarray]: + r"""Mean of variational posterior distribution of linear coefficnent `w`. + If the model is not fit yet, returns `None`. Returns: - Mean of variational posterior distribution of linear coefficnent `w. + Mean of variational posterior distribution of linear coefficnent `w`. """ - def _retrieve(fm: VariationalFM) -> DenseArray: + def _retrieve(fm: VariationalFM) -> np.ndarray: return fm.w return runtime_error_to_optional(self, _retrieve) @property - def w_var(self) -> Optional[DenseArray]: - """Variance of variational posterior distribution of linear coefficnent `w`. + def w_var(self) -> Optional[np.ndarray]: + r"""Variance of variational posterior distribution of linear coefficnent `w`. + If the model is not fit yet, returns `None`. Returns: - Variance of variational posterior distribution of linear coefficnent `w. + Variance of variational posterior distribution of linear coefficnent `w`. """ - def _retrieve(fm: VariationalFM) -> DenseArray: + def _retrieve(fm: VariationalFM) -> np.ndarray: return fm.w_var return runtime_error_to_optional(self, _retrieve) @property - def V_mean(self) -> Optional[DenseArray]: - """Mean of variational posterior distribution of factorized quadratic coefficnent `V`. + def V_mean(self) -> Optional[np.ndarray]: + r"""Mean of variational posterior distribution of factorized quadratic coefficnent `V`. + If the model is not fit yet, returns `None`. Returns: - Mean of variational posterior distribution of factorized quadratic coefficient `w. + Mean of variational posterior distribution of factorized quadratic coefficient `V`. """ - def _retrieve(fm: VariationalFM) -> DenseArray: + def _retrieve(fm: VariationalFM) -> np.ndarray: return fm.V return runtime_error_to_optional(self, _retrieve) @property - def V_var(self) -> Optional[DenseArray]: - """Variance of variational posterior distribution of factorized quadratic coefficnent `V`. + def V_var(self) -> Optional[np.ndarray]: + r"""Variance of variational posterior distribution of factorized quadratic coefficnent `V`. + If the model is not fit yet, returns `None`. Returns: - Variance of variational posterior distribution of factorized quadratic coefficient `w. + Variance of variational posterior distribution of factorized quadratic coefficient `V`. """ - def _retrieve(fm: VariationalFM) -> DenseArray: + def _retrieve(fm: VariationalFM) -> np.ndarray: return fm.V_var return runtime_error_to_optional(self, _retrieve) diff --git a/tests/oprobit/test_oprobit_1dim.py b/tests/oprobit/test_oprobit_1dim.py index b5d1b51..6d655f5 100644 --- a/tests/oprobit/test_oprobit_1dim.py +++ b/tests/oprobit/test_oprobit_1dim.py @@ -31,8 +31,8 @@ def test_oprobit(use_libfm_callback: bool) -> None: ) assert fm.predictor_ is not None - for sample in fm.predictor_.samples[-10:]: - cp_1, cp_2, cp_3 = sample.cutpoints[0] + for cutpoint_sample in fm.cutpoint_samples[-10:]: + cp_1, cp_2, cp_3 = cutpoint_sample assert abs(cp_1) < 0.25 assert abs(cp_2 - cp_1 - 0.5) < 0.25 assert abs(cp_3 - cp_1 - 1.5) < 0.25 diff --git a/tests/regression/test_block.py b/tests/regression/test_block.py index e1bf8a0..708b137 100644 --- a/tests/regression/test_block.py +++ b/tests/regression/test_block.py @@ -57,27 +57,21 @@ def test_block_vfm() -> None: y, n_iter=100, ) - fm_blocked = VariationalFMRegressor(3).fit( + fm_blocked_serialized = VariationalFMRegressor(3).fit( tm_column, y, blocks, n_iter=100, ) - assert fm_flatten.predictor_ is not None - assert fm_blocked.predictor_ is not None with tempfile.TemporaryFile() as temp_fs: - pickle.dump(fm_blocked, temp_fs) - del fm_blocked + pickle.dump(fm_blocked_serialized, temp_fs) + del fm_blocked_serialized temp_fs.seek(0) - fm_blocked = pickle.load(temp_fs) + fm_blocked: VariationalFMRegressor = pickle.load(temp_fs) - np.testing.assert_allclose( - fm_flatten.predictor_.weights().w, fm_blocked.predictor_.weights().w - ) - np.testing.assert_allclose( - fm_flatten.predictor_.weights().V, fm_blocked.predictor_.weights().V - ) + np.testing.assert_allclose(fm_flatten.w_mean, fm_blocked.w_mean) + np.testing.assert_allclose(fm_flatten.V_mean, fm_blocked.V_mean) predicton_flatten = fm_flatten.predict(tm_column, blocks) predicton_blocked = fm_blocked.predict(X_flatten) np.testing.assert_allclose(predicton_flatten, predicton_blocked)