Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jun 6, 2024
1 parent 92acc6d commit 5bb6331
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Type
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type

import numpy as np
from llama_index.core.bridge.pydantic import Field, PrivateAttr
Expand All @@ -22,7 +22,7 @@ class HitRate(BaseRetrievalMetric):
metric_name (str): The name of the metric.
"""

metric_name: str = "hit_rate"
metric_name: ClassVar[str] = "hit_rate"
use_granular_hit_rate: bool = False

def compute(
Expand Down Expand Up @@ -81,7 +81,7 @@ class MRR(BaseRetrievalMetric):
metric_name (str): The name of the metric.
"""

metric_name: str = "mrr"
metric_name: ClassVar[str] = "mrr"
use_granular_mrr: bool = False

def compute(
Expand Down Expand Up @@ -143,8 +143,8 @@ def compute(
class CohereRerankRelevancyMetric(BaseRetrievalMetric):
"""Cohere rerank relevancy metric."""

metric_name: ClassVar[str] = "cohere_rerank_relevancy"
model: str = Field(description="Cohere model name.")
metric_name: str = "cohere_rerank_relevancy"

_client: Any = PrivateAttr()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional

from llama_index.core.bridge.pydantic import BaseModel, Field

Expand Down Expand Up @@ -30,7 +30,7 @@ def __float__(self) -> float:
class BaseRetrievalMetric(BaseModel, ABC):
"""Base class for retrieval metrics."""

metric_name: str
metric_name: ClassVar[str]

@abstractmethod
def compute(
Expand Down
20 changes: 8 additions & 12 deletions llama-index-core/llama_index/core/indices/omni_modal/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,8 @@ def retrieve_multi_modal(
query_bundle: QueryBundle,
*,
query_type: KQ,
doc_types: Optional[
Collection[KD]
] = None, # Defaults to all document modalities
# Defaults to all document modalities
doc_types: Optional[Collection[KD]] = None,
) -> List[NodeWithScore]:
if doc_types is None:
doc_types = self._embed_model.document_modalities.keys()
Expand Down Expand Up @@ -479,9 +478,8 @@ async def aretrieve_multi_modal(
query_bundle: QueryBundle,
*,
query_type: KQ,
doc_types: Optional[
Collection[KD]
] = None, # Defaults to all document modalities
# Defaults to all document modalities
doc_types: Optional[Collection[KD]] = None,
) -> List[NodeWithScore]:
if doc_types is None:
doc_types = self._embed_model.document_modalities.keys()
Expand Down Expand Up @@ -580,9 +578,8 @@ def retrieve(
self,
str_or_query_bundle: QueryType,
*,
doc_types: Optional[
Collection[KD]
] = None, # Defaults to all document modalities
# Defaults to all document modalities
doc_types: Optional[Collection[KD]] = None,
) -> List[NodeWithScore]:
query_bundle = self._as_query_bundle(
str_or_query_bundle, query_type=Modalities.TEXT.key
Expand All @@ -598,9 +595,8 @@ async def aretrieve(
self,
str_or_query_bundle: QueryType,
*,
doc_types: Optional[
Collection[KD]
] = None, # Defaults to all document modalities
# Defaults to all document modalities
doc_types: Optional[Collection[KD]] = None,
) -> List[NodeWithScore]:
query_bundle = self._as_query_bundle(
str_or_query_bundle, query_type=Modalities.TEXT.key
Expand Down

0 comments on commit 5bb6331

Please sign in to comment.