You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have encountered some issues/doubts about the reduction of metrics when using the train_loop / eval_loop together with Flax.
To make it easier to reproduce/debug the behaviour, I have extended the flax/02_mnist_train_loop.py example to also include an evaluation loop at the end. Namely, I am evaluating the model with two different batch sizes (first with 32 and then with 16). The full code of the example is now:
As we use mean for reduction, we would expect the aggregated test loss to be the same irrespective of the evaluation batch size. However, this is not the case as the code outputs the following:
Final test loss for batch size 32: 0.9193118214607239
Final test loss for batch size 16: 0.9186758995056152
Connected to this, I have a few questions:
What is the recommended way to read out the validation / test loss averaged over an epoch? Is it the following:
loss_test, accuracy_test=history.collect(
"loss", "accuracy"
)
# now we extract the mean lossprint(f"Mean epoch loss {loss_test[-1]}")
Do I have to write the Metrics(Collection) class myself (which I assume merges the batch metrics over the epochs) or is there anything prebuilt in the library I can use instead?
In general, I think certain aspects about metrics / logging could be made more specific in the documentation. Most importantly, I am wondering at which point the metrics are reduced etc.
The text was updated successfully, but these errors were encountered:
Hi,
I have encountered some issues/doubts about the reduction of metrics when using the
train_loop
/eval_loop
together with Flax.To make it easier to reproduce/debug the behaviour, I have extended the
flax/02_mnist_train_loop.py
example to also include an evaluation loop at the end. Namely, I am evaluating the model with two different batch sizes (first with 32 and then with 16). The full code of the example is now:As we use
mean
for reduction, we would expect the aggregated test loss to be the same irrespective of the evaluation batch size. However, this is not the case as the code outputs the following:Connected to this, I have a few questions:
Metrics(Collection)
class myself (which I assume merges the batch metrics over the epochs) or is there anything prebuilt in the library I can use instead?The text was updated successfully, but these errors were encountered: