From 19712a543ab04e2d88c66354c8bb9f1b6c691cda Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 8 Mar 2024 10:51:46 +0100 Subject: [PATCH 1/2] Reformat and sort imports --- huggingface_sb3/__init__.py | 2 +- huggingface_sb3/push_to_hub.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/huggingface_sb3/__init__.py b/huggingface_sb3/__init__.py index f2e01ef..af23752 100644 --- a/huggingface_sb3/__init__.py +++ b/huggingface_sb3/__init__.py @@ -1,3 +1,3 @@ from .load_from_hub import load_from_hub +from .naming_schemes import EnvironmentName, ModelName, ModelRepoId from .push_to_hub import package_to_hub, push_to_hub -from .naming_schemes import ModelName, EnvironmentName, ModelRepoId diff --git a/huggingface_sb3/push_to_hub.py b/huggingface_sb3/push_to_hub.py index 6b90df9..d63758e 100644 --- a/huggingface_sb3/push_to_hub.py +++ b/huggingface_sb3/push_to_hub.py @@ -94,13 +94,16 @@ def _evaluate_agent( return mean_reward, std_reward + def entry_point(env_id: str) -> str: try: return str(gym.envs.registry[env_id].entry_point) except KeyError: import gym as gym26 + return str(gym26.envs.registry[env_id].entry_point) + def is_atari(env_id: str) -> bool: """ Check if the environment is an Atari one From 69bee7731c898811d2fb75522839df8f9718b675 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 8 Mar 2024 11:13:47 +0100 Subject: [PATCH 2/2] Do not allow to download untrusted model by default --- README.md | 8 ++++++++ huggingface_sb3/load_from_hub.py | 30 ++++++++++++++++++++++++++++++ tests/test_load_from_hub.py | 5 ++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cf71150..16ceb07 100644 --- a/README.md +++ b/README.md @@ -16,13 +16,21 @@ We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 [here](https:// If you use **Colab or a Virtual/Screenless Machine**, you can check Case 3 and Case 4. ### Case 1: I want to download a model from the Hub + +You will need to set the `TRUST_REMOTE_CODE` environment variable to `True` to allow the use of `pickle.load()`: + ```python +import os import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy +# Allow the use of `pickle.load()` when downloading model from the hub +# Please make sure that the organization from which you download can be trusted +os.environ["TRUST_REMOTE_CODE"] = "True" + # Retrieve the model from the hub ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename = name of the model zip file from the repository diff --git a/huggingface_sb3/load_from_hub.py b/huggingface_sb3/load_from_hub.py index cc426b1..483aa37 100644 --- a/huggingface_sb3/load_from_hub.py +++ b/huggingface_sb3/load_from_hub.py @@ -1,3 +1,22 @@ +import os + + +# Vendored from distutils.util +def strtobool(val: str) -> bool: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; + False values are 'n', 'no', 'f', 'false', 'off', and '0'. + Raises ValueError if 'val' is anything else. + """ + val = val.lower() + if val in {"y", "yes", "t", "true", "on", "1"}: + return 1 + if val in {"n", "no", "f", "false", "off", "0"}: + return 0 + raise ValueError(f"Invalid truth value {val!r}") + + def load_from_hub(repo_id: str, filename: str) -> str: """ Download a model from Hugging Face Hub. @@ -12,6 +31,17 @@ def load_from_hub(repo_id: str, filename: str) -> str: "See https://pypi.org/project/huggingface-hub/ for installation." ) + # Copied from https://github.com/huggingface/transformers/pull/27776 + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "You are about to download a model from the HF hub that will be loaded using `pickle.load`. " + "`pickle.load` is insecure and will execute arbitrary code that is " + "potentially malicious. It's recommended to never unpickle data that could have come from an " + "untrusted source, or that could have been tampered with. If you trust the pickle " + "data and decided to use it, you can set the environment variable " + "`TRUST_REMOTE_CODE` to `True` to allow it." + ) + # Get the model from the Hub, download and cache the model on your local disk downloaded_model_file = hf_hub_download( repo_id=repo_id, diff --git a/tests/test_load_from_hub.py b/tests/test_load_from_hub.py index f97e08d..dbc47bd 100644 --- a/tests/test_load_from_hub.py +++ b/tests/test_load_from_hub.py @@ -1,9 +1,12 @@ -import gym +import os +import gymnasium as gym from huggingface_sb3 import load_from_hub, ModelRepoId, ModelName, EnvironmentName from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy +# Test models from sb3 organization can be trusted +os.environ["TRUST_REMOTE_CODE"] = "True" def test_load_from_hub_with_naming_scheme_utils(): # Retrieve the model from the hub