Skip to content

Commit

Permalink
Feature: Added logic behind class weight calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Flippchen committed Feb 11, 2024
1 parent 87ea96e commit 6878af6
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import logging
import platform
from sklearn.utils.class_weight import compute_class_weight


def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
Expand Down Expand Up @@ -184,6 +185,47 @@ def show_augmented_batch(train_ds, data_augmentation) -> None:
plt.show()


def compute_class_weights(class_names: list, dataset_train: tf.data.Dataset, dataset_val: tf.data.Dataset) -> dict:
"""
Computes the class weights for the dataset
:param class_names: List of class names
:param dataset_train: Train-Dataset to compute the class weights for
:param dataset_val: Validation-Dataset to compute the class weights for
:return: Dictionary with class weights
"""
class_counts = {class_name: 0 for class_name in class_names}

for images, label in dataset_train.unbatch(): # Iterate over each instance
class_name = class_names[label.numpy()] # Directly use label to get class name
class_counts[class_name] += 1

class_count_validation = {class_name: 0 for class_name in class_names}

for images, label in dataset_val.unbatch(): # Iterate over each instance
class_name = class_names[label.numpy()] # Directly use label to get class name
class_count_validation[class_name] += 1

print("Validation Weights:", {class_name: count for class_name, count in class_count_validation.items()})
print("Train Weights:", {class_name: count for class_name, count in class_counts.items()})

# Convert class counts to a list in the order of class names
class_samples = np.array([class_counts[class_name] for class_name in class_names])

# Calculate class weights
# This requires the classes to be sequential numbers starting from 0, which they typically are if indexed by class_names
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.arange(len(class_names)),
y=np.concatenate([np.full(count, i) for i, count in enumerate(class_samples)])
)

# Convert class weights to a dictionary where keys are numerical class indices
class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}

return class_weight_dict


def plot_model_score(history, name: str, model_type: str) -> None:
"""
Plots the accuracy and loss of the model
Expand Down

0 comments on commit 6878af6

Please sign in to comment.