Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gemma2 training #127

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions experiments/run_mntp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
MistralBiForMNTP,
LlamaBiForMNTP,
GemmaBiForMNTP,
Gemma2BiForMNTP,
Qwen2BiForMNTP,
)

Expand All @@ -80,6 +81,8 @@ def get_model_class(config):
return LlamaBiForMNTP
elif config_class_name == "GemmaConfig":
return GemmaBiForMNTP
elif config_class_name == "Gemma2Config":
return Gemma2BiForMNTP
elif config_class_name == "Qwen2Config":
return Qwen2BiForMNTP
else:
Expand All @@ -97,6 +100,7 @@ def initialize_peft(
"LlamaConfig",
"MistralConfig",
"GemmaConfig",
"Gemma2Config",
"Qwen2Config",
]:
lora_modules = [
Expand Down
1 change: 1 addition & 0 deletions experiments/run_simcse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def initialize_peft(
"LlamaConfig",
"MistralConfig",
"GemmaConfig",
"Gemma2Config",
"Qwen2Config",
]:
lora_modules = [
Expand Down
2 changes: 2 additions & 0 deletions experiments/run_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def prepare_for_tokenization(model, text, pooling_mode="mean"):
]:
text = "[INST] " + text.strip() + " [/INST]"
if model.config._name_or_path in [
"google/gemma-7b-it",
"google/gemma-2-9b-it",
]:
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
Expand Down Expand Up @@ -92,6 +93,7 @@ def initialize_peft(
"LlamaConfig",
"MistralConfig",
"GemmaConfig",
"Gemma2Config",
"Qwen2Config",
]:
lora_modules = [
Expand Down
3 changes: 3 additions & 0 deletions llm2vec/llm2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MistralBiModel,
LlamaBiModel,
GemmaBiModel,
Gemma2BiModel,
Qwen2BiModel,
)

Expand Down Expand Up @@ -69,6 +70,8 @@ def _get_model_class(cls, config_class_name, enable_bidirectional):
return LlamaBiModel
elif config_class_name == "GemmaConfig":
return GemmaBiModel
elif config_class_name == "Gemma2Config":
return Gemma2BiModel
elif config_class_name == "Qwen2Config":
return Qwen2BiModel
else:
Expand Down
1 change: 1 addition & 0 deletions llm2vec/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bidirectional_mistral import MistralBiModel, MistralBiForMNTP
from .bidirectional_llama import LlamaBiModel, LlamaBiForMNTP
from .bidirectional_gemma import GemmaBiModel, GemmaBiForMNTP
from .bidirectional_gemma2 import Gemma2BiModel, Gemma2BiForMNTP
from .bidirectional_qwen2 import Qwen2BiModel, Qwen2BiForMNTP
56 changes: 27 additions & 29 deletions llm2vec/models/bidirectional_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, *args, **kwargs):
class ModifiedGemmaDecoderLayer(GemmaDecoderLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](
Expand Down Expand Up @@ -107,19 +108,26 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache = None,
output_attentions: bool = False,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
# if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
# attention_mask,
Expand All @@ -140,39 +148,31 @@ def _update_causal_mask(
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.zeros(
(sequence_length, target_length), dtype=dtype, device=device
) # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
# Commenting out next 2 lines to disable causal masking
)
# causal_mask = torch.full(
# (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
# )
# if sequence_length != 1:
# causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand All @@ -182,9 +182,7 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask

Expand Down
174 changes: 174 additions & 0 deletions llm2vec/models/bidirectional_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch

from packaging import version
import importlib.metadata

from transformers import Gemma2Model, Gemma2ForCausalLM, Gemma2PreTrainedModel, Gemma2Config
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2DecoderLayer,
Gemma2Attention,
Gemma2FlashAttention2,
Gemma2SdpaAttention,
Gemma2MLP,
Gemma2RMSNorm,
)

from torch import nn
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.utils.import_utils import _is_package_available
from transformers.cache_utils import Cache, StaticCache

from peft import PeftModel

logger = logging.get_logger(__name__)


def is_transformers_attn_greater_or_equal_4_41():
if not _is_package_available("transformers"):
return False

return version.parse(importlib.metadata.version("transformers")) >= version.parse(
"4.41.0"
)


class ModifiedGemma2Attention(Gemma2Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False


class ModifiedGemma2FlashAttention2(Gemma2FlashAttention2):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False


class ModifiedGemma2SdpaAttention(Gemma2SdpaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False


GEMMA2_ATTENTION_CLASSES = {
"eager": ModifiedGemma2Attention,
"flash_attention_2": ModifiedGemma2FlashAttention2,
"sdpa": ModifiedGemma2SdpaAttention,
}


class ModifiedGemma2DecoderLayer(Gemma2DecoderLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)

self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window



class Gemma2BiModel(Gemma2Model):
_no_split_modules = ["ModifiedGemma2DecoderLayer"]

def __init__(self, config: Gemma2Config):
Gemma2PreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
ModifiedGemma2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()

def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None:
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.zeros(
(sequence_length, target_length), dtype=dtype, device=device
)
# causal_mask = torch.full(
# (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
# )
# if sequence_length != 1:
# causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask


class Gemma2BiForMNTP(Gemma2ForCausalLM):
def __init__(self, config):
Gemma2PreTrainedModel.__init__(self, config)
self.model = Gemma2BiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

# getter for PEFT model
def get_model_for_peft(self):
return self.model

# setter for PEFT model
def set_model_for_peft(self, model: PeftModel):
self.model = model

# save the PEFT model
def save_peft_model(self, path):
self.model.save_pretrained(path)
3 changes: 3 additions & 0 deletions llm2vec/models/bidirectional_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LlamaSdpaAttention,
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)

from torch import nn
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(self, *args, **kwargs):
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(self, config: LlamaConfig):
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down
1 change: 1 addition & 0 deletions llm2vec/models/bidirectional_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs):
class ModifiedMistralDecoderLayer(MistralDecoderLayer):
def __init__(self, config: MistralConfig, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
Expand Down
1 change: 1 addition & 0 deletions llm2vec/models/bidirectional_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
class ModifiedQwen2DecoderLayer(Qwen2DecoderLayer):
def __init__(self, config: Qwen2Config, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](
Expand Down