-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_vanilla.c
144 lines (125 loc) · 4.8 KB
/
mnist_vanilla.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#define STB_IMAGE_IMPLEMENTATION
#include "../lib/clear_net.h"
#include "./external/stb_image.h"
#include <dirent.h>
#include <sys/stat.h>
#define data cn.data
const ulong img_height = 28;
const ulong img_width = 28;
const ulong num_pixels = img_height * img_width;
const ulong num_train_files = 60000;
const ulong num_test_files = 10000;
const ulong dim_output = 10;
int get_data_from_dir(Vector *train, Vector *targets, char *path,
ulong num_files) {
DIR *directory = opendir(path);
if (directory == NULL) {
printf("Error: Failed to open %s.\n", path);
return 1;
}
struct dirent *entry;
ulong count = 0;
while ((entry = readdir(directory)) != NULL) {
// Skip dotfiles
if (entry->d_name[0] == '.')
continue;
// Construct the file path
char file_path[PATH_MAX];
snprintf(file_path, PATH_MAX, "%s/%s", path, entry->d_name);
// Check if the entry is a regular file
struct stat file_stat;
if (stat(file_path, &file_stat) == 0 && S_ISREG(file_stat.st_mode)) {
int img_width, img_height, img_comp;
uint8_t *img_pixels = (uint8_t *)stbi_load(
file_path, &img_width, &img_height, &img_comp, 0);
if (img_pixels == NULL) {
fprintf(
stderr,
"ERROR: could not read %s\n Did you download the data? The "
"binary begins its search at the directory you call it.\n",
file_path);
return 1;
}
if (img_comp != 1) {
fprintf(stderr, "ERROR: %s improperly formatted", file_path);
return 1;
}
for (int j = 0; j < img_width * img_height; ++j) {
VEC_AT(train[count], j) = img_pixels[j] / 255.f;
}
// the python script set it up so the first character is the label
ulong label = (entry->d_name[0] - '0');
VEC_AT(targets[count], label) = 1;
count++;
}
}
closedir(directory);
CLEAR_NET_ASSERT(count == num_files);
return 0;
}
int main(void) {
srand(0);
char *train_path = "./datasets/mnist/train";
Vector *vinputs = data.allocVectors(num_train_files, num_pixels);
Vector *vtargets = data.allocVectors(num_train_files, dim_output);
int res = get_data_from_dir(vinputs, vtargets, train_path, num_train_files);
if (res) {
return 1;
}
CNData *inputs = data.allocDataFromVectors(vinputs, num_train_files);
CNData *targets = data.allocDataFromVectors(vtargets, num_train_files);
// randomize for stochastic gradient descent
data.shuffleDatas(inputs, targets);
char *test_path = "./datasets/mnist/test";
Vector *vtest_in = data.allocVectors(num_test_files, num_pixels);
Vector *vtest_targets = data.allocVectors(num_test_files, dim_output);
res = get_data_from_dir(vtest_in, vtest_targets, test_path, num_test_files);
if (res != 0) {
return 1;
}
CNData *test_ins = data.allocDataFromVectors(vtest_in, num_train_files);
CNData *test_tars =
data.allocDataFromVectors(vtest_targets, num_train_files);
HParams *hp = cn.allocDefaultHParams();
cn.setRate(hp, 0.005);
cn.withMomentum(hp, 0.9);
Net *net = cn.allocVanillaNet(hp, num_pixels);
cn.allocDenseLayer(net, SIGMOID, 16);
cn.allocDenseLayer(net, SIGMOID, 16);
cn.allocDenseLayer(net, SIGMOID, dim_output);
cn.randomizeNet(net, -1, 1);
ulong num_epochs = 1000;
scalar error;
scalar error_break = 0.25;
ulong batch_size = 100;
CLEAR_NET_ASSERT(num_train_files % batch_size == 0);
printf("Initial Cost: %f\n", cn.lossVanilla(net, inputs, targets));
printf("Beginning Training\n");
// for SGD
CNData *batch_ins = data.allocEmptyData();
CNData *batch_tars = data.allocEmptyData();
for (ulong i = 0; i < num_epochs; ++i) {
for (ulong batch_num = 0; batch_num < (num_train_files / batch_size);
++batch_num) {
data.setBatch(inputs, targets, batch_num, batch_size, batch_ins,
batch_tars);
cn.lossVanilla(net, batch_ins, batch_tars);
cn.backprop(net);
}
error = cn.lossVanilla(net, inputs, targets);
printf("Cost after epoch %zu: %f\n", i, error);
if (error < error_break) {
printf("Less than: %f error after epoch %zu\n", error_break, i);
break;
}
}
printf("Final Error on training set: %f\n",
cn.lossVanilla(net, inputs, targets));
char *file = "model";
cn.saveNet(net, file);
cn.deallocNet(net);
net = cn.allocNetFromFile(file);
printf("Testing Predictions\n");
cn.printVanillaPredictions(net, test_ins, test_tars);
cn.deallocNet(net);
}