Skip to content

Commit

Permalink
Changed naming for task types
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Oct 26, 2024
1 parent ea65cec commit 86f7ec5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 29 deletions.
20 changes: 10 additions & 10 deletions transformer_ranker/datacleaner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import logging
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import datasets
from datasets.dataset_dict import DatasetDict, Dataset
import torch
from datasets.dataset_dict import Dataset, DatasetDict
from tokenizers.pre_tokenizers import Whitespace
from .utils import configure_logger

import logging
from typing import List, Dict, Optional, Set, Union, Tuple, Type

from .utils import configure_logger

logger = configure_logger('transformer_ranker', logging.INFO)

Expand Down Expand Up @@ -189,10 +189,10 @@ def merge_texts(example: Dict[str, str]) -> Dict[str, str]:
def _find_task_type(label_column: str, label_type: Union[Type[int], Type[str], Type[list], Type[float]]) -> str:
"""Determine task type based on the label column's data type."""
label_type_to_task_type = {
int: "sentence classification", # labels can be integers
str: "sentence classification", # or strings e.g. "positive"
list: "word classification",
float: "sentence regression",
int: "text classification", # labels can be integers
str: "text classification", # or strings e.g. "positive"
list: "token classification",
float: "text regression",
}

task_type = label_type_to_task_type.get(label_type, None)
Expand Down
10 changes: 4 additions & 6 deletions transformer_ranker/embedder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from transformers import AutoModel, AutoTokenizer
from tokenizers.pre_tokenizers import Whitespace
import torch
from typing import List, Optional, Union

import torch
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
from typing import Optional, List, Union

from transformers import PreTrainedTokenizerFast
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast


class Embedder:
Expand Down
20 changes: 10 additions & 10 deletions transformer_ranker/ranker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
from typing import List, Optional, Union

import torch
from datasets.dataset_dict import Dataset, DatasetDict
from tqdm import tqdm

from .datacleaner import DatasetCleaner
from .embedder import Embedder
from .estimators import HScore, LogME, KNN
from .estimators import KNN, HScore, LogME
from .utils import Result, configure_logger

import logging
from typing import List, Optional, Union

logger = configure_logger('transformer_ranker', logging.INFO)


Expand Down Expand Up @@ -91,7 +91,7 @@ def run(
layer_pooling = "mean" if "mean" in layer_aggregator else None

# Sentence pooling is only applied for text classification tasks
effective_sentence_pooling = None if self.task_type == "word classification" else sentence_pooling
effective_sentence_pooling = None if self.task_type == "token classification" else sentence_pooling

embedder = Embedder(
model=model,
Expand All @@ -110,7 +110,7 @@ def run(
)

# Single list of embeddings for sequence tagging tasks
if self.task_type == "word classification":
if self.task_type == "token classification":
embeddings = [word_embedding for sentence_embedding in embeddings
for word_embedding in sentence_embedding]

Expand Down Expand Up @@ -194,15 +194,15 @@ def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None:
raise ValueError(f"Unsupported layer pooling: {layer_aggregator}. "
f"Use one of the following {valid_layer_aggregators}")

valid_task_types = ["sentence classification", "word classification", "sentence regression"]
valid_task_types = ["text classification", "token classification", "text regression"]
if self.task_type not in valid_task_types:
raise ValueError("Unable to determine task type of the dataset. Please specify it as a parameter: "
"task_type= \"sentence classification\", \"sentence regression\", or "
"\"word classification\"")
"task_type= \"text classification\", \"token classification\", or "
"\"text regression\"")

def _estimate_score(self, estimator, embeddings: torch.Tensor, labels: torch.Tensor) -> float:
"""Use an estimator to score a transformer"""
regression = self.task_type == "sentence regression"
regression = self.task_type == "text regression"
if estimator in ['hscore'] and regression:
logger.warning(f'Specified estimator="{estimator}" does not support regression tasks.')

Expand Down
6 changes: 3 additions & 3 deletions transformer_ranker/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import warnings
from transformers import logging as transformers_logging
import operator
import warnings
from typing import Dict, List

from typing import List, Dict
from transformers import logging as transformers_logging


def prepare_popular_models(model_size='base') -> List[str]:
Expand Down

0 comments on commit 86f7ec5

Please sign in to comment.