diff --git a/transformer_ranker/datacleaner.py b/transformer_ranker/datacleaner.py index 388d7b0..7273c21 100644 --- a/transformer_ranker/datacleaner.py +++ b/transformer_ranker/datacleaner.py @@ -1,12 +1,12 @@ -import torch +import logging +from typing import Dict, List, Optional, Set, Tuple, Type, Union + import datasets -from datasets.dataset_dict import DatasetDict, Dataset +import torch +from datasets.dataset_dict import Dataset, DatasetDict from tokenizers.pre_tokenizers import Whitespace -from .utils import configure_logger - -import logging -from typing import List, Dict, Optional, Set, Union, Tuple, Type +from .utils import configure_logger logger = configure_logger('transformer_ranker', logging.INFO) @@ -189,10 +189,10 @@ def merge_texts(example: Dict[str, str]) -> Dict[str, str]: def _find_task_type(label_column: str, label_type: Union[Type[int], Type[str], Type[list], Type[float]]) -> str: """Determine task type based on the label column's data type.""" label_type_to_task_type = { - int: "sentence classification", # labels can be integers - str: "sentence classification", # or strings e.g. "positive" - list: "word classification", - float: "sentence regression", + int: "text classification", # labels can be integers + str: "text classification", # or strings e.g. "positive" + list: "token classification", + float: "text regression", } task_type = label_type_to_task_type.get(label_type, None) diff --git a/transformer_ranker/embedder.py b/transformer_ranker/embedder.py index 8205b04..81b7a3f 100644 --- a/transformer_ranker/embedder.py +++ b/transformer_ranker/embedder.py @@ -1,11 +1,9 @@ -from transformers import AutoModel, AutoTokenizer -from tokenizers.pre_tokenizers import Whitespace -import torch +from typing import List, Optional, Union +import torch +from tokenizers.pre_tokenizers import Whitespace from tqdm import tqdm -from typing import Optional, List, Union - -from transformers import PreTrainedTokenizerFast +from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast class Embedder: diff --git a/transformer_ranker/ranker.py b/transformer_ranker/ranker.py index a4278c9..26e4c8c 100644 --- a/transformer_ranker/ranker.py +++ b/transformer_ranker/ranker.py @@ -1,15 +1,15 @@ +import logging +from typing import List, Optional, Union + import torch from datasets.dataset_dict import Dataset, DatasetDict from tqdm import tqdm from .datacleaner import DatasetCleaner from .embedder import Embedder -from .estimators import HScore, LogME, KNN +from .estimators import KNN, HScore, LogME from .utils import Result, configure_logger -import logging -from typing import List, Optional, Union - logger = configure_logger('transformer_ranker', logging.INFO) @@ -91,7 +91,7 @@ def run( layer_pooling = "mean" if "mean" in layer_aggregator else None # Sentence pooling is only applied for text classification tasks - effective_sentence_pooling = None if self.task_type == "word classification" else sentence_pooling + effective_sentence_pooling = None if self.task_type == "token classification" else sentence_pooling embedder = Embedder( model=model, @@ -110,7 +110,7 @@ def run( ) # Single list of embeddings for sequence tagging tasks - if self.task_type == "word classification": + if self.task_type == "token classification": embeddings = [word_embedding for sentence_embedding in embeddings for word_embedding in sentence_embedding] @@ -194,15 +194,15 @@ def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None: raise ValueError(f"Unsupported layer pooling: {layer_aggregator}. " f"Use one of the following {valid_layer_aggregators}") - valid_task_types = ["sentence classification", "word classification", "sentence regression"] + valid_task_types = ["text classification", "token classification", "text regression"] if self.task_type not in valid_task_types: raise ValueError("Unable to determine task type of the dataset. Please specify it as a parameter: " - "task_type= \"sentence classification\", \"sentence regression\", or " - "\"word classification\"") + "task_type= \"text classification\", \"token classification\", or " + "\"text regression\"") def _estimate_score(self, estimator, embeddings: torch.Tensor, labels: torch.Tensor) -> float: """Use an estimator to score a transformer""" - regression = self.task_type == "sentence regression" + regression = self.task_type == "text regression" if estimator in ['hscore'] and regression: logger.warning(f'Specified estimator="{estimator}" does not support regression tasks.') diff --git a/transformer_ranker/utils.py b/transformer_ranker/utils.py index 4c0936f..b708630 100644 --- a/transformer_ranker/utils.py +++ b/transformer_ranker/utils.py @@ -1,9 +1,9 @@ import logging -import warnings -from transformers import logging as transformers_logging import operator +import warnings +from typing import Dict, List -from typing import List, Dict +from transformers import logging as transformers_logging def prepare_popular_models(model_size='base') -> List[str]: