Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
serialize: doctest doesn't like save_to_disk(PosixPath)
The error message: Warning, treated as error: ********************************************************************** File "algorithms/dagger.rst", line 45, in default Failed example: import tempfile import numpy as np import gymnasium as gym from stable_baselines3.common.evaluation import evaluate_policy from imitation.algorithms import bc from imitation.algorithms.dagger import SimpleDAggerTrainer from imitation.policies.serialize import load_policy from imitation.util.util import make_vec_env rng = np.random.default_rng(0) env = make_vec_env( "seals:seals/CartPole-v0", rng=rng, ) expert = load_policy( "ppo-huggingface", organization="HumanCompatibleAI", env_name="seals-CartPole-v0", venv=env, ) bc_trainer = bc.BC( observation_space=env.observation_space, action_space=env.action_space, rng=rng, ) with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: print(tmpdir) dagger_trainer = SimpleDAggerTrainer( venv=env, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer, rng=rng, ) dagger_trainer.train(8_000) reward, _ = evaluate_policy(dagger_trainer.policy, env, 10) print("Reward:", reward) Exception raised: Traceback (most recent call last): File "/usr/lib/python3.8/doctest.py", line 1336, in __run exec(compile(example.source, filename, "single", File "<doctest default[0]>", line 38, in <module> dagger_trainer.train(8_000) File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 669, in train trajectories = rollout.generate_trajectories( File "/venv/lib/python3.8/site-packages/imitation/data/rollout.py", line 447, in generate_trajectories obs, rews, dones, infos = venv.step(acts) File "/venv/lib/python3.8/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 206, in step return self.step_wait() File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 285, in step_wait _save_dagger_demo(traj, traj_index, self.save_dir, self.rng) File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 147, in _save_dagger_demo serialize.save(npz_path, [trajectory]) File "/venv/lib/python3.8/site-packages/imitation/data/serialize.py", line 23, in save huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p) File "/venv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1470, in save_to_disk fs, _ = url_to_fs(dataset_path, **(storage_options or {})) File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 383, in url_to_fs chain = _un_chain(url, kwargs) File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 323, in _un_chain if "::" in path TypeError: argument of type 'PosixPath' is not iterable
- Loading branch information