-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[kunlunxin] update mixtral patch for py39 upgrade (#296)
Co-authored-by: zhangsanfeng2022 <[email protected]>
- Loading branch information
1 parent
1011644
commit f2bc020
Showing
1 changed file
with
8 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,117 +1,25 @@ | ||
From 115b26cc46200236cccfe072cf0049b39853b168 Mon Sep 17 00:00:00 2001 | ||
From cd33c8caedf1a87a250ab5d03554397f8fc05293 Mon Sep 17 00:00:00 2001 | ||
From: brianlcy123 <[email protected]> | ||
Date: Sun, 24 Nov 2024 19:12:03 +0800 | ||
Subject: [PATCH] [kunlunxin] add patch for mixtral | ||
Date: Tue, 24 Dec 2024 10:19:56 +0800 | ||
Subject: [PATCH] [kunlunxin] update mixtral patch for py39 | ||
|
||
--- | ||
.../megatron/core/dist_checkpointing/strategies/base.py | 4 ++-- | ||
.../megatron/core/distributed/param_and_grad_buffer.py | 7 ++++++- | ||
megatron/megatron/core/transformer/moe/moe_utils.py | 6 +++--- | ||
megatron/megatron/core/transformer/moe/token_dispatcher.py | 4 ++-- | ||
megatron/megatron/training/checkpointing.py | 3 ++- | ||
5 files changed, 15 insertions(+), 9 deletions(-) | ||
megatron/megatron/training/checkpointing.py | 3 ++- | ||
1 file changed, 2 insertions(+), 1 deletion(-) | ||
|
||
diff --git a/megatron/megatron/core/dist_checkpointing/strategies/base.py b/megatron/megatron/core/dist_checkpointing/strategies/base.py | ||
index cc1c83b9..125779a0 100644 | ||
--- a/megatron/megatron/core/dist_checkpointing/strategies/base.py | ||
+++ b/megatron/megatron/core/dist_checkpointing/strategies/base.py | ||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod | ||
from collections import defaultdict | ||
from enum import Enum | ||
from pathlib import Path | ||
-from typing import Any, DefaultDict | ||
+from typing import Any, DefaultDict, Dict, Tuple | ||
|
||
from ..mapping import CheckpointingException, ShardedStateDict, StateDict | ||
from .async_utils import AsyncCallsQueue, AsyncRequest | ||
@@ -20,7 +20,7 @@ class StrategyAction(Enum): | ||
|
||
|
||
_import_trigger = None | ||
-default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) | ||
+default_strategies: DefaultDict[str, Dict[Tuple, Any]] = defaultdict(dict) | ||
|
||
async_calls = AsyncCallsQueue() | ||
|
||
diff --git a/megatron/megatron/core/distributed/param_and_grad_buffer.py b/megatron/megatron/core/distributed/param_and_grad_buffer.py | ||
index 77ecd7be..c2761c6e 100644 | ||
--- a/megatron/megatron/core/distributed/param_and_grad_buffer.py | ||
+++ b/megatron/megatron/core/distributed/param_and_grad_buffer.py | ||
@@ -248,6 +248,11 @@ class ParamAndGradBuffer: | ||
def _pad(number_to_be_padded: int, divisor: int) -> int: | ||
return int(math.ceil(number_to_be_padded / divisor) * divisor) | ||
|
||
+ import math | ||
+ | ||
+ def _lcm(a, b): | ||
+ return abs(a * b) // math.gcd(a, b) | ||
+ | ||
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int: | ||
""" | ||
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). | ||
@@ -257,7 +262,7 @@ class ParamAndGradBuffer: | ||
# This also helps cuBLAS pick more efficient algorithms for GEMMs. | ||
# We now ensure that all buckets start at a memory address that is 256-byte | ||
# aligned (128 values since params and grads use >= 16-bit precision). | ||
- return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128)) | ||
+ return _pad(bucket_end_index, _lcm(self.data_parallel_world_size, 128)) | ||
return bucket_end_index | ||
|
||
def _pad_start_of_param_if_needed(param_start_index: int) -> int: | ||
diff --git a/megatron/megatron/core/transformer/moe/moe_utils.py b/megatron/megatron/core/transformer/moe/moe_utils.py | ||
index ee4bb690..a3c1fd69 100644 | ||
--- a/megatron/megatron/core/transformer/moe/moe_utils.py | ||
+++ b/megatron/megatron/core/transformer/moe/moe_utils.py | ||
@@ -366,8 +366,8 @@ def topk_softmax_with_capacity( | ||
|
||
if capacity_factor is None: | ||
# TopK without capacity | ||
- tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts) | ||
- return probs, top_indices, tokens_per_expert | ||
+ tokens_per_expert = torch.bincount(top_indices.cpu().view(-1), minlength=num_experts) | ||
+ return probs, top_indices, tokens_per_expert.cuda() | ||
else: | ||
# TopK with capacity | ||
expert_capacity = get_capacity( | ||
@@ -380,7 +380,7 @@ def topk_softmax_with_capacity( | ||
# Maskout exceeded tokens | ||
if drop_policy == "probs": | ||
capacity_probs, capacity_indices = torch.topk( | ||
- topk_masked_gates, k=expert_capacity, dim=0, sorted=False | ||
+ topk_masked_gates, k=expert_capacity, dim=0, sorted=True #mod by zh | ||
) | ||
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) | ||
elif drop_policy == "position": | ||
diff --git a/megatron/megatron/core/transformer/moe/token_dispatcher.py b/megatron/megatron/core/transformer/moe/token_dispatcher.py | ||
index 84f3d450..6a0b4a28 100644 | ||
--- a/megatron/megatron/core/transformer/moe/token_dispatcher.py | ||
+++ b/megatron/megatron/core/transformer/moe/token_dispatcher.py | ||
@@ -179,10 +179,10 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher): | ||
|
||
with torch.no_grad(): | ||
tokens_per_expert = torch.bincount( | ||
- local_indices.view(-1), minlength=self.config.num_moe_experts | ||
+ local_indices.cpu().view(-1), minlength=self.config.num_moe_experts | ||
) | ||
if self.num_local_experts < self.config.num_moe_experts: | ||
- tokens_per_expert = tokens_per_expert[ | ||
+ tokens_per_expert = tokens_per_expert.cuda()[ | ||
self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 | ||
] | ||
tokens_per_expert = tokens_per_expert.cpu().to(torch.long) | ||
diff --git a/megatron/megatron/training/checkpointing.py b/megatron/megatron/training/checkpointing.py | ||
index 6e58b317..6c650c4e 100644 | ||
index 6e58b317..7906ea88 100644 | ||
--- a/megatron/megatron/training/checkpointing.py | ||
+++ b/megatron/megatron/training/checkpointing.py | ||
@@ -1057,7 +1057,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri | ||
restore_modelopt_state(model, state_dict) | ||
|
||
# Model. | ||
- strict = False if args.retro_add_retriever else strict | ||
+ # strict = False if args.retro_add_retriever else strict | ||
+ #strict = False if args.retro_add_retriever else strict | ||
+ strict = False | ||
if len(model) == 1: | ||
model[0].load_state_dict(state_dict['model'], strict=strict) | ||
else: | ||
-- | ||
2.25.1 | ||
2.34.1 |