-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmegatron-v11.patch
29 lines (28 loc) · 1.14 KB
/
megatron-v11.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
diff --git a/megatron/initialize.py b/megatron/initialize.py
index 9adae00..94e764c 100644
--- a/megatron/initialize.py
+++ b/megatron/initialize.py
@@ -103,6 +103,7 @@ def _initialize_distributed():
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
+ print("world:", args.world_size, " rank:", args.rank, " device count: ", device_count, " local_rank:", args.local_rank);
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
diff --git a/megatron/training.py b/megatron/training.py
index 065d8fa..23374e8 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -237,7 +237,11 @@ def backward_step(optimizer, model, loss):
# Backward pass.
timers('backward-backward').start()
- optimizer.zero_grad(set_grads_to_None=True)
+# optimizer.zero_grad(set_grads_to_None=True)
+ if args.fp16:
+ optimizer.zero_grad(set_grads_to_None=True)
+ else:
+ optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else: