diff --git a/src/embedding_scvi/_components.py b/src/embedding_scvi/_components.py index d4cc3b8..a40d26c 100644 --- a/src/embedding_scvi/_components.py +++ b/src/embedding_scvi/_components.py @@ -1,73 +1,142 @@ +from __future__ import annotations + +from typing import Literal + import numpy as np import torch +from scvi.utils._exceptions import InvalidParameterError from torch import nn class MLPBlock(nn.Module): + """Multi-layer perceptron block. + + Parameters + ---------- + n_in + Number of input features. + n_out + Number of output features. + bias + Whether to include a bias term in the linear layer. + norm + Type of normalization to use. One of the following: + + * ``"batch"``: :class:`~torch.nn.BatchNorm1d` + * ``"layer"``: :class:`~torch.nn.LayerNorm` + * ``None``: No normalization + activation + Type of activation to use. One of the following: + + * ``"relu"``: :class:`~torch.nn.ReLU` + * ``"leaky_relu"``: :class:`~torch.nn.LeakyReLU` + * ``"softmax"``: :class:`~torch.nn.Softmax` + * ``"softplus"``: :class:`~torch.nn.Softplus` + dropout_rate + Dropout rate. If ``None``, no dropout is used. + residual + Whether to use residual connections. If ``True`` and ``n_in != n_out``, + then a linear layer is used to project the input to the correct + dimensionality. + """ + def __init__( self, n_in: int, n_out: int, bias: bool = True, - norm: str | None = None, - activation: str | None = None, + norm: Literal["batch", "layer"] | None = None, + norm_kwargs: dict | None = None, + activation: Literal["relu", "leaky_relu", "softmax", "softplus"] | None = None, + activation_kwargs: dict | None = None, dropout_rate: float | None = None, residual: bool = False, ): super().__init__() self.linear = nn.Linear(n_in, n_out, bias=bias) + self.norm = nn.Identity() + self.norm_kwargs = norm_kwargs or {} + self.activation = nn.Identity() + self.activation_kwargs = activation_kwargs or {} + self.dropout = nn.Identity() self.residual = residual if norm == "batch": - self.norm = nn.BatchNorm1d(n_out) + self.norm = nn.BatchNorm1d(n_out, **self.norm_kwargs) elif norm == "layer": - self.norm = nn.LayerNorm(n_out) + self.norm = nn.LayerNorm(n_out, **self.norm_kwargs) elif norm is not None: - raise ValueError(f"Unrecognized norm: {norm}") - else: - self.norm = nn.Identity() - - if activation is not None: - self.activation = getattr(nn.functional, activation) - else: - self.activation = nn.Identity() + raise InvalidParameterError(param="norm", value=norm, valid=["batch", "layer", None]) + + if activation == "relu": + self.activation = nn.ReLU(**self.activation_kwargs) + elif activation == "leaky_relu": + self.activation = nn.LeakyReLU(**self.activation_kwargs) + elif activation == "softmax": + self.activation = nn.Softmax(**self.activation_kwargs) + elif activation == "softplus": + self.activation = nn.Softplus(**self.activation_kwargs) + elif activation is not None: + raise InvalidParameterError( + param="norm", value=norm, valid=["relu", "leaky_relu", "softmax", "softplus", None] + ) if dropout_rate is not None: self.dropout = nn.Dropout(dropout_rate) - else: - self.dropout = nn.Identity() - if self.residual and n_in != n_out: - raise ValueError("`n_in` must equal `n_out` if `residual` is `True`.") + if residual and n_in != n_out: + self.residual_transform = nn.Linear(n_in, n_out, bias=False) + elif residual and n_in == n_out: + self.residual_transform = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.linear(x) h = self.norm(h) h = self.activation(h) h = self.dropout(h) - h = h + x if self.residual else h + h = h + self.residual_transform(x) if self.residual else h return h -class MultiOutput(nn.Module): +class MultiOutputLinear(nn.Module): + """Multi-output linear layer. + + Parameters + ---------- + n_in + Number of input features. + n_out + Number of output features. + n_out_params + Number of output parameters. + activations + List containing the type of activation to use for each output parameter. + One of the following: + + * ``"relu"``: :class:`~torch.nn.ReLU` + * ``"leaky_relu"``: :class:`~torch.nn.LeakyReLU` + * ``"softmax"``: :class:`~torch.nn.Softmax` + * ``"softplus"``: :class:`~torch.nn.Softplus` + * ``None``: No activation + activation_kwargs + List containing the keyword arguments to pass to the activation function + for each output parameter. + """ + def __init__( self, n_in: int, n_out: int, n_out_params: int, - param_activations: list[int] | None, + activations: list[int] | None, + activation_kwargs: list[dict] | None = None, ): super().__init__() + self.n_in = n_in + self.n_out = n_out self.n_out_params = n_out_params - self.param_activations = param_activations - - if self.param_activations is not None and len(param_activations) != n_out_params: - raise ValueError( - f"Length of `param_activations` {len(param_activations)}) must " - f"match `n_out_params`: {n_out_params}." - ) - elif self.param_activations is None: - self.param_activations = [None for _ in range(n_out_params)] + self.activations = activations or [None] * n_out_params + self.activation_kwargs = activation_kwargs or [{}] * n_out_params blocks = [] for i in range(self.n_out_params): @@ -76,13 +145,14 @@ def __init__( n_in=n_in, n_out=n_out, bias=False, - activation=self.param_activations[i], + activation=self.activations[i], + activation_kwargs=self.activation_kwargs[i], ) ) - self._blocks = nn.ModuleList(blocks) + self.blocks = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: - return tuple(block(x) for block in self._blocks) + return tuple(block(x) for block in self.blocks) class MLP(nn.Module): @@ -94,7 +164,9 @@ def __init__( n_layers: int, bias: bool = True, norm: str | None = None, + norm_kwargs: dict | None = None, activation: str | None = None, + activation_kwargs: dict | None = None, dropout_rate: float | None = None, residual: bool = False, ): @@ -108,70 +180,29 @@ def __init__( n_outs = [n_hidden for _ in range(n_layers - 1)] + [n_out] blocks = [] for n_in, n_out in zip(n_ins, n_outs): - _residual = residual and n_in == n_out blocks.append( MLPBlock( n_in=n_in, n_out=n_out, bias=bias, norm=norm, + norm_kwargs=norm_kwargs, activation=activation, + activation_kwargs=activation_kwargs, dropout_rate=dropout_rate, - residual=_residual, + residual=residual, ) ) - self._blocks = nn.Sequential(*blocks) + self.blocks = nn.Sequential(*blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self._blocks(x) - - -class MLPMultiOutput(nn.Module): - def __init__( - self, - n_in: int, - n_out: int, - n_hidden: int, - n_layers: int, - n_out_params: int, - param_activations: list[int] | None, - bias: bool = True, - norm: str | None = None, - activation: str | None = None, - dropout_rate: float | None = None, - residual: bool = False, - ): - super().__init__() - self.n_in = n_in - self.n_out = n_out - self.n_hidden = n_hidden - self.n_layers = n_layers - - self._mlp = MLP( - n_in=n_in, - n_out=n_hidden, - n_hidden=n_hidden, - n_layers=n_layers, - bias=bias, - norm=norm, - activation=activation, - dropout_rate=dropout_rate, - residual=residual, - ) - self._multi_output = MultiOutput( - n_in=n_hidden, - n_out=n_out, - n_out_params=n_out_params, - param_activations=param_activations, - ) - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: - h = self._mlp(x) - return self._multi_output(h) + return self.blocks(x) class ExtendableEmbedding(nn.Embedding): + """Extendable embedding layer.""" + @classmethod def extend_embedding( cls, @@ -179,21 +210,29 @@ def extend_embedding( init: int | list[int], freeze_prev: bool = True, ): + # (num_embeddings, embedding_dim) old_weight = embedding.weight.clone() - if isinstance(init, int) and init > 0: - num_init = init + + if isinstance(init, int): + if init <= 0: + raise ValueError(f"`init` must be > 0, got {init}") + n_init = init + # (n_init, embedding_dim) new_weight = torch.empty( (init, old_weight.shape[1]), device=old_weight.device, ) nn.init.normal_(new_weight) elif isinstance(init, list): - num_init = len(init) + n_init = len(init) + # (n_init, embedding_dim) new_weight = old_weight[init] + + # (num_embeddings + n_init, embedding_dim) weight = torch.cat([old_weight, new_weight], dim=0) new_embedding = cls( - num_embeddings=embedding.num_embeddings + num_init, + num_embeddings=embedding.num_embeddings + n_init, embedding_dim=embedding.embedding_dim, _weight=weight, padding_idx=embedding.padding_idx, @@ -203,8 +242,10 @@ def extend_embedding( sparse=embedding.sparse, ) + # freeze previous embeddings def _partial_freeze_hook(grad: torch.Tensor) -> torch.Tensor: - grad[: old_weight.shape[0]] = 0 + grad = grad.clone() + grad[: embedding.num_embeddings] = 0.0 return grad if freeze_prev: @@ -225,6 +266,8 @@ def _load_from_state_dict(self, state_dict, *args, **kwargs): class ExtendableEmbeddingList(nn.Module): + """List of extendable embedding layers.""" + def __init__( self, num_embeddings: list[int], @@ -261,8 +304,18 @@ def get_embedding_layer(self, index: int) -> nn.Embedding: def set_embedding_layer(self, index: int, embedding: nn.Embedding): self._embeddings[index] = embedding + def extend_embedding_layer(self, index: int, init: int | list[int], freeze_prev: bool = True) -> None: + self.set_embedding_layer( + index, + ExtendableEmbedding.extend_embedding( + self.get_embedding_layer(index), + init=init, + freeze_prev=freeze_prev, + ), + ) + def get_embedding_weight(self, index: int, as_tensor: bool = False) -> np.ndarray | torch.Tensor: - weight = self._embeddings[index].weight.detach().cpu() + weight = self.get_embedding_layer(index).weight.detach().cpu() if as_tensor: return weight return weight.numpy() diff --git a/src/embedding_scvi/_module.py b/src/embedding_scvi/_module.py index a1c66a7..cc110e6 100644 --- a/src/embedding_scvi/_module.py +++ b/src/embedding_scvi/_module.py @@ -12,7 +12,7 @@ from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from torch import nn -from ._components import ExtendableEmbeddingList, MLPMultiOutput +from ._components import ExtendableEmbeddingList from ._constants import TENSORS_KEYS @@ -53,17 +53,14 @@ def __init__( self.decoder_kwargs = decoder_kwargs or {} if likelihood == "zinb": - decoder_n_out_params = 3 # scale, r, dropout - decoder_param_activations = ["softmax", None, None] + pass elif likelihood == "nb": - decoder_n_out_params = 3 # mu, theta, scale - decoder_param_activations = [None, None, "softplus"] + pass elif likelihood == "poisson": - decoder_n_out_params = 2 # mu, scale - decoder_param_activations = [None, "softplus"] + pass else: raise ValueError(f"Invalid likelihood {likelihood}") @@ -77,13 +74,13 @@ def __init__( "residual": True, } _encoder_kwargs.update(self.encoder_kwargs) - self.encoder = MLPMultiOutput( - n_in=self.n_vars, - n_out=self.n_latent, - n_out_params=2, - param_activations=[None, "softplus"], - **_encoder_kwargs, - ) + # self.encoder = MLPMultiOutput( + # n_in=self.n_vars, + # n_out=self.n_latent, + # n_out_params=2, + # param_activations=[None, "softplus"], + # **_encoder_kwargs, + # ) _decoder_kwargs = { "n_hidden": 256, @@ -95,13 +92,13 @@ def __init__( "residual": True, } _decoder_kwargs.update(self.decoder_kwargs) - self.decoder = MLPMultiOutput( - n_in=self.n_latent, - n_out=self.n_vars, - n_out_params=decoder_n_out_params, - param_activations=decoder_param_activations, - **_decoder_kwargs, - ) + # self.decoder = MLPMultiOutput( + # n_in=self.n_latent, + # n_out=self.n_vars, + # n_out_params=decoder_n_out_params, + # param_activations=decoder_param_activations, + # **_decoder_kwargs, + # ) self.covariates_encoder = nn.Identity() if self.categorical_covariates is not None: diff --git a/tests/test_components.py b/tests/test_components.py index 6c7264b..5e891df 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -1,23 +1,174 @@ +import pytest import torch +from torch import nn -from embedding_scvi._components import MLP +from embedding_scvi._components import MLP, ExtendableEmbedding, MLPBlock, MultiOutputLinear -def test_mlp(): - x = torch.randn(100, 10) +@pytest.mark.parametrize("n_obs", [10]) +@pytest.mark.parametrize("n_in", [10]) +@pytest.mark.parametrize("n_out", [10, 20]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("norm", ["batch", "layer", None]) +@pytest.mark.parametrize("activation", ["relu", "softmax", None]) +@pytest.mark.parametrize("dropout_rate", [0.1, None]) +@pytest.mark.parametrize("residual", [True, False]) +def test_mlp_block( + n_obs: int, + n_in: int, + n_out: int, + bias: bool, + norm: str | None, + activation: str | None, + dropout_rate: float, + residual: bool, +): + if norm == "batch": + norm_kwargs = {"eps": 1e-3, "momentum": 0.1} + elif norm == "layer": + norm_kwargs = {"eps": 1e-3} + else: + norm_kwargs = None + + if activation == "softmax": + activation_kwargs = {"dim": 1} + else: + activation_kwargs = None + + mlp_block = MLPBlock( + n_in=n_in, + n_out=n_out, + bias=bias, + norm=norm, + norm_kwargs=norm_kwargs, + activation=activation, + activation_kwargs=activation_kwargs, + dropout_rate=dropout_rate, + residual=residual, + ) + + x = torch.randn(n_obs, n_in) + h = mlp_block(x) + assert h.shape == (n_obs, n_out) + + +@pytest.mark.parametrize("n_obs", [10]) +@pytest.mark.parametrize("n_in", [10]) +@pytest.mark.parametrize("n_out", [20]) +@pytest.mark.parametrize("n_out_params", [1, 2]) +@pytest.mark.parametrize("activation", ["relu", "softmax", None]) +def test_multi_output_linear( + n_obs: int, + n_in: int, + n_out: int, + n_out_params: int, + activation: str | None, +): + if activation == "softmax": + activation_kwargs = {"dim": 1} + else: + activation_kwargs = None + + multi_output_linear = MultiOutputLinear( + n_in=n_in, + n_out=n_out, + n_out_params=n_out_params, + activations=[activation] * n_out_params, + activation_kwargs=[activation_kwargs] * n_out_params, + ) + + x = torch.randn(n_obs, n_in) + h = multi_output_linear(x) + assert len(h) == n_out_params + assert all(h_i.shape == (n_obs, n_out) for h_i in h) + + +@pytest.mark.parametrize("n_obs", [10]) +@pytest.mark.parametrize("n_in", [10]) +@pytest.mark.parametrize("n_out", [20]) +@pytest.mark.parametrize("n_hidden", [64]) +@pytest.mark.parametrize("n_layers", [1, 2]) +@pytest.mark.parametrize("bias", [True]) +@pytest.mark.parametrize("norm", ["batch"]) +@pytest.mark.parametrize("activation", ["relu"]) +@pytest.mark.parametrize("dropout_rate", [0.1]) +@pytest.mark.parametrize("residual", [True]) +def test_mlp( + n_obs: int, + n_in: int, + n_out: int, + n_hidden: int, + n_layers: int, + bias: bool, + norm: str, + activation: str, + dropout_rate: float, + residual: bool, +): + if norm == "batch": + norm_kwargs = {"eps": 1e-3, "momentum": 0.1} + elif norm == "layer": + norm_kwargs = {"eps": 1e-3} + else: + norm_kwargs = None + + if activation == "softmax": + activation_kwargs = {"dim": 1} + else: + activation_kwargs = None + mlp = MLP( - n_in=10, - n_out=10, - n_out_params=10, - n_hidden=10, - n_layers=2, - bias=True, - norm=None, - norm_kwargs=None, - activation=None, - activation_kwargs=None, - dropout_rate=None, - residual=False, + n_in=n_in, + n_out=n_out, + n_hidden=n_hidden, + n_layers=n_layers, + bias=bias, + norm=norm, + norm_kwargs=norm_kwargs, + activation=activation, + activation_kwargs=activation_kwargs, + dropout_rate=dropout_rate, + residual=residual, ) - y = mlp(x) - print(len(y)) + + x = torch.randn(n_obs, n_in) + h = mlp(x) + assert h.shape == (n_obs, n_out) + + +@pytest.mark.parametrize("num_embeddings", [10]) +@pytest.mark.parametrize("embedding_dim", [5]) +@pytest.mark.parametrize("init", [2, [0, 1]]) +@pytest.mark.parametrize("freeze_prev", [True, False]) +def test_extendable_embedding( + num_embeddings: int, + embedding_dim: int, + init: int | list[int], + freeze_prev: bool, +): + embedding = nn.Embedding(num_embeddings, embedding_dim) + ext_embedding = ExtendableEmbedding.extend_embedding(embedding, init=init, freeze_prev=freeze_prev) + n_init = len(init) if isinstance(init, list) else init + + assert ext_embedding.num_embeddings == num_embeddings + n_init + assert ext_embedding.embedding_dim == embedding_dim + assert ext_embedding.weight.shape == (num_embeddings + n_init, embedding_dim) + assert torch.equal(ext_embedding.weight[:num_embeddings], embedding.weight) + + if isinstance(init, list): + assert torch.equal(ext_embedding.weight[num_embeddings:], embedding.weight[init]) + + dummy_indexes = torch.arange(num_embeddings + n_init, dtype=torch.long) + dummy_prediction = ext_embedding(dummy_indexes) + dummy_target = torch.randn_like(dummy_prediction) + dummy_loss = torch.nn.functional.mse_loss(dummy_prediction, dummy_target, reduce=True) + dummy_loss.backward() + grad = ext_embedding.weight.grad + + if freeze_prev: + prev_grad = grad[:num_embeddings] + new_grad = grad[num_embeddings:] + assert torch.equal(prev_grad, torch.zeros_like(prev_grad)) + assert not torch.equal(new_grad, torch.zeros_like(new_grad)) + else: + assert not torch.equal(grad, torch.zeros_like(grad))