Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varuna Profiler Issue #8

Open
powern90 opened this issue May 25, 2022 · 1 comment
Open

Varuna Profiler Issue #8

powern90 opened this issue May 25, 2022 · 1 comment

Comments

@powern90
Copy link

powern90 commented May 25, 2022

Hello, I was trying to profile Megatron-LM (which is in example) and I got an error.

My enviroment is
Nvidia NGC Container pytorch 22.02-py3
pytorch 1.11.0
CUDA 11.6.0
Ubuntu 20.04

I got an error like

Traceback (most recent call last):
  File "pretrain_gpt2.py", line 170, in 
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step, 
  File "/workspace/Megatron-LM/megatron/training.py", line 108, in pretrain
    profile = model.profile_all(list(range(1,25)))
  File "/opt/conda/lib/python3.8/site-packages/varuna-0.0.1-py3.8.egg/varuna/profiler.py", line 476, in profile_all
    self.profile(microbatch_sizes, optimizer)
  File "/opt/conda/lib/python3.8/site-packages/varuna-0.0.1-py3.8.egg/varuna/profiler.py", line 755, in profile
    self.profile_mbs(batch_size, optimizer)
  File "/opt/conda/lib/python3.8/site-packages/varuna-0.0.1-py3.8.egg/varuna/profiler.py", line 841, in profile_mbs
    optimizer.step()
  File "/opt/conda/lib/python3.8/site-packages/apex/amp/_process_optimizer.py", line 357, in new_step
    retval = old_step(global_grad_norm=global_grad_norm)
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
TypeError: step() got an unexpected keyword argument 'global_grad_norm'

My run script is

#! /bin/bash

DATA_PATH=/workspace/data/gpt_text_document
GPUS_PER_SERVER=4


python -m varuna.run_varuna --nstages 4 --chunk_size 4 --batch_size 16 \
        --gpus_per_node $GPUS_PER_SERVER --no_morphing --machine_list /workspace/IPs pretrain_gpt2.py \
        --num-layers 16 \
        --hidden-size 1024 \
        --num-attention-heads 16 \
        --seq-length 1024 \
        --max-position-embeddings 1024 \
        --train-iters 100 \
        --lr-decay-iters 100 \
        --data-path $DATA_PATH \
        --distributed-backend gloo \
        --vocab-file /workspace/data/gpt2-vocab.json \
        --merge-file /workspace/data/gpt2-merges.txt \
        --save /workspace/profile \
        --save-interval 1000 \
        --data-impl mmap \
        --split 1000,0,0 \
        --lr 0.00001 \
        --min-lr 1e-5 \
        --lr-decay-style cosine \
        --weight-decay 1e-2 \
        --clip-grad 1.0 \
        --use-cpu-initialization \
        --warmup .05 \
        --fp16 \
        --varuna \
        --profiling

I follow all instructions in README like get exact commit of each apex, Megatron-LM and run patch.

Please let me know what is wrong.
Thank you.

@insujang
Copy link

It seems Varuna apex patch only modifies FusedLAMB optimizer which is not used in Megatron-LM.
Try manually modify FusedAdam optimizer in apex and run as follows.

--- a/apex/optimizers/fused_adam.py
+++ b/apex/optimizers/fused_adam.py
@@ -87,7 +87,7 @@ class FusedAdam(torch.optim.Optimizer):
         else:
             super(FusedAdam, self).zero_grad()

