Skip to content

Commit

Permalink
Fixed merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 29, 2024
2 parents 749bfbd + 3de37ed commit 1ba3e85
Show file tree
Hide file tree
Showing 30 changed files with 2,104 additions and 1,035 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest"]
python-version: ["3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
exclude:
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.11"
- os: windows-latest
python-version: "3.12"
fail-fast: false

steps:
Expand Down Expand Up @@ -42,8 +49,7 @@ jobs:
# Install dependencies using uv pip
- name: Install dependencies
run: make install
# run: uv pip install -e ".[pytest]"
run: make install-no-pre-commit

# Run tests with coverage
- name: Run tests under coverage
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ models
checkpoints/*
features/*
model2vec_models
results/*
counts/*
results_old/*
local/*
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ install:
uv sync --all-extras
uv run pre-commit install

install-no-pre-commit:
uv pip install ".[dev,distill]"

install-base:
uv sync --extra dev

Expand Down
297 changes: 180 additions & 117 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/images/logo_v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/images/speed_vs_accuracy_v4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/images/speed_vs_mteb_score.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/images/speed_vs_mteb_score_v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions model2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from model2vec.distill import distill
from model2vec.model import StaticModel
from model2vec.version import __version__

__all__ = ["distill", "StaticModel"]
__all__ = ["StaticModel", "__version__"]
7 changes: 7 additions & 0 deletions model2vec/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from model2vec.utils import get_package_extras, importable

_REQUIRED_EXTRA = "distill"

for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.distill.distillation import distill, distill_from_model

__all__ = ["distill", "distill_from_model"]
53 changes: 0 additions & 53 deletions model2vec/distill/__main__.py

This file was deleted.

16 changes: 13 additions & 3 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import Literal
from typing import Literal, Union

import numpy as np
from huggingface_hub import model_info
Expand All @@ -26,7 +28,7 @@
logger = logging.getLogger(__name__)


PCADimType = int | None | Literal["auto"]
PCADimType = Union[int, None, Literal["auto"]]


def distill_from_model(
Expand All @@ -53,7 +55,7 @@ def distill_from_model(
:param device: The device to use.
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
Expand Down Expand Up @@ -214,9 +216,17 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply
elif pca_dims <= embeddings.shape[1]:
logger.info(f"Applying PCA with n_components {pca_dims}")

orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, whiten=False)
embeddings = p.fit_transform(embeddings)

if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")

if apply_zipf:
logger.info("Applying Zipf weighting")
embeddings *= np.log(1 + np.arange(embeddings.shape[0]))[:, None]
Expand Down
47 changes: 33 additions & 14 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import inspect
import logging
from pathlib import Path
from typing import Protocol
from typing import Protocol, Union

import numpy as np
import torch
Expand All @@ -12,7 +15,7 @@
logger = logging.getLogger(__name__)


PathLike = str | Path
PathLike = Union[Path, str]

_DEFAULT_BATCH_SIZE = 1024

Expand Down Expand Up @@ -112,35 +115,51 @@ def create_output_embeddings_from_model_name(
:return: The tokens and output embeddings.
"""
model = model.to(device)
ids = torch.arange(tokenizer.vocab_size)

# Quick check to see if the tokenizer is consistent.
vocab_length = len(tokenizer.get_vocab())
if vocab_length != tokenizer.vocab_size:
logger.warning(
f"Reported vocab size {tokenizer.vocab_size} is inconsistent with the vocab size {vocab_length}."
)

ids = torch.arange(vocab_length)

# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
eos_token_id, bos_token_id = dummy_encoding[0], dummy_encoding[-1]
bos_token_id, eos_token_id = dummy_encoding[0], dummy_encoding[-1]

eos = torch.full([len(ids)], fill_value=eos_token_id)
bos = torch.full([len(ids)], fill_value=bos_token_id)
eos = torch.full([len(ids)], fill_value=eos_token_id)

stacked = torch.stack([bos, ids, eos], dim=1)
# NOTE: reversing the bos and eos tokens works better on our benchmarks.
stacked = torch.stack([eos, ids, bos], dim=1)

intermediate_weights: list[np.ndarray] = []
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)
with torch.no_grad():
# NOTE: we create these masks because nomic embed requires them.
# Normally, we could set them to None
token_type_ids = torch.zeros_like(batch)
attention_mask = torch.ones_like(batch)
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(
input_ids=batch.to(device), attention_mask=attention_mask, token_type_ids=token_type_ids
)
out: torch.Tensor = encoded.last_hidden_state
# Prepare model inputs
model_inputs = {"input_ids": batch.to(device), "attention_mask": attention_mask}

# Add token_type_ids only if the model supports it
if "token_type_ids" in inspect.getfullargspec(model.forward).args:
model_inputs["token_type_ids"] = torch.zeros_like(batch)

# Perform the forward pass
encoded_output: BaseModelOutputWithPoolingAndCrossAttentions = model(**model_inputs)
out: torch.Tensor = encoded_output.last_hidden_state
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
if out.dtype == torch.bfloat16:
out = out.float()
intermediate_weights.append(out[:, 1].cpu().numpy())

# Add the output to the intermediate weights
intermediate_weights.append(out[:, 1].detach().cpu().numpy())

# Concatenate the intermediate weights
out_weights = np.concatenate(intermediate_weights)

return tokenizer.convert_ids_to_tokens(ids), out_weights
87 changes: 46 additions & 41 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from typing import Any

from tokenizers import Tokenizer

Expand Down Expand Up @@ -36,11 +37,11 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
logger.info("No tokens to remove.")
return Tokenizer.from_str(tokenizer.to_str())

tokenizer_data = json.loads(tokenizer.to_str())
tokenizer_data: dict[str, Any] = json.loads(tokenizer.to_str())

# Find all added tokens
added_tokens = tokenizer_data["added_tokens"]
added_tokens_str = {token["content"] for token in added_tokens}
added_tokens: list[dict[str, Any]] = tokenizer_data.get("added_tokens", [])
added_tokens_str: set[str] = {token["content"] for token in added_tokens}

# Remove all added tokens from the list of tokens to remove.
# Things will go bad if we keep them.
Expand All @@ -49,34 +50,36 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
# Load the vocabulary.
model_type = tokenizer_data["model"]["type"]

match model_type:
case "WordPiece":
# Vocab is a dictionary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed
case "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Removing tokens from a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model_type}")
if model_type == "WordPiece":
# Vocab is a dictionary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed

elif model_type == "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")

elif model_type == "BPE":
raise ValueError("Removing tokens from a BPE tokenizer is not supported.")

else:
raise ValueError(f"Unknown model type {model_type}")

# Reindex the special tokens (i.e., CLS and SEP for BertTokenizers.)
special_tokens_post_processor: dict[str, dict] = tokenizer_data["post_processor"]["special_tokens"]
for token, token_data in special_tokens_post_processor.items():
token_data["ids"] = [reindexed[token] for token in token_data["tokens"]]
added_tokens = tokenizer_data.get("added_tokens", [])
for token_data in added_tokens:
token_data["id"] = reindexed[token_data["content"]]

# Reinitialize the tokenizer from the json.
tokenizer = Tokenizer.from_str(json.dumps(tokenizer_data))
Expand All @@ -97,18 +100,20 @@ def add_tokens(tokenizer: Tokenizer, tokens_to_add: list[str]) -> Tokenizer:

model = data["model"]["type"]

match model:
case "WordPiece":
wordpiece_vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
if token not in wordpiece_vocab:
wordpiece_vocab[token] = len(wordpiece_vocab)
case "Unigram":
raise ValueError("Adding tokens to a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Adding tokens to a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model}")
if model == "WordPiece":
wordpiece_vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
if token not in wordpiece_vocab:
wordpiece_vocab[token] = len(wordpiece_vocab)

elif model == "Unigram":
raise ValueError("Adding tokens to a unigram tokenizer is not supported.")

elif model == "BPE":
raise ValueError("Adding tokens to a BPE tokenizer is not supported.")

else:
raise ValueError(f"Unknown model type {model}")

tokenizer = Tokenizer.from_str(json.dumps(data))

Expand Down
2 changes: 2 additions & 0 deletions model2vec/distill/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from logging import getLogger

import torch
Expand Down
Loading

0 comments on commit 1ba3e85

Please sign in to comment.