PyTorch implementation of Lessons on Parameter Sharing across Layers in Transformers.
Clone this repository.
git clone https://github.com/jaketae/param-share-transformer.git
Navigate to the cloned directory. You can start using the model via
>>> from pshare_transformer import ParameterSharedTransformerEncoder
>>> model = ParameterSharedTransformerEncoder()
By default, the model comes with the following parameters:
ParameterSharedTransformerEncoder(
d_model=512,
nhead=16,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
num_unique_layers=3,
num_total_layers=6,
mode="cycle_rev",
norm=False,
)
You can check which layer is being used in each forward pass by toggling the verbose
argument. By default, verbose
is set to False
. Also note that layer indicies are zero-indexed.
Below is a simple demonstration of the model's behavior when initialized in cycle reverse mode, which is the default configuration.
>>> import torch
>>> x = torch.randn(8, 100, 512) # (batch_size, seq_len, d_model)
>>> from pshare_transformer import ParameterSharedTransformerEncoder
>>> model = ParameterSharedTransformerEncoder()
>>> model(x, verbose=True).shape
layer 0
layer 1
layer 2
layer 2
layer 1
layer 0
torch.Size([8, 100, 512])
The layers are "sandwiched" in the sense that the first layer is called again as the final layer; the second layer, the second to last, and so on.
If the model is initialized in cycle mode, each layer is called again only after all preceding unique layers have been consumed.
>>> model = ParameterSharedTransformerEncoder(mode="cycle")
>>> model(x, verbose=True).shape
layer 0
layer 1
layer 2
layer 0
layer 1
layer 2
torch.Size([8, 100, 512])
In sequence mode, the model simply repeatedly calls a layer until moving onto the next in a sequential fashion.
>>> model = ParameterSharedTransformerEncoder(mode="sequence")
>>> model(x, verbose=True).shape
layer 0
layer 0
layer 1
layer 1
layer 2
layer 2
torch.Size([8, 100, 512])
The authors present three strategies for performing weight sharing on Transformer models: sequence, cycle, and cycle (rev). These strategies are distinct from other parameter sharing schemes that typically assign the same weights to all model sublayers. Parameter shared transformers achieve SOTA performance on the WMT 2014 dataset while significantly saving computation cost.