Skip to content

Commit

Permalink
Remove workarounds to have causal_mask in uint8 for GPT2, GPT-J and…
Browse files Browse the repository at this point in the history
… CodeGen (huggingface#592)
  • Loading branch information
regisss authored Dec 12, 2023
1 parent d8ec6d6 commit 54971f0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 241 deletions.
8 changes: 2 additions & 6 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,8 @@ def adapt_transformers_to_gaudi():
transformers.models.bart.modeling_bart.BartForConditionalGeneration.prepare_inputs_for_generation = (
gaudi_BartForConditionalGeneration_prepare_inputs_for_generation
)

# Optimization for codegen generation on Gaudi
# The bias in the CodeGenAttention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8
transformers.models.codegen.modeling_codegen.CodeGenAttention = GaudiCodeGenAttention
transformers.models.codegen.modeling_codegen.CodeGenForCausalLM = GaudiCodeGenForCausalLM
transformers.models.codegen.modeling_codegen.CodeGenModel.forward = gaudi_codegen_model_forward
Expand All @@ -194,8 +193,7 @@ def adapt_transformers_to_gaudi():
# AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced
transformers.models.albert.modeling_albert.AlbertModel.forward = gaudi_albert_forward

# From Transformers 4.27, the bias in the GPT2Attention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26)
# Optimization for GPT2 on Gaudi
transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention
transformers.models.gpt2.modeling_gpt2.GPT2Model.forward = gaudi_gpt2_forward
transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = GaudiGPT2LMHeadModel
Expand All @@ -216,8 +214,6 @@ def adapt_transformers_to_gaudi():
transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding = GaudiOPTLearnedPositionalEmbedding

# Optimization for GPTJ on Gaudi
# The bias in the GPTJAttention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26)
transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention
transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM
transformers.models.gptj.modeling_gptj.GPTJBlock.forward = gaudi_gptj_block_forward
Expand Down
95 changes: 2 additions & 93 deletions optimum/habana/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,17 @@

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenForCausalLM,
apply_rotary_pos_emb,
create_sinusoidal_positions,
logger,
)


class GaudiCodeGenAttention(nn.Module):
def __init__(self, config):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into n_ctx
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].bool()

# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)

attn_weights = torch.matmul(query, key.transpose(-1, -2))

attn_weights = attn_weights / self.scale_attn
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

class GaudiCodeGenAttention(CodeGenAttention):
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
Expand Down
85 changes: 4 additions & 81 deletions optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,16 @@
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, logger
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2LMHeadModel, logger


class GaudiGPT2Attention(torch.nn.Module):
class GaudiGPT2Attention(GPT2Attention):
"""
Copied from GPT2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
The only differences are:
- `self.bias` is a torch.uint8 and not a torch.bool
- it is casted to bool before being used in torch.where
- optimize KV cache
"""

def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = torch.nn.Dropout(config.attn_pdrop)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

# Update hyper params
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
key = key.contiguous()
value = value.contiguous()
Expand All @@ -91,7 +30,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
Expand Down Expand Up @@ -141,7 +80,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
Expand All @@ -168,22 +107,6 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

return attn_output, attn_weights

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand Down
64 changes: 3 additions & 61 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,15 @@
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gptj.modeling_gptj import (
GPTJAttention,
GPTJForCausalLM,
apply_rotary_pos_emb,
create_sinusoidal_positions,
logger,
)


class GaudiGPTJAttention(nn.Module):
def __init__(self, config):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim

def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

class GaudiGPTJAttention(GPTJAttention):
def _attn(
self,
query,
Expand All @@ -82,7 +24,7 @@ def _attn(
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

query = query.contiguous()
key = key.contiguous()
Expand Down

0 comments on commit 54971f0

Please sign in to comment.