diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index 5518b15..757d76e 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -209,7 +209,7 @@ def forward( self.apply_rotary(k, 0, attention_mask) ) # keys are cached so no offset - if self.cfg.dtype not in [torch.float32, torch.float64]: + if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]: # If using 16 bits, increase the precision to avoid numerical instabilities q = q.to(torch.float32) k = k.to(torch.float32) diff --git a/TransformerLens/transformer_lens/components/rms_norm.py b/TransformerLens/transformer_lens/components/rms_norm.py index 26d5c7c..9d041a5 100644 --- a/TransformerLens/transformer_lens/components/rms_norm.py +++ b/TransformerLens/transformer_lens/components/rms_norm.py @@ -36,7 +36,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[i def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: - if self.cfg.dtype not in [torch.float32, torch.float64]: + if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]: x = x.to(torch.float32) scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() diff --git a/TransformerLens/transformer_lens/components/rms_norm_pre.py b/TransformerLens/transformer_lens/components/rms_norm_pre.py index 2d2ff57..f08569d 100644 --- a/TransformerLens/transformer_lens/components/rms_norm_pre.py +++ b/TransformerLens/transformer_lens/components/rms_norm_pre.py @@ -26,7 +26,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: - if self.cfg.dtype not in [torch.float32, torch.float64]: + if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]: x = x.to(torch.float32) scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(