forked from jrk/gradient-halide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fit_function.cpp
143 lines (126 loc) · 4.83 KB
/
fit_function.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include "Halide.h"
using namespace Halide;
int main(int argc, char **argv) {
// Fit an odd polynomial to sin from 0 to pi/2 using Halide's derivative support
ImageParam coeffs(Float(64), 1);
Param<double> learning_rate;
Param<int> order, samples;
Func approx_sin;
Var x, y;
Expr fx = (x / cast<double>(samples)) * Expr(M_PI/2);
// We'll evaluate polynomial using a slightly modified Horner's
// method. We need to save the intermediate results for the
// backwards pass to use. We'll leave the ultimate result at index
// 0.
RDom r(0, order);
Expr r_flipped = order - 1 - r;
approx_sin(x, y) = cast<double>(0);
approx_sin(x, r_flipped) = (approx_sin(x, r_flipped + 1)*fx + coeffs(r_flipped)) * fx;
Func exact_sin;
exact_sin(x) = sin(fx);
// Minimize squared relative error. We'll be careful not to
// evaluate it at zero. We're correct there by construction
// anyway, because our polynomial is odd.
Func err;
err(x) = pow((approx_sin(x, 0) - exact_sin(x)) / exact_sin(x), 2);
RDom d(1, samples - 1);
Func average_err;
average_err() = sum(err(d)) / samples;
// Take the derivative of the output w.r.t. the coefficients. The
// returned object acts like a map from Funcs to the derivative of
// the err w.r.t those Funcs.
auto d_err_d = propagate_adjoints(average_err);
// Compute the new coefficients in terms of the old.
Func new_coeffs;
new_coeffs(x) = coeffs(x) - learning_rate * d_err_d(coeffs)(x);
// Schedule
err.compute_root().vectorize(x, 4);
new_coeffs.compute_root().vectorize(x, 4);
approx_sin.compute_root().vectorize(x, 4).update().vectorize(x, 4);
exact_sin.compute_root().vectorize(x, 4);
average_err.compute_root();
// d_err_d(coeffs) is just a Func, and you can schedule
// it. However, each Func in the pipeline actually creates a
// sequence of synthesized Funcs to compute its derivative, and
// you may want to schedule all of them (or just use the
// autoscheduler). Here we will write a quick-and-dirty
// autoscheduler for this pipeline to illustrate how you can
// access the new synthesized derivative Funcs.
Var v;
Func fs[] = {coeffs, approx_sin, err};
for (Func f : fs) {
// Iterate over the derivative Funcs for this Func. We get
// them in order from output to input. The first Func in the
// vector is the one returned by operator(), and is the
// fully-computed derivative. It is always a zero boundary
// condition which defines the region in which the derivative
// is non-zero, and we always want to inline that into the
// consumer, so we'll skip the first one.
bool first = true;
for (Func df : d_err_d.funcs(f)) {
if (first) {
first = false;
continue;
}
df.compute_root().vectorize(df.args()[0], 4);
for (int i = 0; i < df.num_update_definitions(); i++) {
// Find a pure var to vectorize over
for (auto d : df.update(i).get_schedule().dims()) {
if (d.is_pure()) {
df.update(i).vectorize(Var(d.var), 4);
break;
}
}
}
}
}
// Gradient descent loop
// Let's use eight terms and a thousand samples
const int terms = 8;
Buffer<double> c(terms);
order.set(terms);
samples.set(1000);
auto e = Buffer<double>::make_scalar();
coeffs.set(c);
Pipeline p({average_err, new_coeffs});
c.fill(0);
// Initialize to the Taylor series for sin about zero
c(0) = 1;
for (int i = 1; i < terms; i++) {
c(i) = -c(i-1)/(i*2*(i*2 + 1));
}
// This gradient descent is not particularly well-conditioned,
// because the standard polynomial basis is nowhere near
// orthogonal over [0, pi/2]. This should probably use a Cheychev
// basis instead. We'll use a very slow learning rate and lots of
// steps.
learning_rate.set(0.00001);
const int steps = 10000;
double initial_error;
for (int i = 0; i <= steps; i++) {
bool should_print = (i == 0 || i == steps/2 || i == steps);
if (should_print) {
printf("Iteration %d\n"
"Coefficients: ", i);
for (int j = 0; j < terms; j++) {
printf("%g ", c(j));
}
printf("\n");
}
p.realize({e, c});
if (should_print) {
printf("Error: %g\n", e());
}
if (i == 0) {
initial_error = e();
}
}
double final_error = e();
if (final_error <= 1e-10 && final_error < initial_error) {
printf("Success!\n");
return 0;
} else {
printf("Did not converge\n");
return -1;
}
}