-
Notifications
You must be signed in to change notification settings - Fork 4
Metrics
-
F1 score:
-
Micro F1 score: Utilizes the standard implementation provided by
MultilabelF1Score
from thetorchmetrics
library. -
Macro F1 score: Custom implementation by extending the
torchmetrics.Metric
, the abstract base class for all metrics.
-
Micro F1 score: Utilizes the standard implementation provided by
-
Balanced Accuracy:
- Computed using a custom implementation, where Balanced Accuracy is calculated as
$$Balanced \space Accuracy = \frac{\text{TPR} + \text{TNR}}{2} = \frac{\frac{TP}{TP + FN} + \frac{TN}{TN+FP}}{2}$$
- Computed using a custom implementation, where Balanced Accuracy is calculated as
To change which metrics are tracked during training/validation, follow these steps:
-
Set up the configuration for the metric using
YAML
file. -
Provide the file path to the ChEBai's CLI command using the following parameters:
--model.train_metrics=[config_file_path]
--model.val_metrics=[config_file_path]
--model.test_metrics=[config_file_path]
To change which metrics appear in the names of checkpoints:
-
Create a configuration file containing the class path of the custom model checkpoint class used for ChEB-AI.
- Custom Model Checkpoint Class :
chebai.callbacks.model_checkpoint.CustomModelCheckpoint
- Custom Model Checkpoint Class :
-
Provide additional information such as the name of the metric to the
monitor
parameter, and other relevant parameters in the configuration file.
By default, the checkpoint names will include the validation loss, micro-f1 and macro-f1, see configs/training/default_callbacks.yml
The easiest way to add a new metric is taking the torchmetrics
implementation. But custom implementations can be integrated as well:
-
To add a new metric from the
torchmetrics
library:- Create a YAML config file in
configs/metrics
, similar to existing ones. - Utilize
torchmetrics.MetricCollection
to chain the metric using its class path.
- Create a YAML config file in
-
To add a custom implementation of a metric:
- Create a new class implementation by extending the
torchmetrics.Metric
abstract base class. - In the constructor, add state variables to capture statistics of the data such as true positives (TPs), false negatives (FNs), etc.
- Implement the
update
method to update the state variables after each batch. - Implement the
compute
method to compute the metric score after each epoch. - Add a new config file as mentioned above, and provide the class path of this custom implementation.
- For a more detailed explanation, see the torchmetrics tutorial
- Create a new class implementation by extending the