-    def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
+    def step(self, global_grad_norm=-1, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
         """Performs a single optimization step.

         Arguments:
@@ -102,72 +102,72 @@ class FusedAdam(torch.optim.Optimizer):
         if closure is not None:
             loss = closure()

-        for group in self.param_groups:
-            bias_correction = 1 if group['bias_correction'] else 0
-            beta1, beta2 = group['betas']
-
-            # assume same step across group now to simplify things
-            # per parameter step can be easily support by making it tensor, or pass list into kernel
-            if 'step' in group:
-                group['step'] += 1
-            else:
-                group['step'] = 1
-
-            # create lists for multi-tensor apply
-            g_16, p_16, m_16, v_16 = [], [], [], []
-            g_32, p_32, m_32, v_32 = [], [], [], []
-
-            for p in group['params']:
-                if p.grad is None:
-                    continue
-                if p.grad.data.is_sparse:
-                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
-
-                state = self.state[p]
-                # State initialization
-                if len(state) == 0:
-                    # Exponential moving average of gradient values
-                    state['exp_avg'] = torch.zeros_like(p.data)
-                    # Exponential moving average of squared gradient values
-                    state['exp_avg_sq'] = torch.zeros_like(p.data)
-
-                if p.dtype == torch.float16:
-                    g_16.append(p.grad.data)
-                    p_16.append(p.data)
-                    m_16.append(state['exp_avg'])
-                    v_16.append(state['exp_avg_sq'])
-                elif p.dtype == torch.float32:
-                    g_32.append(p.grad.data)
-                    p_32.append(p.data)
-                    m_32.append(state['exp_avg'])
-                    v_32.append(state['exp_avg_sq'])
+        if global_grad_norm == -1:
+            for group in self.param_groups:
+                bias_correction = 1 if group['bias_correction'] else 0
+                beta1, beta2 = group['betas']
+
+                # assume same step across group now to simplify things
+                # per parameter step can be easily support by making it tensor, or pass list into kernel
+                if 'step' in group:
+                    group['step'] += 1
                 else:
-                    raise RuntimeError('FusedAdam only support fp16 and fp32.')
-
-            if(len(g_16) > 0):
-                multi_tensor_applier(self.multi_tensor_adam,
-                                     self._dummy_overflow_buf,
-                                     [g_16, p_16, m_16, v_16],
-                                     group['lr'],
-                                     beta1,
-                                     beta2,
-                                     group['eps'],
-                                     group['step'],
-                                     self.adam_w_mode,
-                                     bias_correction,
-                                     group['weight_decay'])
-            if(len(g_32) > 0):
-                multi_tensor_applier(self.multi_tensor_adam,
-                                     self._dummy_overflow_buf,
-                                     [g_32, p_32, m_32, v_32],
-                                     group['lr'],
-                                     beta1,
-                                     beta2,
-                                     group['eps'],
-                                     group['step'],
-                                     self.adam_w_mode,
-                                     bias_correction,
-                                     group['weight_decay'])
+                    group['step'] = 1

+                # create lists for multi-tensor apply
+                g_16, p_16, m_16, v_16 = [], [], [], []
+                g_32, p_32, m_32, v_32 = [], [], [], []
+
+                for p in group['params']:
+                    if p.grad is None:
+                        continue
+                    if p.grad.data.is_sparse:
+                        raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
+
+                    state = self.state[p]
+                    # State initialization
+                    if len(state) == 0:
+                        # Exponential moving average of gradient values
+                        state['exp_avg'] = torch.zeros_like(p.data)
+                        # Exponential moving average of squared gradient values
+                        state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+                    if p.dtype == torch.float16:
+                        g_16.append(p.grad.data)
+                        p_16.append(p.data)
+                        m_16.append(state['exp_avg'])
+                        v_16.append(state['exp_avg_sq'])
+                    elif p.dtype == torch.float32:
+                        g_32.append(p.grad.data)
+                        p_32.append(p.data)
+                        m_32.append(state['exp_avg'])
+                        v_32.append(state['exp_avg_sq'])
+                    else:
+                        raise RuntimeError('FusedAdam only support fp16 and fp32.')
+
+                if(len(g_16) > 0):
+                    multi_tensor_applier(self.multi_tensor_adam,
+                                        self._dummy_overflow_buf,
+                                        [g_16, p_16, m_16, v_16],
+                                        group['lr'],
+                                        beta1,
+                                        beta2,
+                                        group['eps'],
+                                        group['step'],
+                                        self.adam_w_mode,
+                                        bias_correction,
+                                        group['weight_decay'])
+                if(len(g_32) > 0):
+                    multi_tensor_applier(self.multi_tensor_adam,
+                                        self._dummy_overflow_buf,
+                                        [g_32, p_32, m_32, v_32],
+                                        group['lr'],
+                                        beta1,
+                                        beta2,
+                                        group['eps'],
+                                        group['step'],
+                                        self.adam_w_mode,
+                                        bias_correction,
+                                        group['weight_decay'])

I cannot guarantee if it actually can train models as originally intended btw.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants