Skip to content

Metrics

Simon Flügel edited this page Jun 12, 2024 · 2 revisions

Metrics implemented in ChEB-AI

  1. F1 score:

    • Micro F1 score: Utilizes the standard implementation provided by MultilabelF1Score from the torchmetrics library.
    • Macro F1 score: Custom implementation by extending the torchmetrics.Metric, the abstract base class for all metrics.
  2. Balanced Accuracy:

    • Computed using a custom implementation, where Balanced Accuracy is calculated as $$Balanced \space Accuracy = \frac{\text{TPR} + \text{TNR}}{2} = \frac{\frac{TP}{TP + FN} + \frac{TN}{TN+FP}}{2}$$

Changing Tracked Metrics

To change which metrics are tracked during training/validation, follow these steps:

  1. Set up the configuration for the metric using YAML file.

  2. Provide the file path to the ChEBai's CLI command using the following parameters:

    • --model.train_metrics=[config_file_path]
    • --model.val_metrics=[config_file_path]
    • --model.test_metrics=[config_file_path]

Customizing Checkpoint Names with Metrics

To change which metrics appear in the names of checkpoints:

  1. Create a configuration file containing the class path of the custom model checkpoint class used for ChEB-AI.

    • Custom Model Checkpoint Class : chebai.callbacks.model_checkpoint.CustomModelCheckpoint
  2. Provide additional information such as the name of the metric to the monitor parameter, and other relevant parameters in the configuration file.

By default, the checkpoint names will include the validation loss, micro-f1 and macro-f1, see configs/training/default_callbacks.yml

Adding New Metrics

The easiest way to add a new metric is taking the torchmetrics implementation. But custom implementations can be integrated as well:

  1. To add a new metric from the torchmetrics library:

    • Create a YAML config file in configs/metrics, similar to existing ones.
    • Utilize torchmetrics.MetricCollection to chain the metric using its class path.
  2. To add a custom implementation of a metric:

    • Create a new class implementation by extending the torchmetrics.Metric abstract base class.
    • In the constructor, add state variables to capture statistics of the data such as true positives (TPs), false negatives (FNs), etc.
    • Implement the update method to update the state variables after each batch.
    • Implement the compute method to compute the metric score after each epoch.
    • Add a new config file as mentioned above, and provide the class path of this custom implementation.
    • For a more detailed explanation, see the torchmetrics tutorial
Clone this wiki locally