From 6878af608df4790973a31bc3f2b5a4c2f71fb341 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Sun, 11 Feb 2024 17:53:25 +0100 Subject: [PATCH] Feature: Added logic behind class weight calculation --- utilities/tools.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/utilities/tools.py b/utilities/tools.py index 3fa8af0..4cc8385 100644 --- a/utilities/tools.py +++ b/utilities/tools.py @@ -9,6 +9,7 @@ import os import logging import platform +from sklearn.utils.class_weight import compute_class_weight def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]: @@ -184,6 +185,47 @@ def show_augmented_batch(train_ds, data_augmentation) -> None: plt.show() +def compute_class_weights(class_names: list, dataset_train: tf.data.Dataset, dataset_val: tf.data.Dataset) -> dict: + """ + Computes the class weights for the dataset + + :param class_names: List of class names + :param dataset_train: Train-Dataset to compute the class weights for + :param dataset_val: Validation-Dataset to compute the class weights for + :return: Dictionary with class weights + """ + class_counts = {class_name: 0 for class_name in class_names} + + for images, label in dataset_train.unbatch(): # Iterate over each instance + class_name = class_names[label.numpy()] # Directly use label to get class name + class_counts[class_name] += 1 + + class_count_validation = {class_name: 0 for class_name in class_names} + + for images, label in dataset_val.unbatch(): # Iterate over each instance + class_name = class_names[label.numpy()] # Directly use label to get class name + class_count_validation[class_name] += 1 + + print("Validation Weights:", {class_name: count for class_name, count in class_count_validation.items()}) + print("Train Weights:", {class_name: count for class_name, count in class_counts.items()}) + + # Convert class counts to a list in the order of class names + class_samples = np.array([class_counts[class_name] for class_name in class_names]) + + # Calculate class weights + # This requires the classes to be sequential numbers starting from 0, which they typically are if indexed by class_names + class_weights = compute_class_weight( + class_weight='balanced', + classes=np.arange(len(class_names)), + y=np.concatenate([np.full(count, i) for i, count in enumerate(class_samples)]) + ) + + # Convert class weights to a dictionary where keys are numerical class indices + class_weight_dict = {i: weight for i, weight in enumerate(class_weights)} + + return class_weight_dict + + def plot_model_score(history, name: str, model_type: str) -> None: """ Plots the accuracy and loss of the model