Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continue training of a BC policy with RL fails #862

Open
timosturm opened this issue Dec 14, 2024 · 1 comment
Open

Continue training of a BC policy with RL fails #862

timosturm opened this issue Dec 14, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@timosturm
Copy link

Bug description

My goal is to pre-train a policy with BC and fine tune it with RL, e.g., PPO. The problem is that I cannot find an example for this, and the ways I tried do not work.

Steps to reproduce

Below I provide a minimal example by using the quickstart.

"""This is a simple example demonstrating how to clone the behavior of an expert.

Refer to the jupyter notebooks for more detailed examples of how to use the algorithms.
"""
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)
env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=rng,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollouts
)


def train_expert():
    # note: use `download_expert` instead to download a pretrained, competent expert
    print("Training a expert.")
    expert = PPO(
        policy=MlpPolicy,
        env=env,
        seed=0,
        batch_size=64,
        ent_coef=0.0,
        learning_rate=0.0003,
        n_epochs=10,
        n_steps=64,
    )
    expert.learn(1_000)  # Note: change this to 100_000 to train a decent expert.
    return expert


def sample_expert_transitions():
    expert = train_expert()  # uncomment to train your own expert

    print("Sampling expert transitions.")
    rollouts = rollout.rollout(
        expert,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=50),
        rng=rng,
    )
    
    return rollout.flatten_trajectories(rollouts)

transitions = sample_expert_transitions()
bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

evaluation_env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=rng,
    # env_make_kwargs={"render_mode": "human"},  # for rendering
)

bc_trainer.train(n_epochs=1)

ppo = PPO(
    policy=bc_trainer.policy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)

Running the code like this gives the following error:

.../site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

TypeError: forward() got an unexpected keyword argument 'use_sde'

Using the default policy of PPO like this

from stable_baselines3.common.policies import ActorCriticPolicy

transitions = sample_expert_transitions()
bc_trainer = bc.BC(
    policy=ActorCriticPolicy,
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

leads to another error:

.../site-packages/torch/nn/modules/module.py:1340, in Module.to(self, *args, **kwargs)
   1337         else:
   1338             raise
-> 1340 return self._apply(convert)

AttributeError: 'torch.device' object has no attribute '_apply'

Environment

  • Python version: 3.9.18
  • Output of pip freeze --all:
absl-py==2.1.0
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiohttp-cors==0.7.0
aiosignal==1.3.1
ale-py==0.10.1
alembic==1.14.0
annotated-types==0.7.0
anyio==4.4.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
async-timeout==4.0.3
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bleach==6.1.0
bokeh==3.4.3
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorful==0.5.6
colorlog==6.9.0
comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
datasets==3.2.0
debugpy==1.8.5
decorator==4.4.2
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.8
distlib==0.3.8
dm-tree==0.1.8
docker-pycreds==0.4.0
docopt-ng==0.9.0
docstring_parser==0.16
etils==1.5.2
eval_type_backport==0.2.0
exceptiongroup==1.2.2
executing==2.1.0
Farama-Notifications==0.0.4
fastapi==0.114.2
fastjsonschema==2.20.0
filelock==3.16.0
flatbuffers==24.3.25
fonttools==4.53.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.9.0
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.43
glfw==2.8.0
google-api-core==2.19.2
google-auth==2.34.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.65.0
GPUtil==1.4.0
greenlet==3.1.1
grpcio==1.62.0
gymnasium==0.29.1
h11==0.14.0
h5py==3.11.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.2
huggingface-hub==0.26.5
huggingface-sb3==3.0
icecream==2.1.3
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imitation==1.0.0
importlib_metadata==8.4.0
importlib_resources==6.4.5
ipykernel==6.29.5
ipython==8.18.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.4.30
jax-jumpy==1.0.0
jaxlib==0.4.30
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpickle==4.0.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
keras==3.7.0
kiwisolver==1.4.7
lazy_loader==0.4
libclang==18.1.1
lightning-utilities==0.11.8
lz4==4.3.3
Mako==1.3.8
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.4.1
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.1.0
mujoco==3.2.6
multidict==6.1.0
multiprocess==0.70.16
munch==4.0.0
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.2.2
notebook_shim==0.2.4
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
oauthlib==3.2.2
opencensus==0.11.4
opencensus-context==0.1.3
opencv-python==4.10.0.84
opentelemetry-api==1.27.0
opentelemetry-exporter-otlp==1.27.0
opentelemetry-exporter-otlp-proto-common==1.27.0
opentelemetry-exporter-otlp-proto-grpc==1.27.0
opentelemetry-exporter-otlp-proto-http==1.27.0
opentelemetry-proto==1.27.0
opentelemetry-sdk==1.27.0
opentelemetry-semantic-conventions==0.48b0
opt-einsum==3.3.0
optree==0.12.1
optuna==4.1.0
overrides==7.7.0
packaging==24.1
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==10.4.0
pip @ file:///croot/pip_1723484598856/work
platformdirs==4.3.3
plotly==5.24.1
proglog==0.1.10
prometheus_client==0.20.0
prompt_toolkit==3.0.47
proto-plus==1.24.0
protobuf==4.25.4
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
py-cpuinfo==9.0.0
py-spy==0.3.14
pyarrow==17.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
pydantic==2.9.1
pydantic_core==2.23.3
pygame==2.6.1
Pygments==2.18.0
PyOpenGL==3.1.7
pyparsing==3.1.4
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytorch-lightning==2.4.0
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
ray==2.10.0
ray-cpp==2.10.0
rdkit==2024.3.6
referencing==0.35.1
requests==2.32.3
requests-oauthlib==2.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.8.1
rpds-py==0.20.0
rsa==4.9
sacred==0.8.7
scikit-image==0.24.0
scikit-learn==1.5.0
scipy==1.13.1
seals==0.2.1
Send2Trash==1.8.3
sentry-sdk==2.17.0
setproctitle==1.3.3
setuptools==75.1.0
shellingham==1.5.4
shtab==1.7.1
six==1.16.0
smart-open==7.0.4
smmap==5.0.1
sniffio==1.3.1
soundfile==0.12.1
soupsieve==2.6
SQLAlchemy==2.0.36
stable_baselines3==2.4.0
stack-data==0.6.3
starlette==0.38.5
sympy==1.13.1
tenacity==9.0.0
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.18.0
tensorflow-estimator==2.12.0
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.4.0
terminado==0.18.1
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tomli==2.0.1
torch==2.5.1+cpu
torchaudio==2.5.1+cpu
torchmetrics==1.5.1
torchvision==0.20.1+cpu
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
triton==3.1.0
typer==0.12.5
types-python-dateutil==2.9.0.20240906
typing_extensions==4.12.2
tyro==0.9.1
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.3
uvicorn==0.30.6
uvloop==0.20.0
virtualenv==20.26.4
wandb==0.18.5
wasabi==1.1.3
watchfiles==0.24.0
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==13.0.1
Werkzeug==3.0.4
wheel==0.44.0
widgetsnbextension==4.0.13
wrapt==1.14.1
xxhash==3.5.0
xyzservices==2024.9.0
yarl==1.11.1
zipp==3.20.2
@timosturm timosturm added the bug Something isn't working label Dec 14, 2024
@timosturm timosturm changed the title How to continue training of a BC policy with RL Continue training of a BC policy with RL fails Dec 14, 2024
@ChenJiangxi
Copy link

I'm having the same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants