diff --git a/optillm/entropy_decoding.py b/optillm/entropy_decoding.py index 30146d7..3a768fc 100644 --- a/optillm/entropy_decoding.py +++ b/optillm/entropy_decoding.py @@ -26,17 +26,30 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis) return entropy, varentropy -def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]: - attention_probs = F.softmax(attention_scores, dim=-1) +def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]: + attention_probs = attention_weights + + # Calculate entropy attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) - attn_varentropy = torch.var(attn_entropy, dim=-1) - attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy) + # Calculate variance of entropy with unbiased=False to avoid df issues + # Also add a check for singleton dimensions + if attn_entropy.size(-1) > 1: + attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False) + else: + attn_varentropy = torch.zeros_like(attn_entropy) + + attn_varentropy = torch.where(torch.isnan(attn_varentropy), + torch.zeros_like(attn_varentropy), + attn_varentropy) + + # Rest remains the same mean_attention = torch.mean(attention_probs, dim=1) agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2)) - - interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) - + + attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0)) + interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3)) + return { "attn_entropy": torch.mean(attn_entropy), "attn_varentropy": torch.mean(attn_varentropy), diff --git a/optillm/litellm_wrapper.py b/optillm/litellm_wrapper.py index 9d19b09..5ac6232 100644 --- a/optillm/litellm_wrapper.py +++ b/optillm/litellm_wrapper.py @@ -24,7 +24,10 @@ class Chat: class Completions: @staticmethod def create(model: str, messages: List[Dict[str, str]], **kwargs): - response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS) + if model.startswith("gemini"): + response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS) + else: + response = completion(model=model, messages=messages, **kwargs) # Convert LiteLLM response to match OpenAI response structure return response diff --git a/setup.py b/setup.py index ae76402..7d91dfe 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="optillm", - version="0.0.6", + version="0.0.7", packages=find_packages(), py_modules=['optillm'], package_data={