diff --git a/docs/source/conf.py b/docs/source/conf.py index fdb10fbe03..827626d12e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,8 @@ import os import subprocess import sys +import importlib +import inspect sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) @@ -137,7 +139,7 @@ def generate_apidocs(*args): "github_user": "Project-MONAI", "github_repo": "MONAI", "github_version": "dev", - "doc_path": "docs/", + "doc_path": "docs/source", "conf_py_path": "/docs/", "VERSION": version, } @@ -162,3 +164,47 @@ def setup(app): # Hook to allow for automatic generation of API docs # before doc deployment begins. app.connect("builder-inited", generate_apidocs) + + +# -- Linkcode configuration -------------------------------------------------- +DEFAULT_REPOSITORY = "Project-MONAI/MONAI" +repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY) + +base_code_url = f"https://github.com/{repository}/blob/{version}" +MODULE_ROOT_FOLDER = "monai" + + +# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py +def linkcode_resolve(domain, info): + if domain != "py": + raise ValueError( + f"expected domain to be 'py', got {domain}." + "Please adjust linkcode_resolve to either handle this domain or ignore it." + ) + + mod = importlib.import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(mod, objname) + try: + # object is a method of a class + obj = getattr(obj, attrname) + except AttributeError: + # object is an attribute of a class + return None + else: + obj = getattr(mod, info["fullname"]) + + try: + file = inspect.getsourcefile(obj) + source, lineno = inspect.getsourcelines(obj) + except TypeError: + # e.g. object is a typing.Union + return None + file = os.path.relpath(file, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + if not file.startswith(MODULE_ROOT_FOLDER): + # e.g. object is a typing.NewType + return None + start, end = lineno, lineno + len(source) - 1 + url = f"{base_code_url}/{file}#L{start}-L{end}" + return url diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 7450c87314..fe6746e017 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -34,7 +34,6 @@ from collections.abc import Sequence import torch -import torch.nn.functional as F from torch import nn from monai.networks.blocks import Convolution @@ -57,7 +56,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=1, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) self.blocks = nn.ModuleList([]) @@ -73,7 +73,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=1, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) ) @@ -85,7 +86,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann strides=2, kernel_size=3, padding=1, - conv_only=True, + adn_ordering="A", + act="SWISH", ) ) @@ -103,11 +105,9 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann def forward(self, conditioning): embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) - embedding = F.silu(embedding) embedding = self.conv_out(embedding) @@ -410,3 +410,56 @@ def forward( mid_block_res_sample *= conditioning_scale return down_block_res_samples, mid_block_res_sample + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a ControlNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old ControlNet model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + + self.load_state_dict(new_state_dict) diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index f6718fe7a5..88c9a0d66a 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -24,11 +24,9 @@ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type if TYPE_CHECKING: - from matplotlib import cm from matplotlib import pyplot as plt else: plt, _ = optional_import("matplotlib", name="pyplot") - cm, _ = optional_import("matplotlib", name="cm") __all__ = ["matshow3d", "blend_images"] @@ -210,7 +208,7 @@ def blend_images( image = repeat(image, 3, axis=0) def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor: - _cmap = cm.get_cmap(cmap) + _cmap = plt.colormaps.get_cmap(cmap) label_np, *_ = convert_data_type(label, np.ndarray) label_rgb_np = _cmap(label_np[0]) label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] diff --git a/requirements-dev.txt b/requirements-dev.txt index ce28d3ebe2..35ff3382be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,7 +36,7 @@ einops transformers>=4.36.0 mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 -matplotlib!=3.5.0 +matplotlib>=3.6.3 tensorboardX types-PyYAML pyyaml diff --git a/setup.cfg b/setup.cfg index c8ae1630f7..c90b043c1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,7 +68,7 @@ all = transformers<4.22; python_version <= '3.10' mlflow>=1.28.0, <=2.11.3 clearml>=1.10.0rc0 - matplotlib + matplotlib>=3.6.3 tensorboardX pyyaml fire @@ -127,7 +127,7 @@ transformers = mlflow = mlflow>=1.28.0, <=2.11.3 matplotlib = - matplotlib + matplotlib>=3.6.3 clearml = clearml tensorboardX = diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 05ceb69fa3..4746c7ce22 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -11,15 +11,19 @@ from __future__ import annotations +import os +import tempfile import unittest from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets.controlnet import ControlNet from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config _, has_einops = optional_import("einops") UNCOND_CASES_2D = [ @@ -177,6 +181,35 @@ def test_shape_conditioned_models(self, input_param, expected_output_shape): self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) self.assertEqual(result[1].shape, expected_output_shape) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = ControlNet( + spatial_dims=2, + in_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + resblock_updown=True, + ) + + tmpdir = tempfile.mkdtemp() + key = "controlnet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "controlnet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index e54bb523e4..2eba310f4e 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -114,6 +114,7 @@ def test_3d_rgb(self): every_n=2, frame_dim=-1, channel_dim=0, + fill_value=0, show=False, ) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 318331e5f7..8b1d2868b7 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -153,6 +153,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", "hash_type": "sha256", "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" + }, + "controlnet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth", + "hash_type": "sha256", + "hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e" } }, "configs": {