Skip to content

Commit

Permalink
default fast model loading 🔥 (huggingface#1115)
Browse files Browse the repository at this point in the history
* make accelerate hard dep

* default fast init

* move params to cpu when device map is None

* handle device_map=None

* handle torch < 1.9

* remove device_map="auto"

* style

* add accelerate in torch extra

* remove accelerate from extras["test"]

* raise an error if torch is available but not accelerate

* update installation docs

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* improve defautl loading speed even further, allow disabling fats loading

* address review comments

* adapt the tests

* fix test_stable_diffusion_fast_load

* fix test_read_init

* temp fix for dummy checks

* Trigger Build

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <[email protected]>

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Anton Lozhkov <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2022
1 parent ef2ea33 commit 7482178
Show file tree
Hide file tree
Showing 23 changed files with 564 additions and 109 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:

## Installation

### For PyTorch

**With `pip`**

```bash
pip install --upgrade diffusers
pip install --upgrade diffusers[torch]
```

**With `conda`**
Expand All @@ -39,6 +41,14 @@ pip install --upgrade diffusers
conda install -c conda-forge diffusers
```

### For Flax

**With `pip`**

```bash
pip install --upgrade diffusers[flax]
```

**Apple Silicon (M1/M2) support**

Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
Expand Down Expand Up @@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
# !pip install diffusers["torch"] transformers
from diffusers import DiffusionPipeline

device = "cuda"
Expand All @@ -373,7 +383,7 @@ image.save("squirrel.png")
```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
# !pip install diffusers["torch"]
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline

model_id = "google/ddpm-celebahq-256"
Expand Down
40 changes: 36 additions & 4 deletions docs/source/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.

# Installation

Install Diffusers for with PyTorch. Support for other libraries will come in the future
Install 🤗 Diffusers for whichever deep learning library you’re working with.

🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:

- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.

## Install with pip

Expand All @@ -36,12 +39,30 @@ source .env/bin/activate

Now you're ready to install 🤗 Diffusers with the following command:

**For PyTorch**

```bash
pip install diffusers["torch"]
```

**For Flax**

```bash
pip install diffusers
pip install diffusers["flax"]
```

## Install from source

Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.

For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).

To install `accelerate`

```bash
pip install accelerate
```

Install 🤗 Diffusers from source with the following command:

```bash
Expand All @@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install -e .
```

**For PyTorch**

```
pip install -e ".[torch]"
```

**For Flax**

```
pip install -e ".[flax]"
```

These commands will link the folder you cloned the repository to and your Python library paths.
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def run(self):
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"accelerate",
"datasets",
"parameterized",
"pytest",
Expand All @@ -188,7 +187,7 @@ def run(self):
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")
extras["torch"] = deps_list("torch", "accelerate")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .utils import (
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_onnx_available,
Expand All @@ -16,6 +17,13 @@
from .utils import logging


# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)


if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
Expand Down
47 changes: 39 additions & 8 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import torch
from torch import Tensor, device

from diffusers.utils import is_accelerate_available
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
Expand Down Expand Up @@ -268,6 +270,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
fast_load (`bool`, *optional*, defaults to `True`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
this argument will be ignored and the model will be loaded normally.
<Tip>
Expand Down Expand Up @@ -296,6 +311,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)

# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")

# Fast init is only possible if torch version is >= 1.9.0
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")

user_agent = {
"diffusers": __version__,
Expand Down Expand Up @@ -378,12 +403,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# restore default dtype

if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

if _INIT_EMPTY_WEIGHTS:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
Expand All @@ -400,7 +421,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
**kwargs,
)

accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)

loading_info = {
"missing_keys": [],
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand Down Expand Up @@ -572,6 +573,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)

if is_diffusers_model:
loading_kwargs["fast_load"] = fast_load

# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_transformers_model and device_map is None:
loading_kwargs["low_cpu_mem_usage"] = fast_load

if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map

Expand Down
Loading

0 comments on commit 7482178

Please sign in to comment.