Skip to content

Commit

Permalink
xSIM++ release
Browse files Browse the repository at this point in the history
Co-authored-by: Mingda Chen <[email protected]>
Co-authored-by: Gustavo Gianotti <[email protected]>
  • Loading branch information
3 people committed Jun 26, 2023
1 parent 1be8b55 commit 5b9820b
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
LASER is a library to calculate and use multilingual sentence embeddings.

**NEWS**
* 2023/06/26 [**xSIM++**](https://arxiv.org/abs/2306.12907) evaluation pipeline and data [**released**](tasks/xsimplusplus/README.md)
* 2022/07/06 Updated LASER models with support for over 200 languages are [**now available**](nllb/README.md)
* 2022/07/06 Multilingual similarity search (**xsim**) evaluation pipeline [**released**](tasks/xsim/README.md)
* 2022/05/03 [**Librivox S2S is available**](tasks/librivox-s2s): Speech-to-Speech translations automatically mined in Librivox [9]
Expand Down
114 changes: 84 additions & 30 deletions source/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@
import pandas
import tempfile
import numpy as np
from pathlib import Path
import itertools
import logging
import sys
from typing import List, Tuple
from typing import List, Tuple, Dict
from tabulate import tabulate
from pathlib import Path
from collections import defaultdict
from xsim import xSIM
from embed import embed_sentences, SentenceEncoder, HuggingFaceEncoder, load_model
from embed import embed_sentences, load_model

logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger('eval')
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("eval")


class Eval:
def __init__(self, args):
Expand All @@ -47,14 +50,13 @@ def __init__(self, args):
self.encoder_args = {
k: v
for k, v in args._get_kwargs()
if k
in ["max_sentences", "max_tokens", "cpu", "fp16", "sort_kind", "verbose"]
if k in ["max_sentences", "max_tokens", "cpu", "sort_kind", "verbose"]
}
self.src_bpe_codes = args.src_bpe_codes
self.tgt_bpe_codes = args.tgt_bpe_codes
self.src_spm_model = args.src_spm_model
self.tgt_spm_model = args.tgt_spm_model
logger.info('loading src encoder')
logger.info("loading src encoder")
self.src_encoder = load_model(
args.src_encoder,
self.src_spm_model,
Expand All @@ -63,7 +65,7 @@ def __init__(self, args):
**self.encoder_args,
)
if args.tgt_encoder:
logger.info('loading tgt encoder')
logger.info("loading tgt encoder")
self.tgt_encoder = load_model(
args.tgt_encoder,
self.tgt_spm_model,
Expand All @@ -72,7 +74,7 @@ def __init__(self, args):
**self.encoder_args,
)
else:
logger.info('encoding tgt using src encoder')
logger.info("encoding tgt using src encoder")
self.tgt_encoder = self.src_encoder
self.tgt_bpe_codes = self.src_bpe_codes
self.tgt_spm_model = self.src_spm_model
Expand All @@ -81,40 +83,63 @@ def __init__(self, args):
self.fp16 = args.fp16
self.margin = args.margin

def _embed(self, tmpdir, langs, encoder, spm_model, bpe_codes) -> List[List[str]]:
def _embed(
self, tmpdir, langs, encoder, spm_model, bpe_codes, tgt_aug_langs=[]
) -> List[List[str]]:
emb_data = []
for lang in langs:
augjson = None
fname = f"{lang}.{self.split}"
infile = os.path.join(self.base_dir, self.corpus, self.split, fname)
outfile = os.path.join(tmpdir, fname)
infile = self.base_dir / self.corpus / self.split / fname
assert infile.exists(), f"{infile} does not exist"
outfile = tmpdir / fname
if lang in tgt_aug_langs:
fname = f"{lang}_augmented.{self.split}"
fjname = f"{lang}_errtype.{self.split}.json"
augment_dir = self.base_dir / self.corpus / (self.split + "_augmented")
augjson = augment_dir / fjname
auginfile = augment_dir / fname
assert augjson.exists(), f"{augjson} does not exist"
assert auginfile.exists(), f"{auginfile} does not exist"
combined_infile = tmpdir / f"combined_{lang}"
with open(combined_infile, "w") as newfile:
for f in [infile, auginfile]:
with open(f) as fin:
newfile.write(fin.read())
infile = combined_infile
embed_sentences(
infile,
outfile,
str(infile),
str(outfile),
encoder=encoder,
spm_model=spm_model,
bpe_codes=bpe_codes,
token_lang=lang if bpe_codes else "--",
buffer_size=self.buffer_size,
fp16=self.fp16,
**self.encoder_args,
)
assert (
os.path.isfile(outfile) and os.path.getsize(outfile) > 0
), f"Error encoding {infile}"
emb_data.append([lang, infile, outfile])
emb_data.append([lang, infile, outfile, augjson])
return emb_data

def _xsim(self, src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt) -> Tuple[int, int]:
err, nbex = xSIM(
def _xsim(
self, src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt, augjson=None
) -> Tuple[int, int, Dict[str, int]]:
return xSIM(
src_emb,
tgt_emb,
margin=self.margin,
dim=self.emb_dimension,
fp16=self.fp16,
eval_text=tgt_txt if not self.index_comparison else None,
augmented_json=augjson,
)
return err, nbex

def calc_xsim(self, embdir, src_langs, tgt_langs, err_sum=0, totl_nbex=0) -> None:
def calc_xsim(
self, embdir, src_langs, tgt_langs, tgt_aug_langs, err_sum=0, totl_nbex=0
) -> None:
outputs = []
src_emb_data = self._embed(
embdir,
Expand All @@ -129,13 +154,19 @@ def calc_xsim(self, embdir, src_langs, tgt_langs, err_sum=0, totl_nbex=0) -> Non
self.tgt_encoder,
self.tgt_spm_model,
self.tgt_bpe_codes,
tgt_aug_langs,
)
aug_df = defaultdict(lambda: defaultdict())
combs = list(itertools.product(src_emb_data, tgt_emb_data))
for (src_lang, src_txt, src_emb), (tgt_lang, tgt_txt, tgt_emb) in combs:
for (src_lang, _, src_emb, _), (tgt_lang, tgt_txt, tgt_emb, augjson) in combs:
if src_lang == tgt_lang:
continue
err, nbex = self._xsim(src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt)
err, nbex, aug_report = self._xsim(
src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt, augjson
)
result = round(100 * err / nbex, 2)
if tgt_lang in tgt_aug_langs:
aug_df[tgt_lang][src_lang] = aug_report
if nbex < self.min_sents:
result = "skipped"
else:
Expand All @@ -154,9 +185,22 @@ def calc_xsim(self, embdir, src_langs, tgt_langs, err_sum=0, totl_nbex=0) -> Non
)
print(
tabulate(
outputs, tablefmt="psql", headers=["dataset", "src-tgt", "xsim", "nbex"]
outputs,
tablefmt="psql",
headers=[
"dataset",
"src-tgt",
"xsim" + ("(++)" if tgt_aug_langs else ""),
"nbex",
],
)
)
for tgt_aug_lang in tgt_aug_langs:
df = pandas.DataFrame.from_dict(aug_df[tgt_aug_lang]).fillna(0).T
print(
f"\nAbsolute error under augmented transformations for: {tgt_aug_lang}"
)
print(f"{tabulate(df, df.columns, floatfmt='.2f', tablefmt='grid')}")

def calc_xsim_nway(self, embdir, langs) -> None:
err_matrix = np.zeros((len(langs), len(langs)))
Expand All @@ -167,12 +211,14 @@ def calc_xsim_nway(self, embdir, langs) -> None:
self.src_spm_model,
self.src_bpe_codes,
)
for i1, (src_lang, src_txt, src_emb) in enumerate(emb_data):
for i2, (tgt_lang, tgt_txt, tgt_emb) in enumerate(emb_data):
for i1, (src_lang, _, src_emb, _) in enumerate(emb_data):
for i2, (tgt_lang, tgt_txt, tgt_emb, _) in enumerate(emb_data):
if src_lang == tgt_lang:
err_matrix[i1, i2] = 0
else:
err, nbex = self._xsim(src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt)
err, nbex, _ = self._xsim(
src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt
)
err_matrix[i1, i2] = 100 * err / nbex
df = pandas.DataFrame(err_matrix, columns=langs, index=langs)
df.loc["avg"] = df.sum() / float(df.shape[0] - 1) # exclude diagonal in average
Expand All @@ -188,16 +234,17 @@ def run_eval(args) -> None:
embed_dir = args.embed_dir
else:
tmp_dir = tempfile.TemporaryDirectory()
embed_dir = tmp_dir.name
embed_dir = Path(tmp_dir.name)
src_langs = sorted(args.src_langs.split(","))
tgt_aug_langs = sorted(args.tgt_aug_langs.split(",")) if args.tgt_aug_langs else []
if evaluation.nway:
evaluation.calc_xsim_nway(embed_dir, src_langs)
else:
assert (
args.tgt_langs
), "Please provide tgt langs when not performing n-way comparison"
tgt_langs = sorted(args.tgt_langs.split(","))
evaluation.calc_xsim(embed_dir, src_langs, tgt_langs)
evaluation.calc_xsim(embed_dir, src_langs, tgt_langs, tgt_aug_langs)
if tmp_dir:
tmp_dir.cleanup() # remove temporary directory

