Skip to content

Commit

Permalink
serialize: doctest doesn't like save_to_disk(PosixPath)
Browse files Browse the repository at this point in the history
The error message:
Warning, treated as error:
**********************************************************************
File "algorithms/dagger.rst", line 45, in default
Failed example:
    import tempfile

    import numpy as np
    import gymnasium as gym
    from stable_baselines3.common.evaluation import evaluate_policy

    from imitation.algorithms import bc
    from imitation.algorithms.dagger import SimpleDAggerTrainer
    from imitation.policies.serialize import load_policy
    from imitation.util.util import make_vec_env

    rng = np.random.default_rng(0)
    env = make_vec_env(
        "seals:seals/CartPole-v0",
        rng=rng,
    )
    expert = load_policy(
        "ppo-huggingface",
        organization="HumanCompatibleAI",
        env_name="seals-CartPole-v0",
        venv=env,
    )

    bc_trainer = bc.BC(
        observation_space=env.observation_space,
        action_space=env.action_space,
        rng=rng,
    )
    with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
        print(tmpdir)
        dagger_trainer = SimpleDAggerTrainer(
            venv=env,
            scratch_dir=tmpdir,
            expert_policy=expert,
            bc_trainer=bc_trainer,
            rng=rng,
        )
        dagger_trainer.train(8_000)

    reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)
    print("Reward:", reward)
Exception raised:
    Traceback (most recent call last):
      File "/usr/lib/python3.8/doctest.py", line 1336, in __run
        exec(compile(example.source, filename, "single",
      File "<doctest default[0]>", line 38, in <module>
        dagger_trainer.train(8_000)
      File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 669, in train
        trajectories = rollout.generate_trajectories(
      File "/venv/lib/python3.8/site-packages/imitation/data/rollout.py", line 447, in generate_trajectories
        obs, rews, dones, infos = venv.step(acts)
      File "/venv/lib/python3.8/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 206, in step
        return self.step_wait()
      File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 285, in step_wait
        _save_dagger_demo(traj, traj_index, self.save_dir, self.rng)
      File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 147, in _save_dagger_demo
        serialize.save(npz_path, [trajectory])
      File "/venv/lib/python3.8/site-packages/imitation/data/serialize.py", line 23, in save
        huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
      File "/venv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1470, in save_to_disk
        fs, _ = url_to_fs(dataset_path, **(storage_options or {}))
      File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 383, in url_to_fs
        chain = _un_chain(url, kwargs)
      File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 323, in _un_chain
        if "::" in path
    TypeError: argument of type 'PosixPath' is not iterable
  • Loading branch information
tomtseng committed Jan 6, 2025
1 parent b765aff commit 30f9048
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
trajectories: The trajectories to save.
"""
p = util.parse_path(path)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))
logging.info(f"Dumped demonstrations to {p}.")


Expand Down

0 comments on commit 30f9048

Please sign in to comment.