diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index b6039402c..28da4c21f 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -16,8 +16,10 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons +from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward from imitation.data import types from imitation.data.types import TrajectoryWithRew +from imitation.policies.replay_buffer_wrapper import ReplayBufferView from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets from imitation.util import networks, util @@ -71,6 +73,23 @@ def agent_trainer(agent, reward_net, venv, rng): return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) +@pytest.fixture +def replay_buffer(rng): + return ReplayBufferView(rng.random((10, 8, 4)), lambda: slice(None)) + + +@pytest.fixture +def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer): + reward_fn = PebbleStateEntropyReward(reward_net.predict_processed) + reward_fn.set_replay_buffer(replay_buffer, (4,)) + return preference_comparisons.PebbleAgentTrainer( + algorithm=agent, + reward_fn=reward_fn, + venv=venv, + rng=rng, + ) + + def _check_trajs_equal( trajs1: Sequence[types.TrajectoryWithRew], trajs2: Sequence[types.TrajectoryWithRew], @@ -277,14 +296,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): "schedule", ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], ) +@pytest.mark.parametrize("agent_fixture", ["agent_trainer", "pebble_agent_trainer"]) def test_trainer_no_crash( - agent_trainer, + request, + agent_fixture, reward_net, random_fragmenter, custom_logger, schedule, rng, ): + agent_trainer = request.getfixturevalue(agent_fixture) main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, reward_net,