diff --git a/mellon/base_model.py b/mellon/base_model.py index 5dcd292..b022b2f 100644 --- a/mellon/base_model.py +++ b/mellon/base_model.py @@ -127,6 +127,7 @@ def __repr__(self): L = object_str(self.L, ["cells", "ranks"]) nn_distances = object_str(self.nn_distances, ["cells"]) initial_value = object_str(self.initial_value, ["ranks"]) + d = object_str(self.d, ["cells"]) string = ( f"{name}(" f"\n cov_func_curry={self.cov_func_curry}," @@ -137,7 +138,7 @@ def __repr__(self): f"\n optimizer={self.optimizer}," f"\n landmarks={landmarks}," f"\n nn_distances={nn_distances}," - f"\n d={self.d}," + f"\n d={d}," f"\n mu={self.mu}," f"\n ls={self.ls}," f"\n ls_factor={self.ls_factor}," diff --git a/mellon/base_predictor.py b/mellon/base_predictor.py index 6c366f3..7aeeaca 100644 --- a/mellon/base_predictor.py +++ b/mellon/base_predictor.py @@ -19,6 +19,7 @@ deserialize, ensure_2d, make_multi_time_argument, + object_str, ) from .derivatives import ( gradient, @@ -121,8 +122,8 @@ def __repr__(self): + "and data:\n" + "\n".join( [ - str(key) + ": " + repr(getattr(self, key)) - for key in self._data_dict().keys() + str(key) + ": " + object_str(v) + for key, v in self._data_dict().items() ] ) ) diff --git a/mellon/time_sensitive_density_estimator.py b/mellon/time_sensitive_density_estimator.py index 25253ce..02e34c1 100644 --- a/mellon/time_sensitive_density_estimator.py +++ b/mellon/time_sensitive_density_estimator.py @@ -312,6 +312,7 @@ def __repr__(self): normalize_per_time_point = object_str( self.normalize_per_time_point, ["time points"] ) + d = object_str(self.d, ["cells"]) string = ( f"{name}(" f"\n cov_func_curry={self.cov_func_curry}," @@ -323,7 +324,7 @@ def __repr__(self): f"\n landmarks={landmarks}," f"\n nn_distances={nn_distances}," f"\n normalize_per_time_point={normalize_per_time_point}," - f"\n d={self.d}," + f"\n d={d}," f"\n mu={self.mu}," f"\n ls={self.ls}," f"\n ls_time={self.ls_time}," diff --git a/tests/time_sensitive_density_estimator.py b/tests/time_sensitive_density_estimator.py index 14a82d6..1b78471 100644 --- a/tests/time_sensitive_density_estimator.py +++ b/tests/time_sensitive_density_estimator.py @@ -241,7 +241,7 @@ def test_density_estimator_errors(common_setup_time_sensitive): est.fit_predict() est.predict.n_obs = None with pytest.raises(ValueError): - est.predict(X, normalize=True) + est.predict(X, time=times, normalize=True) @pytest.mark.parametrize(