Skip to content

Commit

Permalink
better representation of Predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Aug 30, 2023
1 parent 23de93a commit 1874c15
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion mellon/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __repr__(self):
L = object_str(self.L, ["cells", "ranks"])
nn_distances = object_str(self.nn_distances, ["cells"])
initial_value = object_str(self.initial_value, ["ranks"])
d = object_str(self.d, ["cells"])
string = (
f"{name}("
f"\n cov_func_curry={self.cov_func_curry},"
Expand All @@ -137,7 +138,7 @@ def __repr__(self):
f"\n optimizer={self.optimizer},"
f"\n landmarks={landmarks},"
f"\n nn_distances={nn_distances},"
f"\n d={self.d},"
f"\n d={d},"
f"\n mu={self.mu},"
f"\n ls={self.ls},"
f"\n ls_factor={self.ls_factor},"
Expand Down
5 changes: 3 additions & 2 deletions mellon/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deserialize,
ensure_2d,
make_multi_time_argument,
object_str,
)
from .derivatives import (
gradient,
Expand Down Expand Up @@ -121,8 +122,8 @@ def __repr__(self):
+ "and data:\n"
+ "\n".join(
[
str(key) + ": " + repr(getattr(self, key))
for key in self._data_dict().keys()
str(key) + ": " + object_str(v)
for key, v in self._data_dict().items()
]
)
)
Expand Down
3 changes: 2 additions & 1 deletion mellon/time_sensitive_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __repr__(self):
normalize_per_time_point = object_str(
self.normalize_per_time_point, ["time points"]
)
d = object_str(self.d, ["cells"])
string = (
f"{name}("
f"\n cov_func_curry={self.cov_func_curry},"
Expand All @@ -323,7 +324,7 @@ def __repr__(self):
f"\n landmarks={landmarks},"
f"\n nn_distances={nn_distances},"
f"\n normalize_per_time_point={normalize_per_time_point},"
f"\n d={self.d},"
f"\n d={d},"
f"\n mu={self.mu},"
f"\n ls={self.ls},"
f"\n ls_time={self.ls_time},"
Expand Down
2 changes: 1 addition & 1 deletion tests/time_sensitive_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_density_estimator_errors(common_setup_time_sensitive):
est.fit_predict()
est.predict.n_obs = None
with pytest.raises(ValueError):
est.predict(X, normalize=True)
est.predict(X, time=times, normalize=True)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1874c15

Please sign in to comment.