Skip to content

Commit

Permalink
mamba hacks, please forgive me
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Mar 8, 2024
1 parent 0524030 commit 29e986e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 27 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ chardet==5.2.0
sentencepiece==0.1.99
protobuf==4.25.2
plotly==5.18.0
spacy-transformers==1.3.4
spacy-transformers==1.3.4
mamba_ssm==1.2.0.post1; sys_platform != 'darwin'
2 changes: 1 addition & 1 deletion scripts/sample_mamba.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"log_interval": 1,
"eval_iters": 10,
"eval_only": false,
"architecture": "ModelTypes.MAMBA",
"architecture": "mamba",
"always_save_checkpoint": false,
"init_from": "scratch",
"wandb_log": true,
Expand Down
5 changes: 4 additions & 1 deletion src/delphi/train/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from llama2c.model import ModelArgs as Llama2ModelArgs
from llama2c.model import Transformer as Llama2Model

from delphi.train.mamba import Mamba, MambaArgs
try:
from delphi.train.mamba import Mamba, MambaArgs
except Exception as e:
print("no mamba for you")


class ModelTypes:
Expand Down
56 changes: 32 additions & 24 deletions src/delphi/train/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,42 @@

import torch
import torch.nn.functional as F
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

_hasmamba = False
try:
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

@dataclass
class MambaArgs(MambaConfig):
pass
_hasmamba = True
except Exception as e:
print("mamaba_ssm not installed")


class Mamba(MambaLMHeadModel):
def __init__(self, params: MambaArgs) -> None:
super().__init__(params)
if _hasmamba:

def forward(
self, input_ids: torch.Tensor, target_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids)
logits = self.lm_head(hidden_states)
self.last_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1
)
@dataclass
class MambaArgs(MambaConfig):
pass

return logits
class Mamba(MambaLMHeadModel):
def __init__(self, params: MambaArgs) -> None:
super().__init__(params)

def estimate_mfu(self, fwdbwd_per_iter, dt):
"""I don't want to implement this"""
return 0
def forward(
self, input_ids: torch.Tensor, target_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids)
logits = self.lm_head(hidden_states)
self.last_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1
)

return logits

def estimate_mfu(self, fwdbwd_per_iter, dt):
"""I don't want to implement this"""
return 0

0 comments on commit 29e986e

Please sign in to comment.