Skip to content

Commit

Permalink
feat: convolutional vae for multivariate time series (#237)
Browse files Browse the repository at this point in the history
- Introduce variational autoencoder for multivariate time series data
- Customized trainer for vae
- Causal convolutional modules
- Improved training console logging

Exploration in numaproj/numalogic-benchmarks#3

---------

Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 authored Aug 8, 2023
1 parent ecc1192 commit 2dfd84c
Show file tree
Hide file tree
Showing 12 changed files with 633 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ clean:
@find . -type f -name "*.py[co]" -exec rm -rf {} +

format: clean
poetry run black numalogic/ examples/ tests/ benchmarks/
poetry run black numalogic/ examples/ tests/

lint: format
poetry run ruff check --fix .
Expand Down
3 changes: 3 additions & 0 deletions numalogic/models/vae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from numalogic.models.vae.trainer import VAETrainer

__all__ = ["VAETrainer"]
63 changes: 63 additions & 0 deletions numalogic/models/vae/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Callable

import torch.nn.functional as F
from torch import Tensor, optim

from numalogic.base import TorchModel


def _init_criterion(loss_fn: str) -> Callable:
if loss_fn == "huber":
return F.huber_loss
if loss_fn == "l1":
return F.l1_loss
if loss_fn == "mse":
return F.mse_loss
raise ValueError(f"Unsupported loss function provided: {loss_fn}")


class BaseVAE(TorchModel):
"""
Abstract Base class for all Pytorch based variational autoencoder models.
Args:
----
lr: learning rate (default: 3e-4)
weight_decay: weight decay factor weight for regularization (default: 0.0)
loss_fn: loss function used to train the model
supported values include: {mse, l1, huber}
"""

def __init__(
self,
lr: float = 3e-4,
weight_decay: float = 0.0,
loss_fn: str = "mse",
):
super().__init__()
self._lr = lr
self.weight_decay = weight_decay
self.criterion = _init_criterion(loss_fn)

def configure_shape(self, x: Tensor) -> Tensor:
"""Method to configure the batch shape for each type of model architecture."""
return x

def configure_optimizers(self) -> dict:
optimizer = optim.Adam(self.parameters(), lr=self._lr, weight_decay=self.weight_decay)
return {"optimizer": optimizer}

def recon_loss(self, batch: Tensor, recon: Tensor, reduction: str = "sum"):
return self.criterion(batch, recon, reduction=reduction)

def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
"""Validation step for the model."""
p, recon = self.forward(batch)
loss = self.recon_loss(batch, recon)
self.log("val_loss", loss)
return loss

def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Prediction step for the model."""
p, recon = self.forward(batch)
return self.recon_loss(batch, recon, reduction="none")
62 changes: 62 additions & 0 deletions numalogic/models/vae/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from torch import nn, Tensor
import torch.nn.functional as F


class CausalConv1d(nn.Conv1d):
"""Temporal convolutional layer with causal padding."""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)

self.__padding = (kernel_size - 1) * dilation

def forward(self, x: Tensor) -> Tensor:
return super().forward(F.pad(x, (self.__padding, 0)))


class CausalConvBlock(nn.Module):
"""Basic convolutional block consisting of:
- causal 1D convolutional layer
- batch norm
- relu activation.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.conv = CausalConv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
)
self.bnorm = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)

def forward(self, input_: Tensor) -> Tensor:
return self.relu(self.bnorm(self.conv(input_)))
69 changes: 69 additions & 0 deletions numalogic/models/vae/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
import warnings
from typing import Optional

import torch
from torch import Tensor
from pytorch_lightning import Trainer, LightningModule

from numalogic.tools.callbacks import ConsoleLogger
from numalogic.tools.data import inverse_window


class VAETrainer(Trainer):
"""A PyTorch Lightning Trainer for VAE models.
Args:
----
max_epochs: The maximum number of epochs to train for. (default: 100)
logger: Whether to use a console logger to log metrics. (default: True)
log_freq: The number of epochs between logging. (default: 5)
check_val_every_n_epoch: The number of epochs between validation checks. (default: 5)
enable_checkpointing: Whether to enable checkpointing. (default: False)
enable_progress_bar: Whether to enable the progress bar. (default: False)
enable_model_summary: Whether to enable the model summary. (default: False)
**trainer_kw: Additional keyword arguments to pass to the Lightning Trainer.
"""

def __init__(
self,
max_epochs: int = 100,
logger: bool = True,
log_freq: int = 5,
check_val_every_n_epoch: int = 5,
enable_checkpointing: bool = False,
enable_progress_bar: bool = False,
enable_model_summary: bool = False,
**trainer_kw
):
if not sys.warnoptions:
warnings.simplefilter("ignore", category=UserWarning)

if logger:
logger = ConsoleLogger(log_freq=log_freq)

super().__init__(
logger=logger,
max_epochs=max_epochs,
check_val_every_n_epoch=check_val_every_n_epoch,
enable_checkpointing=enable_checkpointing,
enable_progress_bar=enable_progress_bar,
enable_model_summary=enable_model_summary,
**trainer_kw
)

def predict(self, model: Optional[LightningModule] = None, unbatch=True, **kwargs) -> Tensor:
r"""Predicts the output of the model.
Args:
----
model: The model to predict with. (default: None)
unbatch: Whether to inverse window the output. (default: True)
**kwargs: Additional keyword arguments to pass to the Lightning
trainers predict method.
"""
recon_err = super().predict(model, **kwargs)
recon_err = torch.vstack(recon_err)
if unbatch:
return inverse_window(recon_err, method="keep_last")
return recon_err
3 changes: 3 additions & 0 deletions numalogic/models/vae/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from numalogic.models.vae.variants.conv import Conv1dVAE

__all__ = ["Conv1dVAE"]
Loading

0 comments on commit 2dfd84c

Please sign in to comment.