Skip to content

Commit

Permalink
Merge pull request #17 from wangyuxinwhy/fix-infer-record-type
Browse files Browse the repository at this point in the history
🐞fix:infer record type bug
  • Loading branch information
wangyuxinwhy authored Jun 21, 2023
2 parents b3d0eae + 93b1ce3 commit 753112c
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "uniem"
version = "0.2.2"
version = "0.2.3"
description = "unified embedding model"
authors = ["wangyuxin <[email protected]>"]
license = "MIT"
Expand Down
31 changes: 31 additions & 0 deletions tests/test_data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from uniem.data_structures import RecordType, infer_record_type


@pytest.mark.parametrize(
'record, expected_record_type',
[
pytest.param(
{'text': 'I like apples', 'text_pos': 'I like oranges'},
RecordType.PAIR,
id='pair',
),
pytest.param(
{'text': 'I like apples', 'text_pos': 'I like oranges', 'source': 'wikipedia'},
RecordType.PAIR,
id='pair_with_extra_fields',
),
pytest.param(
{'text': 'I like apples', 'text_pos': 'I like oranges', 'text_neg': 'I want to eat apples'},
RecordType.TRIPLET,
id='triplet',
),
pytest.param(
{'sentence1': 'I like apples', 'sentence2': 'I like oranges', 'label': 1.0},
RecordType.SCORED_PAIR,
id='scored_pair',
),
],
)
def test_infer_record_type(record: dict, expected_record_type: RecordType):
assert infer_record_type(record) == expected_record_type
4 changes: 2 additions & 2 deletions uniem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RecordType,
ScoredPairRecord,
TripletRecord,
get_record_type,
infer_record_type,
record_type_cls_map,
)
from uniem.types import Tokenizer
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
if record_type:
self.record_type = RecordType(record_type)
else:
self.record_type = get_record_type(dataset[0])
self.record_type = infer_record_type(dataset[0])
self.record_cls = record_type_cls_map[self.record_type]

def __getitem__(self, index: int):
Expand Down
7 changes: 4 additions & 3 deletions uniem/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ class ScoredPairRecord:
label: float


# * Order matters
record_type_cls_map: dict[RecordType, Any] = {
RecordType.PAIR: PairRecord,
RecordType.TRIPLET: TripletRecord,
RecordType.SCORED_PAIR: ScoredPairRecord,
RecordType.TRIPLET: TripletRecord,
RecordType.PAIR: PairRecord,
}


def get_record_type(record: dict) -> RecordType:
def infer_record_type(record: dict) -> RecordType:
record_type_field_names_map = {
record_type: [field.name for field in fields(record_cls)] for record_type, record_cls in record_type_cls_map.items()
}
Expand Down
8 changes: 6 additions & 2 deletions uniem/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ScoredPairCollator,
TripletCollator,
)
from uniem.data_structures import RecordType, get_record_type
from uniem.data_structures import RecordType, infer_record_type
from uniem.model import (
EmbedderForPairInBatchNegTrain,
EmbedderForScoredPairTrain,
Expand All @@ -41,6 +41,7 @@ def __init__(
self,
model_name_or_path: str,
dataset: RawDataset,
record_type: RecordType | str | None = None,
):
self.model_name_or_path = model_name_or_path
self.raw_dataset = dataset
Expand All @@ -54,7 +55,10 @@ def __init__(
self.raw_dataset,
None,
)
self.record_type = get_record_type(self.raw_train_dataset[0]) # type: ignore

record_type = RecordType(record_type) if isinstance(record_type, str) else record_type
self.record_type = record_type or infer_record_type(self.raw_train_dataset[0])

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)

def create_finetune_datasets(
Expand Down
2 changes: 1 addition & 1 deletion uniem/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.1'
__version__ = '0.2.3'

0 comments on commit 753112c

Please sign in to comment.