Skip to content

Commit

Permalink
Update the Gaudi trainer with transformers 4.45.2 (huggingface#1398)
Browse files Browse the repository at this point in the history
  • Loading branch information
yafshar authored and zzhang37 committed Dec 10, 2024
1 parent 14e473e commit 33a718f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 183 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@
gaudi_gpt_neox_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
Expand Down Expand Up @@ -679,7 +679,7 @@ def adapt_transformers_to_gaudi():
transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaRMSNorm.forward = gaudi_llama_rmsnorm_forward
transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state

# Optimization for Whisper on Gaudi
transformers.models.whisper.modeling_whisper.WhisperSdpaAttention = GaudiWhisperSdpaAttention
transformers.models.whisper.modeling_whisper.WhisperDecoderLayer = GaudiWhisperDecoderLayer
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@
from .llava import GaudiLlavaForConditionalGeneration
from .llava_next import GaudiLlavaNextForConditionalGeneration
from .mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
)
from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM
from .mistral import (
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
)
206 changes: 38 additions & 168 deletions optimum/habana/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
from typing import Any, Dict, Optional, Union
import os
from pathlib import Path
from typing import Any, Dict, Optional

import torch
from transformers.utils import (
ModelOutput,
logging,
)
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import MambaCache
from transformers.configuration_utils import PretrainedConfig
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.activations import ACT2FN
from transformers.utils import (
ModelOutput,
logging,
)

from pathlib import Path
import os

my_dir = os.path.realpath(__file__)
my_len = my_dir.rfind("/")
base_dir = os.environ.get('HABANA_CUSTOM_OP_DIR', my_dir[:my_len])
base_dir = os.environ.get("HABANA_CUSTOM_OP_DIR", my_dir[:my_len])

custom_op_lib_path = str(
next(
Path(base_dir).glob("hpu_custom_pscan_all.cpython-*-x86_64-linux-gnu.so")
)
)
custom_op_lib_path = str(next(Path(base_dir).glob("hpu_custom_pscan_all.cpython-*-x86_64-linux-gnu.so")))
torch.ops.load_library(custom_op_lib_path)

logger = logging.get_logger(__name__)
Expand All @@ -35,18 +27,17 @@
use_pscan_kernel = False
if os.path.exists(custom_op_lib_path):
use_pscan_kernel = True

def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):

in_state_h = in_state.unsqueeze(1).transpose(2,3)
in_x_h = in_x.transpose(1,2).unsqueeze(2)

def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
in_state_h = in_state.unsqueeze(1).transpose(2, 3)
in_x_h = in_x.transpose(1, 2).unsqueeze(2)
in_dt_h = in_dt.unsqueeze(2)
in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2,3)
in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3)
in_B_h = in_B.unsqueeze(3)
in_C_h = in_C.unsqueeze(3)
in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
in_z_h = in_z.transpose(1,2).unsqueeze(2)

in_z_h = in_z.transpose(1, 2).unsqueeze(2)

if in_state.dtype == torch.float:
state_out_h = torch.ops.custom_op.custom_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
Expand All @@ -57,10 +48,9 @@ def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z)
state_out_h = torch.ops.custom_op.custom_pscan_bf16(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.custom_op.custom_pscan_update_bf16(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)


output_hpu = output_h.squeeze(2).transpose(1,2)
state_hpu = state_out_h.transpose(2,3)
state_out = torch.select(state_hpu,1,output_hpu.shape[2]-1)
output_hpu = output_h.squeeze(2).transpose(1, 2)
state_hpu = state_out_h.transpose(2, 3)
state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1)

return output_hpu, state_out

Expand All @@ -72,14 +62,15 @@ def gaudi_MambaCache_update_conv_state(
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
#conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
# conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
for c, i in enumerate(cache_position):
conv_state[:, :, i] = new_conv_state[:,:,c].to(conv_state.device)
conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device)

self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]


def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
) -> Dict[str, Any]:
Expand Down Expand Up @@ -162,6 +153,7 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
)
return model_inputs


class gaudi_MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
Expand Down Expand Up @@ -213,125 +205,12 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.use_bias = config.use_bias

if not is_fast_path_available:
if self.use_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
else:
raise ImportError(
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
)
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)

if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
self.conv1d.weight,
self.conv1d.bias if self.use_conv_bias else None,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias.float() if self.use_bias else None,
-torch.exp(self.A_log.float()),
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states[self.layer_idx],
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)

A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and cache_position[0] > 0:
scan_outputs = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
"""
Expand Down Expand Up @@ -389,12 +268,12 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
if use_pscan_kernel:
scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
ssm_state,
hidden_states,
ssm_state,
hidden_states,
discrete_time_step,
A,
B,
C,
A,
B,
C,
self.D,
gate
)
Expand All @@ -405,21 +284,14 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
if self.use_mambapy and self.training and cache_params is None:
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]

scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
scan_output = scan_output * self.act(gate)
else:
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
Expand All @@ -436,6 +308,4 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
22 changes: 11 additions & 11 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,36 +218,36 @@ def _test_text_generation(
command += ["--use_flash_attention"]

if "mamba" in model_name.lower():
from huggingface_hub import hf_hub_download
import subprocess

from huggingface_hub import hf_hub_download

cmd1 = subprocess.Popen(["pip", "list"], stdout=subprocess.PIPE)
cmd2 = subprocess.Popen(["grep", "habana-torch-plugin"], stdin=cmd1.stdout, stdout=subprocess.PIPE)
cmd1.stdout.close()
version_no, _ = cmd2.communicate()
version_no, _ = cmd2.communicate()

name_op = 'hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu.so'
name_kernel = 'libcustom_tpc_perf_lib.so'
name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu.so"
name_kernel = "libcustom_tpc_perf_lib.so"
if "1.19.0" in version_no.decode():
name_op = 'hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu_119.so'
name_kernel = 'libcustom_tpc_perf_lib_119.so'
name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu_119.so"
name_kernel = "libcustom_tpc_perf_lib_119.so"

file_op = hf_hub_download(repo_id="Habana/mamba", filename=name_op)
file_kernel = hf_hub_download(repo_id="Habana/mamba", filename=name_kernel)
file_kernel = hf_hub_download(repo_id="Habana/mamba", filename=name_kernel)

new_file_op = file_op
new_file_kernel = file_kernel

if "1.19.0" in version_no.decode():
new_file_op = file_op[:-7] + '.so'
new_file_kernel = file_kernel[:-7] + '.so'
new_file_op = file_op[:-7] + ".so"
new_file_kernel = file_kernel[:-7] + ".so"
os.rename(file_op, new_file_op)
os.rename(file_kernel, new_file_kernel)

env_variables["HABANA_CUSTOM_OP_DIR"] = os.path.dirname(new_file_op)
default_path = env_variables["GC_KERNEL_PATH"]
env_variables["GC_KERNEL_PATH"] = new_file_kernel + os.pathsep + default_path

env_variables["GC_KERNEL_PATH"] = new_file_kernel + os.pathsep + default_path

if (reuse_cache or torch_compile) and not parallel_strategy == "tp" and not is_starcoder_first_gen_model:
command += ["--reuse_cache"]
Expand Down

0 comments on commit 33a718f

Please sign in to comment.