Expand All @@ -208,7 +255,7 @@ def run_eval(args) -> None:
)
parser.add_argument(
"--base-dir",
type=str,
type=Path,
default=None,
help="Base directory for evaluation files",
required=True,
Expand Down Expand Up @@ -244,7 +291,7 @@ def run_eval(args) -> None:
)
parser.add_argument(
"--embed-dir",
type=str,
type=Path,
default=None,
help="Store/load embeddings from specified directory (default temporary)",
)
Expand Down Expand Up @@ -299,6 +346,13 @@ def run_eval(args) -> None:
default=None,
help="Target-side languages for evaluation",
)
parser.add_argument(
"--tgt-aug-langs",
type=str,
default=None,
help="languages with augmented data",
required=False,
)
parser.add_argument(
"--fp16",
action="store_true",
Expand Down
41 changes: 34 additions & 7 deletions source/xsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import typing as tp
import os
import json
from enum import Enum


Expand All @@ -38,14 +39,15 @@ def xSIM(
dim: int = 1024,
fp16: bool = False,
eval_text: str = None,
) -> tp.Tuple[int, int]:
augmented_json: str = None,
) -> tp.Tuple[int, int, tp.Dict[str, int]]:
assert Margin.has_value(margin), f"Margin type: {margin}, is not supported."
if not isinstance(x, np.ndarray):
x = _load_embeddings(x, dim, fp16)
if not isinstance(y, np.ndarray):
y = _load_embeddings(y, dim, fp16)
# calculate xSIM error
return calculate_error(x, y, margin, k, eval_text)
return calculate_error(x, y, margin, k, eval_text, augmented_json)


