Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behaviour for reducing metrics in flax/02_mnist_train_loop.py #13

Open
mstoelzle opened this issue Sep 1, 2023 · 0 comments

Comments

@mstoelzle
Copy link

Hi,

I have encountered some issues/doubts about the reduction of metrics when using the train_loop / eval_loop together with Flax.
To make it easier to reproduce/debug the behaviour, I have extended the flax/02_mnist_train_loop.py example to also include an evaluation loop at the end. Namely, I am evaluating the model with two different batch sizes (first with 32 and then with 16). The full code of the example is now:

from pathlib import Path
from time import time

import flax.linen as nn
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import jax_metrics as jm
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from clu.metrics import Accuracy, Average, Collection
from flax import struct
from flax.training import train_state

import ciclo

batch_size = 32

# load the MNIST dataset
ds_train: tf.data.Dataset = tfds.load("mnist", split="train", shuffle_files=True)
ds_train = ds_train.repeat().shuffle(1024).batch(batch_size).prefetch(1)
ds_test: tf.data.Dataset = tfds.load("mnist", split="test")
ds_test = ds_test.batch(32, drop_remainder=True).prefetch(1)


# Define model
class Linear(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x / 255.0
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=10)(x)
        return x


@struct.dataclass
class Metrics(Collection):
    loss: Average.from_output("loss")
    accuracy: Accuracy

    def update(self, **kwargs) -> "Metrics":
        updates = self.single_from_model_output(**kwargs)
        return self.merge(updates)


class TrainState(train_state.TrainState):
    metrics: jm.Metrics


@jax.jit
def train_step(state: TrainState, batch):
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, batch["image"])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch["label"]
        ).mean()
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = state.metrics.update(loss=loss, preds=logits, target=batch["label"])
    logs = ciclo.logs()
    logs.add_stateful_metrics(**metrics.compute())
    return logs, state.replace(metrics=metrics)


@jax.jit
def test_step(state: TrainState, batch):
    logits = state.apply_fn({"params": state.params}, batch["image"])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch["label"]
    ).mean()
    metrics = state.metrics.update(loss=loss, preds=logits, target=batch["label"])
    logs = ciclo.logs()
    logs.add_stateful_metrics(**metrics.compute())
    return logs, state.replace(metrics=metrics)


def reset_step(state: TrainState):
    return state.replace(metrics=state.metrics.reset())


# Initialize state
model = Linear()
variables = model.init(jax.random.PRNGKey(0), jnp.empty((1, 28, 28, 1)))
state = TrainState.create(
    apply_fn=model.apply,
    params=variables["params"],
    tx=optax.adamw(1e-3),
    metrics=jm.Metrics(
        {
            "accuracy": jm.metrics.Accuracy(),
            "loss": jm.metrics.Mean().from_argument("loss"),
        }
    ),
)

# training loop
total_samples = 32 * 100
total_steps = total_samples // batch_size
eval_steps = total_steps // 10
log_steps = total_steps // 50


state, history, _ = ciclo.train_loop(
    state,
    ds_train.as_numpy_iterator(),
    {
        ciclo.on_train_step: [train_step],
        ciclo.on_test_step: [test_step],
        ciclo.on_reset_step: [reset_step],
    },
    callbacks=[
        ciclo.checkpoint(
            f"logdir/{Path(__file__).stem}/{int(time())}",
            monitor="accuracy_test",
            mode="max",
        ),
        ciclo.keras_bar(total=total_steps),
    ],
    test_dataset=lambda: ds_test.as_numpy_iterator(),
    epoch_duration=eval_steps,
    stop=total_steps,
)

# %% Run evaluation

steps, loss, accuracy = history.collect("steps", "loss", "accuracy")
steps_test, loss_test, accuracy_test = history.collect(
    "steps", "loss_test", "accuracy_test"
)

test_batch_size = 32
ds_test: tf.data.Dataset = tfds.load("mnist", split="test")
ds_test = ds_test.batch(test_batch_size, drop_remainder=True).prefetch(1)
state, history, _ = ciclo.test_loop(
    state,
    ds_test.as_numpy_iterator(),
    {
        ciclo.on_test_step: [test_step],
    },
    callbacks=[
        ciclo.keras_bar(total=total_steps),
    ],
    stop=total_samples // test_batch_size,
)
loss_test, accuracy_test = history.collect(
    "loss", "accuracy"
)
print(f"Final test loss for batch size {test_batch_size}: {loss_test[-1]}")

test_batch_size = 16
ds_test: tf.data.Dataset = tfds.load("mnist", split="test")
ds_test = ds_test.batch(test_batch_size, drop_remainder=True).prefetch(1)
state, history, _ = ciclo.test_loop(
    state,
    ds_test.as_numpy_iterator(),
    {
        ciclo.on_test_step: [test_step],
    },
    callbacks=[
        ciclo.keras_bar(total=total_steps),
    ],
    stop=total_samples // test_batch_size,
)
loss_test, accuracy_test = history.collect(
    "loss", "accuracy"
)
print(f"Final test loss for batch size {test_batch_size}: {loss_test[-1]}")

As we use mean for reduction, we would expect the aggregated test loss to be the same irrespective of the evaluation batch size. However, this is not the case as the code outputs the following:

Final test loss for batch size 32: 0.9193118214607239
Final test loss for batch size 16: 0.9186758995056152

Connected to this, I have a few questions:

  1. What is the recommended way to read out the validation / test loss averaged over an epoch? Is it the following:
loss_test, accuracy_test = history.collect(
    "loss", "accuracy"
)
# now we extract the mean loss
print(f"Mean epoch loss {loss_test[-1]}")
  1. Do I have to write the Metrics(Collection) class myself (which I assume merges the batch metrics over the epochs) or is there anything prebuilt in the library I can use instead?
  2. In general, I think certain aspects about metrics / logging could be made more specific in the documentation. Most importantly, I am wondering at which point the metrics are reduced etc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant