Skip to content

Commit

Permalink
fix: average epoch loss for mini-batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Az-r-ow committed May 18, 2024
1 parent 32f3a62 commit 43d70dd
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
venv
__pycache__
dataset/
checkpoints/*.bin

### binaries ###
*.bin
Empty file.
28 changes: 16 additions & 12 deletions examples/train-predict-MNIST/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
NUM_PREDICTIONS = 1000
MNIST_DATASET_FILE = "./dataset/mnist.npz"

# Adding the module path to the sys path
# Adding the module path to the sys path
so_dir = add_module_path_to_sys_path(__file__)

import NeuralNetPy as NNP

# If file doesn't exists create and download the data
if not file_exists(MNIST_DATASET_FILE):
print("Mnist dataset not found")
get_MNIST_dataset(MNIST_DATASET_FILE)
print("Mnist dataset not found")
get_MNIST_dataset(MNIST_DATASET_FILE)

# Otherwise load data from file
(x_train, y_train), (x_test, y_test) = load_data(MNIST_DATASET_FILE)
Expand All @@ -38,13 +38,13 @@
# Setting up the networks parameters
network.setup(optimizer=NNP.optimizers.Adam(0.01), loss=NNP.LOSS.MCE)

# combining the data with the labels for later shuffling
# combining the data with the labels for later shuffling
combined = list(zip(x_train, y_train))

# shuffling the combined list
# shuffling the combined list
random.shuffle(combined)

# separating them
# separating them
x_train, y_train = zip(*combined)

# preparing the training data
Expand All @@ -54,7 +54,11 @@

trainingData.batch(128)

callbacks = [NNP.callbacks.EarlyStopping("LOSS", 0.01), NNP.callbacks.CSVLogger("training.csv"), NNP.callbacks.ModelCheckpoint("checkpoints", False, 2,verbose=True)]
callbacks = [
NNP.callbacks.EarlyStopping("LOSS", 0.01),
NNP.callbacks.CSVLogger("training.csv"),
NNP.callbacks.ModelCheckpoint("checkpoints", True, verbose=True),
]

network.train(trainingData, 10, callbacks)

Expand All @@ -67,15 +71,15 @@

(accuracy, n, correct) = get_accuracy(predicted_numbers, y_test)

# Getting the prediction's accuracy
# Getting the prediction's accuracy
print(f"Num correct predictions : {correct}/{n} - accuracy {accuracy}")

# Saving the trained model in a bin file
NNP.models.Model.save_to_file('./model.bin', network)
NNP.models.Model.save_to_file("./model.bin", network)

saved_model = NNP.models.Network()

NNP.models.Model.load_from_file('./model.bin', saved_model)
NNP.models.Model.load_from_file("./model.bin", saved_model)

# preparing the testing data
predictions = saved_model.predict(f_x_test[:NUM_PREDICTIONS])
Expand All @@ -84,8 +88,8 @@

(accuracy, n, correct) = get_accuracy(predicted_numbers, y_test)

# Getting the prediction's accuracy
# Getting the prediction's accuracy
print(f"Num correct predictions : {correct}/{n} - accuracy {accuracy}")

# Remove sys.path modification
sys.path.remove(so_dir)
sys.path.remove(so_dir)
9 changes: 7 additions & 2 deletions src/NeuralNet/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ double Network::miniBatchTraining(
trainingCheckpoint("onTrainBegin", callbacks);

for (cEpoch = 0; cEpoch < epochs; cEpoch++) {
double sumBatchLoss = 0;
int numBatches = trainingData.inputs.size();
trainingCheckpoint("onEpochBegin", callbacks);
TrainingGauge g(trainingData.inputs.size(), 0, epochs, (cEpoch + 1));
for (int b = 0; b < trainingData.inputs.size(); b++) {
TrainingGauge g(numBatches, 0, epochs, (cEpoch + 1));
for (int b = 0; b < numBatches; b++) {
trainingCheckpoint("onBatchBegin", callbacks);
const int numOutputs = this->getOutputLayer()->getNumNeurons();
const int inputsSize = trainingData.inputs.batches[b].size();
Expand All @@ -144,13 +146,16 @@ double Network::miniBatchTraining(
Eigen::MatrixXd o =
this->forwardProp(trainingData.inputs.batches[b], true);
loss = this->cmpLoss(o, y) / inputsSize;
sumBatchLoss += loss;
accuracy = computeAccuracy(o, y);
sumLoss += loss;
this->backProp(o, y);
trainingCheckpoint("onBatchEnd", callbacks);
if (!this->progBar) continue; // Skip when disabled
g.printWithLAndA(loss, accuracy);
}
// calculating current epoch avg loss
loss = sumBatchLoss / static_cast<double>(numBatches);
trainingCheckpoint("onEpochEnd", callbacks);
}

Expand Down

0 comments on commit 43d70dd

Please sign in to comment.