-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_toy_trajectory.py
113 lines (91 loc) · 2.88 KB
/
plot_toy_trajectory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
# coding: utf-8
import matplotlib as mpl
pgf_with_rc_fonts = {
"font.size": 10,
}
mpl.rcParams.update(pgf_with_rc_fonts)
import argparse
import json
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import comb
def nceloss_min(k, c=100, b=1):
nceloss_min = []
for _k in k:
l = np.sum(
[
comb(_k, m)
* (1 / c) ** m
* (1 - 1 / c) ** (_k - m)
* np.log(1 + m + (_k - m) * np.exp(-2 * b ** 2))
for m in range(_k + 1)
]
)
nceloss_min.append(l)
nceloss_min = np.array(nceloss_min)
return nceloss_min
def plot_trajectory(result, output_filename="toy_trajectory.pdf"):
c = list(result.keys())[0]
k = list(result[c].keys())[0]
trj_x = result[c][k]["val_cont_loss"]
trj_y = result[c][k]["val_sup_loss"]
c = int(c)
k = int(k)
x = np.linspace(0, 12, 1000)
x0 = nceloss_min([k], c)
min_trj_y = np.min(trj_y)
# upper bound
au = 1
bu = 2 * np.log(np.cosh(1)) - np.log(k / c)
ub = au * x + bu
# lower bound
al = 1
bl = np.log(c) + np.log(k / (k + 1) ** 2) - 2 * np.log(np.cosh(1))
lb = al * x + bl
min_sup_loss = np.log(1 + (c - 1) * np.exp(-2))
max_sup_loss = np.log(1 + (c - 1) * np.exp(+2))
plt.plot(x, lb, "b-", lw=1.5, label=r"Lower bound (Thm 4)")
plt.plot(x, ub, "b-.", lw=1.5, label=r"Upper bound (Thm 3)")
plt.plot([x0 - 1e-3, x0 + 1e-3], [0, 10], "b:", lw=1.5)
plt.plot(
x, min_sup_loss * np.ones_like(x), "b:", lw=1.5
) # NOTE: lower bound of sup loss
plt.plot(trj_x, trj_y, "k-", lw=3, label=r"trajectory")
plt.fill_between(
x,
np.maximum(lb, min_sup_loss),
np.minimum(ub, max_sup_loss),
where=(x >= x0) & (lb <= max_sup_loss),
facecolor="#aaaaff",
alpha=0.5,
)
plt.title(f"C={c}, K={k} (min sup loss = {min_trj_y:.3f})")
plt.xlabel(r"contrastive loss")
plt.ylabel(r"mean supervised loss")
plt.xlim(0, 8)
plt.ylim(0, 5)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(output_filename, bbox_inches="tight", pad_inches=0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"filename",
help="location of execution log file name",
default=os.path.dirname(__file__) + "/../code/results.json",
)
args = parser.parse_args()
with open(args.filename, "r") as f:
result = json.load(f)
pgf_with_rc_fonts = {
"font.serif": [], # use latex default serif font
"font.sans-serif": ["DejaVu Sans"], # use a specific sans-serif font
"font.size": 10,
"ps.useafm": True,
"pdf.use14corefonts": True,
"text.usetex": True,
}
mpl.rcParams.update(pgf_with_rc_fonts)
plot_trajectory(result)