Skip to content

Commit

Permalink
#625 add test for pebble agent trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent f3decf1 commit 473c7b2
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/algorithms/test_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import imitation.testing.reward_nets as testing_reward_nets
from imitation.algorithms import preference_comparisons
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import types
from imitation.data.types import TrajectoryWithRew
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.regularization import regularizers, updaters
from imitation.rewards import reward_nets
from imitation.util import networks, util
Expand Down Expand Up @@ -71,6 +73,23 @@ def agent_trainer(agent, reward_net, venv, rng):
return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng)


@pytest.fixture
def replay_buffer(rng):
return ReplayBufferView(rng.random((10, 8, 4)), lambda: slice(None))


@pytest.fixture
def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed)
reward_fn.set_replay_buffer(replay_buffer, (4,))
return preference_comparisons.PebbleAgentTrainer(
algorithm=agent,
reward_fn=reward_fn,
venv=venv,
rng=rng,
)


def _check_trajs_equal(
trajs1: Sequence[types.TrajectoryWithRew],
trajs2: Sequence[types.TrajectoryWithRew],
Expand Down Expand Up @@ -277,14 +296,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
"schedule",
["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)],
)
@pytest.mark.parametrize("agent_fixture", ["agent_trainer", "pebble_agent_trainer"])
def test_trainer_no_crash(
agent_trainer,
request,
agent_fixture,
reward_net,
random_fragmenter,
custom_logger,
schedule,
rng,
):
agent_trainer = request.getfixturevalue(agent_fixture)
main_trainer = preference_comparisons.PreferenceComparisons(
agent_trainer,
reward_net,
Expand Down

0 comments on commit 473c7b2

Please sign in to comment.