-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
126 lines (96 loc) · 4.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
from cpp.binding import gradient_descent_cpp
from python.gradient_descent import gradient_descent, gradient_descent_cache
from python.gradient_descent_native import gradient_descent_native, gradient_descent_native_cache, PyMatrix
from python.gradient_descent_JAX import gradient_descent_JAX, gradient_descent_cache_JAX
from python.visuals import plot_gradient_descent, plot_gradient_descent_2D, animate_gradient_descent
from timeit import timeit
def generate_radial_points(N, dim):
r = 0.5
points = []
if dim == 2:
for i in range(N):
angle = 2 * np.pi * i / N
points.append([r * np.cos(angle), r * np.sin(angle)])
elif dim == 3:
for i in range(N):
phi = np.arccos(1 - 2 * (i / N))
theta = np.sqrt(N * np.pi) * phi
x = r * np.sin(phi) * np.cos(theta)
y = r * np.sin(phi) * np.sin(theta)
z = r * np.cos(phi)
points.append([x, y, z])
else:
raise ValueError("Only supports 2D and 3D")
return points
def generate_distance_matrix(points):
n = len(points)
distance_matrix = np.zeros((n, n))
for i in range(n):
for j in range(i+1, n):
distance = np.linalg.norm(np.array(points[i]) - np.array(points[j]))
distance_matrix[i, j] = distance
distance_matrix[j, i] = distance
return distance_matrix
NUM_ITERS = 10
def benchmark_gradient_descent_native(X_native, D_native, lr, niter):
secs = timeit(lambda: gradient_descent_native(X_native, D_native, learning_rate=lr, num_iterations=niter), number=NUM_ITERS) / NUM_ITERS
print(f"Average time python native: {secs}")
def benchmark_gradient_descent(X, D, lr, niter):
secs = timeit(lambda: gradient_descent(X, D, learning_rate=lr, num_iterations=niter), number=NUM_ITERS) / NUM_ITERS
print(f"Average time python numpy: {secs}")
def benchmark_gradient_descent_JAX(X, D, lr, niter):
secs = timeit(lambda: gradient_descent_JAX(X, D, learning_rate=lr, num_iterations=niter), number=NUM_ITERS) / NUM_ITERS
print(f"Average time JAX: {secs}")
def benchmark_gradient_descent_cpp(X, D, lr, niter):
secs = timeit(lambda: gradient_descent_cpp(X, D, learning_rate=lr, num_iterations=niter), number=NUM_ITERS) / NUM_ITERS
print(f"Average time C++ binding: {secs}")
def benchmarks(D, dim, lr, niter, plots=True):
N = len(D)
D = np.array(D, dtype=np.float64)
D_native = PyMatrix(D.tolist(), N, N)
# Initial starting point
np.random.seed(42)
X = np.random.rand(N, dim)
X_native = PyMatrix(X.tolist(), N, dim)
### Without visuals
p1 = gradient_descent_native(X_native.copy(), D_native, learning_rate=lr, num_iterations=niter)
p2 = gradient_descent(X.copy(), D, learning_rate=lr, num_iterations=niter)
p3 = gradient_descent_JAX(X.copy(), D, learning_rate=lr, num_iterations=niter)
p_cpp = gradient_descent_cpp(X.copy(), D, learning_rate=lr, num_iterations=niter)
### Benchmarks
benchmark_gradient_descent_native(X_native.copy(), D_native, lr=lr, niter=niter)
benchmark_gradient_descent(X.copy(), D, lr=lr, niter=niter)
benchmark_gradient_descent_JAX(X.copy(), D, lr=lr, niter=niter)
benchmark_gradient_descent_cpp(X.copy(), D, lr=lr, niter=niter)
## Visualization
if plots:
P, L = gradient_descent_cache(X.copy(), D, learning_rate=lr, num_iterations=niter)
plot_gradient_descent_2D(P, L, title="Gradient Descent in python numpy")
plot_gradient_descent(P, L, title="Gradient Descent in python numpy")
P_native, L_native = gradient_descent_native_cache(X_native.copy(), D_native, learning_rate=lr, num_iterations=niter)
plot_gradient_descent(P_native, L_native, title="Gradient Descent in native python")
P_JAX, L_JAX = gradient_descent_cache_JAX(X.copy(), D, learning_rate=lr, num_iterations=niter)
plot_gradient_descent(P_JAX.tolist(), L_JAX.tolist(), title="Gradient Descent in JAX")
# (cache function not implemented: Can only plot final value)
plot_gradient_descent(p_cpp, -1, title="Gradient Descent in C++")
animate_gradient_descent(P, L, trace=False)
if __name__ == "__main__":
# Create optimization target
n_circle = 10
dim_circle = 2
points = generate_radial_points(n_circle, dim_circle) # circle/sphere
# points = np.loadtxt("./shapes/modular.csv", delimiter=",") # modular (N = 1000)
# points = np.loadtxt("./shapes/flame.csv", delimiter=",") # flame (N = 307)
# Optimization input
dim = 2
lr = 0.001
niter = 1000
plots = True
benchmarks(
D=generate_distance_matrix(points),
dim=dim,
lr=lr,
niter=niter,
plots=plots
)