def _load_embeddings(infile: str, dim: int, fp16: bool = False) -> np.ndarray:
Expand Down Expand Up @@ -111,17 +113,35 @@ def _score_knn(x: np.ndarray, y: np.ndarray, k: int, margin: str) -> np.ndarray:
return indices


def get_transform(augmented_json, closest_neighbor, src):
if (
closest_neighbor in augmented_json
and augmented_json[closest_neighbor]["src"] == src
):
return augmented_json[closest_neighbor]["errtype"]
return "Misaligned"


def calculate_error(
x: np.ndarray,
y: np.ndarray,
margin: str = None,
k: int = 4,
eval_text: str = None,
) -> tp.Tuple[int, int]:
assert (
x.shape == y.shape
), f"number of source {x.shape} / target {y.shape} shapes mismatch"
augmented_json: str = None,
) -> tp.Tuple[int, int, tp.Dict[str, int]]:
if augmented_json:
with open(augmented_json) as f:
augmented_json = json.load(f)
assert (
x.shape[0] < y.shape[0]
), f"Shape mismatch: {x.shape[0]} >= target {y.shape[0]}"
else:
assert (
x.shape == y.shape
), f"number of source {x.shape} / target {y.shape} shapes mismatch, "
nbex = x.shape[0]
augmented_report = {}

# for each x calculate the highest scoring neighbor from y
closest_neighbor = _score_knn(x, y, k, margin)
Expand All @@ -132,7 +152,14 @@ def calculate_error(
for ex in range(nbex):
if lines[ex] != lines[closest_neighbor[ex, 0]]:
err += 1
if augmented_json:
transform = get_transform(
augmented_json,
lines[closest_neighbor[ex, 0]].strip(),
lines[ex].strip(),
)
augmented_report[transform] = augmented_report.get(transform, 0) + 1
else: # calc index error
ref = np.linspace(0, nbex - 1, nbex).astype(int) # [0, nbex)
err = nbex - np.equal(closest_neighbor.reshape(nbex), ref).astype(int).sum()
return err, nbex
return err, nbex, augmented_report
20 changes: 20 additions & 0 deletions tasks/xsimplusplus/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# LASER: xSIM++

This README shows how to calculate the xSIM++ error rate for a given language pair.

xSIM++ is an extension of [xSIM](https://github.com/facebookresearch/LASER/tree/main/tasks/xsim). In comparison to xSIM, this evaluates using target-side data with additional synthetic, hard-to-distinguish examples. You can find more details about it in the publication: [xSIM++: An Improved Proxy to Bitext Mining Performance for Low-Resource Languages](https://arxiv.org/abs/2306.12907).

## Example usage

Simply run the example script `bash ./eval.sh` to download a sample dataset (flores200), download synthetically augmented English evaluation data from Flores, a sample encoder (laser2), and calculate both the sentence embeddings and the xSIM++ error rate for a set of (comma separated) languages.

The evaluation command is similar to xSIM, however there is an additional option to provide the comma-separated list of augmented languages: `--tgt-aug-langs`. These refer
to languages in the chosen evaluation set which also have a separate augmented data file. In addition to the error rate, the script also provides a breakdown of the number of errors by type (e.g. incorrect entity/number etc.).

You can also calculate xsim++ for encoders hosted on [HuggingFace sentence-transformers](https://huggingface.co/sentence-transformers). For example, to use LaBSE you can modify/add the following arguments in the sample script:
```
--src-encoder LaBSE
--use-hugging-face
--embedding-dimension 768
```
Note: for HuggingFace encoders there is no need to specify `--src-spm-model`.
Loading

0 comments on commit 5b9820b

Please sign in to comment.