Skip to content

Commit

Permalink
[kunlunxin] update mixtral patch for py39 upgrade (#296)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangsanfeng2022 <[email protected]>
  • Loading branch information
brianlcy123 and zhangsanfeng2022 authored Dec 24, 2024
1 parent 1011644 commit f2bc020
Showing 1 changed file with 8 additions and 100 deletions.
108 changes: 8 additions & 100 deletions hardware/kunlunxin_R300p/237377e9/237377e9.patch
Original file line number Diff line number Diff line change
@@ -1,117 +1,25 @@
From 115b26cc46200236cccfe072cf0049b39853b168 Mon Sep 17 00:00:00 2001
From cd33c8caedf1a87a250ab5d03554397f8fc05293 Mon Sep 17 00:00:00 2001
From: brianlcy123 <[email protected]>
Date: Sun, 24 Nov 2024 19:12:03 +0800
Subject: [PATCH] [kunlunxin] add patch for mixtral
Date: Tue, 24 Dec 2024 10:19:56 +0800
Subject: [PATCH] [kunlunxin] update mixtral patch for py39

---
.../megatron/core/dist_checkpointing/strategies/base.py | 4 ++--
.../megatron/core/distributed/param_and_grad_buffer.py | 7 ++++++-
megatron/megatron/core/transformer/moe/moe_utils.py | 6 +++---
megatron/megatron/core/transformer/moe/token_dispatcher.py | 4 ++--
megatron/megatron/training/checkpointing.py | 3 ++-
5 files changed, 15 insertions(+), 9 deletions(-)
megatron/megatron/training/checkpointing.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/megatron/megatron/core/dist_checkpointing/strategies/base.py b/megatron/megatron/core/dist_checkpointing/strategies/base.py
index cc1c83b9..125779a0 100644
--- a/megatron/megatron/core/dist_checkpointing/strategies/base.py
+++ b/megatron/megatron/core/dist_checkpointing/strategies/base.py
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
-from typing import Any, DefaultDict
+from typing import Any, DefaultDict, Dict, Tuple

from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncCallsQueue, AsyncRequest
@@ -20,7 +20,7 @@ class StrategyAction(Enum):


_import_trigger = None
-default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)
+default_strategies: DefaultDict[str, Dict[Tuple, Any]] = defaultdict(dict)

async_calls = AsyncCallsQueue()

diff --git a/megatron/megatron/core/distributed/param_and_grad_buffer.py b/megatron/megatron/core/distributed/param_and_grad_buffer.py
index 77ecd7be..c2761c6e 100644
--- a/megatron/megatron/core/distributed/param_and_grad_buffer.py
+++ b/megatron/megatron/core/distributed/param_and_grad_buffer.py
@@ -248,6 +248,11 @@ class ParamAndGradBuffer:
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)

+ import math
+
+ def _lcm(a, b):
+ return abs(a * b) // math.gcd(a, b)
+
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
@@ -257,7 +262,7 @@ class ParamAndGradBuffer:
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
- return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
+ return _pad(bucket_end_index, _lcm(self.data_parallel_world_size, 128))
return bucket_end_index

def _pad_start_of_param_if_needed(param_start_index: int) -> int:
diff --git a/megatron/megatron/core/transformer/moe/moe_utils.py b/megatron/megatron/core/transformer/moe/moe_utils.py
index ee4bb690..a3c1fd69 100644
--- a/megatron/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/megatron/core/transformer/moe/moe_utils.py
@@ -366,8 +366,8 @@ def topk_softmax_with_capacity(

if capacity_factor is None:
# TopK without capacity
- tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts)
- return probs, top_indices, tokens_per_expert
+ tokens_per_expert = torch.bincount(top_indices.cpu().view(-1), minlength=num_experts)
+ return probs, top_indices, tokens_per_expert.cuda()
else:
# TopK with capacity
expert_capacity = get_capacity(
@@ -380,7 +380,7 @@ def topk_softmax_with_capacity(
# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
- topk_masked_gates, k=expert_capacity, dim=0, sorted=False
+ topk_masked_gates, k=expert_capacity, dim=0, sorted=True #mod by zh
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
diff --git a/megatron/megatron/core/transformer/moe/token_dispatcher.py b/megatron/megatron/core/transformer/moe/token_dispatcher.py
index 84f3d450..6a0b4a28 100644
--- a/megatron/megatron/core/transformer/moe/token_dispatcher.py
+++ b/megatron/megatron/core/transformer/moe/token_dispatcher.py
@@ -179,10 +179,10 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):

with torch.no_grad():
tokens_per_expert = torch.bincount(
- local_indices.view(-1), minlength=self.config.num_moe_experts
+ local_indices.cpu().view(-1), minlength=self.config.num_moe_experts
)
if self.num_local_experts < self.config.num_moe_experts:
- tokens_per_expert = tokens_per_expert[
+ tokens_per_expert = tokens_per_expert.cuda()[
self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
diff --git a/megatron/megatron/training/checkpointing.py b/megatron/megatron/training/checkpointing.py
index 6e58b317..6c650c4e 100644
index 6e58b317..7906ea88 100644
--- a/megatron/megatron/training/checkpointing.py
+++ b/megatron/megatron/training/checkpointing.py
@@ -1057,7 +1057,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
restore_modelopt_state(model, state_dict)

# Model.
- strict = False if args.retro_add_retriever else strict
+ # strict = False if args.retro_add_retriever else strict
+ #strict = False if args.retro_add_retriever else strict
+ strict = False
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
--
2.25.1
2.34.1

0 comments on commit f2bc020

Please sign in to comment.