Skip to content

Commit

Permalink
Neater use off nn.Sequential in controlnet (#7754)
Browse files Browse the repository at this point in the history
Part of #7227  .

### Description
Tidies up some of controlnet

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
marksgraham and KumoLiu authored May 22, 2024
1 parent a052c44 commit a423bcd
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 13 deletions.
48 changes: 47 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import os
import subprocess
import sys
import importlib
import inspect

sys.path.insert(0, os.path.abspath(".."))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
Expand Down Expand Up @@ -137,7 +139,7 @@ def generate_apidocs(*args):
"github_user": "Project-MONAI",
"github_repo": "MONAI",
"github_version": "dev",
"doc_path": "docs/",
"doc_path": "docs/source",
"conf_py_path": "/docs/",
"VERSION": version,
}
Expand All @@ -162,3 +164,47 @@ def setup(app):
# Hook to allow for automatic generation of API docs
# before doc deployment begins.
app.connect("builder-inited", generate_apidocs)


# -- Linkcode configuration --------------------------------------------------
DEFAULT_REPOSITORY = "Project-MONAI/MONAI"
repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY)

base_code_url = f"https://github.com/{repository}/blob/{version}"
MODULE_ROOT_FOLDER = "monai"


# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py
def linkcode_resolve(domain, info):
if domain != "py":
raise ValueError(
f"expected domain to be 'py', got {domain}."
"Please adjust linkcode_resolve to either handle this domain or ignore it."
)

mod = importlib.import_module(info["module"])
if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".")
obj = getattr(mod, objname)
try:
# object is a method of a class
obj = getattr(obj, attrname)
except AttributeError:
# object is an attribute of a class
return None
else:
obj = getattr(mod, info["fullname"])

try:
file = inspect.getsourcefile(obj)
source, lineno = inspect.getsourcelines(obj)
except TypeError:
# e.g. object is a typing.Union
return None
file = os.path.relpath(file, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
if not file.startswith(MODULE_ROOT_FOLDER):
# e.g. object is a typing.NewType
return None
start, end = lineno, lineno + len(source) - 1
url = f"{base_code_url}/{file}#L{start}-L{end}"
return url
65 changes: 59 additions & 6 deletions monai/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from collections.abc import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from monai.networks.blocks import Convolution
Expand All @@ -57,7 +56,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)

self.blocks = nn.ModuleList([])
Expand All @@ -73,7 +73,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)
)

Expand All @@ -85,7 +86,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=2,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)
)

Expand All @@ -103,11 +105,9 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)

Expand Down Expand Up @@ -410,3 +410,56 @@ def forward(
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample

def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
"""
Load a state dict from a ControlNet trained with
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
Args:
old_state_dict: state dict from the old ControlNet model.
"""

new_state_dict = self.state_dict()
# if all keys match, just load the state dict
if all(k in new_state_dict for k in old_state_dict):
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
# print all new_state_dict keys that are not in old_state_dict
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
# and vice versa
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict[k]

# fix the attention blocks
attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
for block in attention_blocks:
new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat(
[
old_state_dict[f"{block}.attn1.to_q.weight"],
old_state_dict[f"{block}.attn1.to_k.weight"],
old_state_dict[f"{block}.attn1.to_v.weight"],
],
dim=0,
)

# projection
new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]

new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]

self.load_state_dict(new_state_dict)
4 changes: 1 addition & 3 deletions monai/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

if TYPE_CHECKING:
from matplotlib import cm
from matplotlib import pyplot as plt
else:
plt, _ = optional_import("matplotlib", name="pyplot")
cm, _ = optional_import("matplotlib", name="cm")

__all__ = ["matshow3d", "blend_images"]

Expand Down Expand Up @@ -210,7 +208,7 @@ def blend_images(
image = repeat(image, 3, axis=0)

def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor:
_cmap = cm.get_cmap(cmap)
_cmap = plt.colormaps.get_cmap(cmap)
label_np, *_ = convert_data_type(label, np.ndarray)
label_rgb_np = _cmap(label_np[0])
label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ einops
transformers>=4.36.0
mlflow>=1.28.0, <=2.11.3
clearml>=1.10.0rc0
matplotlib!=3.5.0
matplotlib>=3.6.3
tensorboardX
types-PyYAML
pyyaml
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ all =
transformers<4.22; python_version <= '3.10'
mlflow>=1.28.0, <=2.11.3
clearml>=1.10.0rc0
matplotlib
matplotlib>=3.6.3
tensorboardX
pyyaml
fire
Expand Down Expand Up @@ -127,7 +127,7 @@ transformers =
mlflow =
mlflow>=1.28.0, <=2.11.3
matplotlib =
matplotlib
matplotlib>=3.6.3
clearml =
clearml
tensorboardX =
Expand Down
33 changes: 33 additions & 0 deletions tests/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.networks import eval_mode
from monai.networks.nets.controlnet import ControlNet
from monai.utils import optional_import
from tests.utils import skip_if_downloading_fails, testing_data_config

_, has_einops = optional_import("einops")
UNCOND_CASES_2D = [
Expand Down Expand Up @@ -177,6 +181,35 @@ def test_shape_conditioned_models(self, input_param, expected_output_shape):
self.assertEqual(len(result[0]), 2 * len(input_param["channels"]))
self.assertEqual(result[1].shape, expected_output_shape)

@skipUnless(has_einops, "Requires einops")
def test_compatibility_with_monai_generative(self):
# test loading weights from a model saved in MONAI Generative, version 0.2.3
with skip_if_downloading_fails():
net = ControlNet(
spatial_dims=2,
in_channels=1,
num_res_blocks=1,
channels=(8, 8, 8),
attention_levels=(False, False, True),
norm_num_groups=8,
with_conditioning=True,
transformer_num_layers=1,
cross_attention_dim=3,
resblock_updown=True,
)

tmpdir = tempfile.mkdtemp()
key = "controlnet_monai_generative_weights"
url = testing_data_config("models", key, "url")
hash_type = testing_data_config("models", key, "hash_type")
hash_val = testing_data_config("models", key, "hash_val")
filename = "controlnet_monai_generative_weights.pt"

weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/test_matshow3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_3d_rgb(self):
every_n=2,
frame_dim=-1,
channel_dim=0,
fill_value=0,
show=False,
)

Expand Down
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth",
"hash_type": "sha256",
"hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184"
},
"controlnet_monai_generative_weights": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth",
"hash_type": "sha256",
"hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e"
}
},
"configs": {
Expand Down

0 comments on commit a423bcd

Please sign in to comment.