PyTorch implementation of L2L execution algorithm from paper Training Large Neural Networks with Constant Memory using a New Execution Algorithm
You need to define a torch model where all layers are specified in ModuleList.
See examples folder
import torch
from torch import nn, optim
class M(nn.Module):
def __init__(self, depth: int, dim: int, hidden_dim: Optional[int] = None):
super().__init__()
hidden_dim = hidden_dim or dim
self.layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(),
)
]
+ [
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(),
)
for i in range(depth)
]
+ [nn.Linear(hidden_dim, dim), nn.Sigmoid()]
)
def forward(self, batch: torch.Tensor) -> torch.Tensor:
x = batch
for l in self.layers:
x = l(x)
return x
model = M(depth=5, dim=40).train() # on CPU
Then, you can use the L2L wrapper over this model.
from layer_to_layer_pytorch.l2l import Layer2Layer
l2l_model = Layer2Layer(
model,
layers_attr="layers", # attribute with ModuleList
microbatch_size=100, # size of a microbatch in a minibatch :) from original paper
verbose=False # enable tqdm
)
And train it, like torch model (almost):
from tqdm.auto import tqdm, trange
x = torch.rand(1_000, 40) # on CPU
y = torch.rand(1_000, 40) # on CPU
losses = []
criterion = nn.MSELoss()
optimizer = optim.AdamW(l2l_model.main_params) # optimizer works with the main model on CPU
for i in trange(2000):
l2l_model.zero_grad()
_ = l2l_model.forward(x)
loss_value: float = l2l_model.compute_loss(y, criterion)
if i % 50 == 0:
tqdm.write(f"[{i}] loss = {loss_value}")
losses.append(loss_value)
l2l_model.backward()
optimizer.step()
l2l_model.update_main_model_params() # Sync params with CPU
Cross-mixes-precision available in init params
from layer_to_layer_pytorch.l2l import Layer2Layer
l2l_model = Layer2Layer(
model,
layers_attr="layers",
microbatch_size=100,
# fp-16
mixed_precision=True,
loss_scale = 128.0
)
And then train the same way π
pip install layer-to-layer-pytorch
or install with Poetry
poetry add layer-to-layer-pytorch
You can see the list of available releases on the GitHub Releases page.
We follow Semantic Versions specification.
This project is licensed under the terms of the MIT
license. See LICENSE for more details.
@misc{layer-to-layer-pytorch,
author = {Roman Tezikov},
title = {PyTorch implementation of L2L execution algorithm},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/TezRomacH/layer-to-layer-pytorch}}
}
@article{Pudipeddi2020TrainingLN,
title={Training Large Neural Networks with Constant Memory using a New Execution Algorithm},
author={Bharadwaj Pudipeddi and Maral Mesmakhosroshahi and J. Xi and S. Bharadwaj},
journal={ArXiv},
year={2020},
volume={abs/2002.05645}
}
This project was generated with python-package-template
.