-
Notifications
You must be signed in to change notification settings - Fork 7
/
compress.py
executable file
·107 lines (90 loc) · 3.39 KB
/
compress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
import numpy as np
import argparse
import tqdm
import re
parser = argparse.ArgumentParser()
parser.add_argument("vectors", help="Vectors to compress")
parser.add_argument("words", help="The coresponding words, because we do filtering")
parser.add_argument("-n", default=50000, type=int, help="Number of words to include")
parser.add_argument("-f", type=str, nargs='*', help="Take only words from this file")
def quantize_8bit(data, alpha):
# Normalize data to 0-1
min_val = np.min(data) * alpha
max_val = np.max(data) * alpha
data = data.clip(min_val, max_val)
normalized = (data - min_val) / (max_val - min_val)
# Scale to 0-255 and convert to uint8
quantized = (normalized * 255).astype(np.uint8)
return quantized, min_val, max_val
def dequantize_8bit(quantized, min_val, max_val):
# Convert back to float range 0-1
normalized = quantized.astype(np.float32) / 255
# Scale back to original range
dequantized = normalized * (max_val - min_val) + min_val
return dequantized
def main(args):
vecs = np.load(args.vectors)
n, dim = vecs.shape
if dim != 300:
from sklearn.decomposition import PCA
vecs = PCA(n_components=300).fit_transform(vecs)
# vecs /= np.linalg.norm(vecs, axis=1, keepdims=True)
with open(args.words) as file:
words = file.readlines()
if not args.f:
good_words = set(words)
else:
good_words = set()
print(args.f)
for path in args.f:
with open(path) as file:
good_words |= {line.lower().strip() for line in file}
print(f'{len(good_words)=}')
print(list(good_words)[:3])
print(len(vecs), len(words))
assert len(vecs) == len(words)
included_vectors = []
included_words = []
seen = set()
for vec, word in zip(vecs, tqdm.tqdm(words, total=args.n)):
word = word.lower()
word = re.sub('[^a-z0-9]', '', word)
if not word or word.isdigit():
continue
if word in seen:
continue
if good_words and word not in good_words:
continue
seen.add(word)
included_vectors.append(vec)
included_words.append(word)
if len(included_words) == args.n:
break
x = np.stack(included_vectors)
best_alpha, best_err = 0, 1000
for alpha in tqdm.tqdm(np.linspace(x.std()/np.abs(x).max(), 1)):
compressed, min_val, max_val = quantize_8bit(x, alpha)
restored = dequantize_8bit(compressed, min_val, max_val)
err = np.linalg.norm(x - restored, axis=1) / np.linalg.norm(x, axis=1)
#merr = (err**2).mean()
merr = err.mean()
if merr < best_err:
best_err = merr
best_alpha = alpha
print(f"{alpha}, Mean error: {merr}")
print(f"{best_alpha=}")
compressed, min_val, max_val = quantize_8bit(x, best_alpha)
print("IMPORTANT:")
print(f"min={min_val}, max={max_val}")
restored = dequantize_8bit(compressed, min_val, max_val)
err = np.linalg.norm(x - restored, axis=1) / np.linalg.norm(x, axis=1)
print(f"Mean error: {err.mean()}")
data = compressed.tobytes()
print(f"Size: {len(data)/10**6}MB")
with open(f'{args.vectors}.out', 'wb') as file:
file.write(data)
with open(f'{args.words}.out', 'w') as file:
file.write("\n".join(included_words))
if __name__ == '__main__':
main(parser.parse_args())