Skip to content

Commit

Permalink
allow fp16 embeddings with mining
Browse files Browse the repository at this point in the history
  • Loading branch information
heffernankevin committed Jun 16, 2023
1 parent faf08e8 commit 67565cb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions source/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ def EncodeFile(


# Load existing embeddings
def EmbedLoad(fname, dim=1024, verbose=False):
x = np.fromfile(fname, dtype=np.float32, count=-1)
def EmbedLoad(fname, dim=1024, verbose=False, fp16=False):
x = np.fromfile(fname, dtype=(np.float16 if fp16 else np.float32), count=-1)
x.resize(x.shape[0] // dim, dim)
if verbose:
print(" - Embeddings: {:s}, {:d}x{:d}".format(fname, x.shape[0], dim))
Expand Down
8 changes: 5 additions & 3 deletions source/mine_bitexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def score_candidates(x, y, candidate_inds, fwd_mean, bwd_mean, margin, verbose=F
help='Precomputed target sentence embeddings')
parser.add_argument('--dim', type=int, default=1024,
help='Embedding dimensionality')
parser.add_argument('--fp16', action='store_true',
help='Load precomputed embeddings in float16 format')
args = parser.parse_args()

print('LASER: tool to search, score or mine bitexts')
Expand All @@ -211,12 +213,12 @@ def unique_embeddings(emb, ind, verbose=False):
print(' - unify embeddings: {:d} -> {:d}'.format(len(emb), len(aux)))
return emb[[aux[i] for i in range(len(aux))]]

# load the embeddings
x = EmbedLoad(args.src_embeddings, args.dim, verbose=args.verbose)
# load the embeddings and store as np.float32 (required for FAISS)
x = EmbedLoad(args.src_embeddings, args.dim, verbose=args.verbose, fp16=args.fp16).astype(np.float32)
if args.unify:
x = unique_embeddings(x, src_inds, args.verbose)
faiss.normalize_L2(x)
y = EmbedLoad(args.trg_embeddings, args.dim, verbose=args.verbose)
y = EmbedLoad(args.trg_embeddings, args.dim, verbose=args.verbose, fp16=args.fp16).astype(np.float32)
if args.unify:
y = unique_embeddings(y, trg_inds, args.verbose)
faiss.normalize_L2(y)
Expand Down

0 comments on commit 67565cb

Please sign in to comment.