-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
210 lines (176 loc) · 6.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
from pathlib import Path
import configargparse
from datasets import load_dataset
import torchtext
import torchmetrics
torchtext.disable_torchtext_deprecation_warning()
from dataset.dataset import LoaderConstructor
from dataset.dataset import create_rocstories_dataset, create_alicewonderland_dataset
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from models.lstm import LSTM
from models.xlstm import xLSTM
from models.transformer import Transformer
from trainer import Trainer
from scheduler import ChainedScheduler
def get_args():
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
config_path = Path(__file__).parent / "config.yaml"
parser = configargparse.ArgumentParser(
default_config_files=[config_path],
config_file_parser_class=configargparse.YAMLConfigFileParser,
)
parser.add_argument("--model", type=str, default="xlstm")
parser.add_argument("--dataset", type=str, default="wikitext-2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--max_length", type=int, default=20)
parser.add_argument("--embed-dim", type=int, default=512)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--min-text-length", type=int, default=200)
parser.add_argument("--wandb", type=str2bool, default=False)
parser.add_argument("--chained-scheduler", type=str2bool, default=False)
return parser.parse_args()
def get_model(model, vocab_size, embed_dim, seq_len, output_dim, device):
if model == "lstm":
return LSTM(
vocab_size=vocab_size,
embedding_dim=embed_dim,
hidden_dim=512,
output_dim=output_dim,
num_layers=2,
)
elif model == "transformer":
return Transformer(
vocab_size=vocab_size,
seq_len=seq_len,
embed_dim=embed_dim,
output_dim=output_dim,
num_layers=2,
num_heads=8,
dropout=0.1,
)
elif model == "xlstm":
return xLSTM(
vocab_size=vocab_size,
embed_dim=embed_dim,
seq_len=seq_len,
out_features=output_dim,
device=device,
)
else:
raise ValueError(f"Model {model} not found")
if __name__ == "__main__":
cfg = get_args()
# Load the dataset
if "wikitext" in cfg.dataset:
dataset = load_dataset("wikitext", f"{cfg.dataset}-raw-v1")
for split in dataset.keys():
# Filter out short texts because they are very noisy
dataset[split] = dataset[split].filter(
lambda x: len(x["text"]) > cfg.min_text_length
)
elif cfg.dataset == "rocstories":
dataset = create_rocstories_dataset(os.getcwd())
elif cfg.dataset == "alicewonderland":
dataset = create_alicewonderland_dataset(os.getcwd())
# Construct the dataloaders
lc = LoaderConstructor(
dataset=dataset,
batch_size=cfg.batch_size,
max_length=cfg.max_length,
labels_sequence=False,
min_freq=1 if cfg.dataset == "alicewonderland" else 3,
)
loaders = {}
for loader in ["train", "validation", "test"]:
loaders[loader] = lc.construct_loader(split=loader)
input_size = loaders["train"].dataset.input_size
vocab_size = lc.vocab_size
output_size = lc.output_size
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the model
model = get_model(
model=cfg.model,
vocab_size=vocab_size,
embed_dim=cfg.embed_dim,
seq_len=input_size,
output_dim=output_size,
device=device,
)
# Init wandb logger
if cfg.wandb:
wandb.init(
project="text-generation",
config=cfg,
name=f"{cfg.model}_{cfg.dataset}_lr={cfg.lr}",
)
# Initialize the optimizer, loss function, and accuracy metric
optimizer = optim.Adam(model.parameters(), lr=cfg.lr)
criterion = nn.CrossEntropyLoss()
accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=output_size, top_k=5
).to(device)
trainer = Trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
accuracy=accuracy,
batch_size=cfg.batch_size,
output_dim=output_size,
wandb=cfg.wandb,
device=device,
)
scheduler = None
if cfg.chained_scheduler:
warmup_steps = 4
scheduler = ChainedScheduler(
trainer.optimizer,
T_0=(cfg.epochs - warmup_steps),
T_mul=1,
eta_min=cfg.lr / 10,
gamma=0.5,
max_lr=cfg.lr,
warmup_steps=warmup_steps,
)
# Train the model
model_filename = f"trained_models/{cfg.model}_{cfg.dataset}_lr={str(cfg.lr).replace('.', '_')}_best.pt"
# Create the directory if it doesn't exist
Path("trained_models").mkdir(parents=True, exist_ok=True)
best_valid_loss = float("inf")
for epoch in range(cfg.epochs):
if cfg.wandb:
lr = trainer.optimizer.param_groups[0]["lr"]
wandb.log({"Learning Rate": lr}, step=epoch)
# Training
model.train()
trainer.train_validate_epoch(loaders["train"], epoch, "train")
# Validation
model.eval()
with torch.no_grad():
val_loss = trainer.train_validate_epoch(
loaders["validation"], epoch, "validation"
)
# Save the best model
if val_loss < best_valid_loss:
best_valid_loss = val_loss
torch.save(model.state_dict(), model_filename)
print(f"Model improved, saving model")
if scheduler:
scheduler.step()
torch.save(model.state_dict(), model_filename.replace("best", "lastepoch"))
# Load the best model
model.load_state_dict(torch.load(model_filename))
# Test the model
model.eval()
with torch.no_grad():
trainer.train_validate_epoch(loaders["test"], None, "test")