forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
528 lines (444 loc) · 19.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import gc
import os
import time
from dataclasses import dataclass, field
from datetime import timedelta
from io import BytesIO
from timeit import default_timer as timer
from typing import Any, Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import (
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
)
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import (
Color,
dist_max,
dist_mean,
get_metrics_rank,
get_num_flop_per_token,
get_num_params,
get_peak_flops,
init_distributed,
NoColor,
set_pg_timeouts,
)
@dataclass
class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)
def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}
def load_state_dict(self, state_dict) -> None:
self.step = state_dict["step"].item()
state_dict["global_avg_losses"].seek(0)
self.global_avg_losses = torch.load(
state_dict["global_avg_losses"], weights_only=False
)
state_dict["global_max_losses"].seek(0)
self.global_max_losses = torch.load(
state_dict["global_max_losses"], weights_only=False
)
state_dict["log_steps"].seek(0)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""
def _build_optimizer(model):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused
# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")
return optimizer
class OptimizersContainer:
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""
def __init__(self, optimizers):
self.optimizers = optimizers
def step(self):
for optimizer in self.optimizers:
optimizer.step()
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
return OptimizersContainer([_build_optimizer(model) for model in model_parts])
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)
yield
return context
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")
# used for colorful printing
color = Color if job_config.metrics.enable_color_printing else NoColor
# take control of garbage collection to avoid stragglers
_gc_freq = job_config.training.gc_freq
gc.disable()
gc.collect(1)
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)
# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
model_name = job_config.model.name
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
# build dataloader
data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
tokenizer,
job_config.training.batch_size,
job_config.training.seq_len,
dp_degree,
dp_rank,
)
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)
# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
# 3. max_seq_len base on inputs
model_config.norm_type = job_config.model.norm_type
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len
logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
# log model size
model_param_count = get_num_params(whole_model)
num_flop_per_token = get_num_flop_per_token(
get_num_params(whole_model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
logger.info(
f"{color.blue}Model {model_name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)
# initialize GPU memory monitor before applying parallelisms to the model
gpu_memory_monitor = build_gpu_memory_monitor()
# obtain the peak flops of bf16 type for MFU calculation
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name)
if parallel_dims.pp_enabled:
stages, model_parts = models_pipelining_fns[model_name](
whole_model, world_mesh, parallel_dims, job_config, device, model_config
)
else:
# In 1D/2D cases or PP with simple schedules, model_parts is just one item
# for PP with looped schedules, each item is one stage-model-chunk
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
model_parts = [whole_model]
# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
for model in model_parts:
model.to_empty(device=init_device)
if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(
job_config, parallel_dims, stages, loss_fn
)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
whole_model.init_weights()
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
f"GPU memory usage for model: "
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)
# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
metric_logger = build_metric_logger(
job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims)
)
train_state = TrainState()
# train loop
for model in model_parts:
model.train()
# load initial checkpoint
checkpoint = CheckpointManager(
model_parts=model_parts,
optimizers=optimizers.optimizers,
lr_schedulers=lr_schedulers.schedulers,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
)
if job_config.checkpoint.create_seed_checkpoint:
assert (
world_size == 1
), "Must create seed-checkpoint using one gpu, to disable sharding"
checkpoint.save(curr_step=0, force=True)
logger.info("Created seed checkpoint")
return
checkpoint_loaded = checkpoint.load()
if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)
# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
if train_state.step > 0:
for idx, step in enumerate(train_state.log_steps):
metrics = {
"loss_metrics/global_avg_loss": train_state.global_avg_losses[idx],
"loss_metrics/global_max_loss": train_state.global_max_losses[idx],
}
metric_logger.log(metrics, step=step)
data_iterator = iter(data_loader)
checkpoint.reset()
# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
ntokens_since_last_log = 0
data_loading_times: List[float] = []
time_last_log = timer()
gpu_memory_monitor.reset_peak_stats()
# train loop
logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
if train_state.step > 1 and train_state.step % _gc_freq == 0:
gc.collect(1)
# get batch
data_load_start = timer()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
data_loading_times.append(timer() - data_load_start)
input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
pp_schedule.step()
# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# optimizer step
checkpoint.wait_for_staging()
optimizers.step()
lr_schedulers.step()
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config)
losses_since_last_log.append(loss)
# log metrics
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [loss.item() for loss in losses_since_last_log]
avg_loss, max_loss = (
np.mean(losses),
np.max(losses),
)
if parallel_dims.dp_enabled:
global_avg_loss, global_max_loss = (
dist_mean(avg_loss, dp_mesh).item(),
dist_max(max_loss, dp_mesh).item(),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
train_state.log_steps.append(train_state.step)
train_state.global_avg_losses.append(global_avg_loss)
train_state.global_max_losses.append(global_max_loss)
time_delta = timer() - time_last_log
# tokens per second, abbr. as wps by convention
wps = ntokens_since_last_log / (
time_delta * parallel_dims.model_parallel_size
)
# model FLOPS utilization
# For its definition and calculation, please refer to the PaLM paper:
# https://arxiv.org/abs/2204.02311
mfu = 100 * num_flop_per_token * wps / gpu_peak_flops
time_end_to_end = time_delta / job_config.metrics.log_freq
time_data_loading = np.mean(data_loading_times)
time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"wps": wps,
"mfu(%)": mfu,
"time_metrics/end_to_end(s)": time_end_to_end,
"time_metrics/data_loading(s)": time_data_loading,
"time_metrics/data_loading(%)": time_data_loading_pct,
"memory/max_active(GiB)": gpu_mem_stats.max_active_gib,
"memory/max_active(%)": gpu_mem_stats.max_active_pct,
"memory/max_reserved(GiB)": gpu_mem_stats.max_reserved_gib,
"memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct,
"memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries,
"memory/num_ooms": gpu_mem_stats.num_ooms,
}
metric_logger.log(metrics, step=train_state.step)
logger.info(
f"{color.cyan}step: {train_state.step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}wps: {round(wps):,} "
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)
losses_since_last_log.clear()
ntokens_since_last_log = 0
data_loading_times.clear()
time_last_log = timer()
gpu_memory_monitor.reset_peak_stats()
checkpoint.save(
train_state.step, force=(train_state.step == job_config.training.steps)
)
# signals the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if memory_profiler:
memory_profiler.step()
# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished)
if train_state.step == 1:
set_pg_timeouts(
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
world_mesh=world_mesh,
)
if torch.distributed.get_rank() == 0:
logger.info("Sleeping 2 seconds for other ranks to complete")
time.sleep(2)
metric_logger.close()
logger.info("Training completed")
if __name__ == "__main__":
config = JobConfig()
config.parse_args()
main(config)
destroy_process_group()