Skip to content

Commit

Permalink
Merge pull request #93 from Flippchen/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Flippchen authored Feb 8, 2024
2 parents 297e5f7 + ca20618 commit 60f7f61
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
6 changes: 4 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_load_dataset(tmp_path):
p2 = d2 / f"img{i + 1}.jpg"
p2.write_text("fake image data")

train_ds, val_ds, class_names = load_dataset(str(d), 2, 32, 32)
train_ds, val_ds, class_names = load_dataset(str(d), 2, 32, 32, 123)

assert len(train_ds) == 4
assert len(val_ds) == 1
Expand All @@ -30,10 +30,12 @@ def test_load_dataset(tmp_path):

def test_create_augmentation_layer():
data_augmentation = create_augmentation_layer(32, 32)
assert len(data_augmentation.layers) == 3
assert len(data_augmentation.layers) == 5
assert isinstance(data_augmentation.layers[0], tf.keras.layers.RandomFlip)
assert isinstance(data_augmentation.layers[1], tf.keras.layers.RandomRotation)
assert isinstance(data_augmentation.layers[2], tf.keras.layers.RandomZoom)
assert isinstance(data_augmentation.layers[3], tf.keras.layers.RandomContrast)
assert isinstance(data_augmentation.layers[4], tf.keras.layers.GaussianNoise)


def test_get_data_path_addon():
Expand Down
19 changes: 9 additions & 10 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from keras.applications import EfficientNetV2B1
from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape
from utilities.discord_callback import DiscordCallback
from keras.optimizers import Adam
from keras.optimizers import AdamW
from keras.regularizers import l1_l2
from keras.callbacks import EarlyStopping, ModelCheckpoint
import os
import random

import tensorflow as tf
# Ignore warnings
import warnings
Expand All @@ -30,6 +32,8 @@
# Set to True to load trained model
load_model = False
load_path = "../models/all_model_variants/efficientnet-old-head-model-variants.h5"
# Set seed for reproducibility
random_seed = True
# Config
base_path = get_base_path()
path_addon = get_data_path_addon(model_type)
Expand All @@ -38,6 +42,7 @@
"batch_size": 32,
"img_height": img_height,
"img_width": img_width,
"seed": random.randint(0, 1000) if random_seed else 123
}

# Load dataset and classes
Expand All @@ -52,11 +57,6 @@
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Normalize the data
normalization_layer = layers.Rescaling(1. / 255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))

# Create data augmentation layer and show augmented batch
data_augmentation = create_augmentation_layer(img_height, img_width)
show_augmented_batch(train_ds, data_augmentation)
Expand Down Expand Up @@ -87,11 +87,11 @@
]) if not load_model else keras.models.load_model(load_path)

# Define optimizer
optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
optimizer = AdamW(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, use_ema=True)

# Define learning rate scheduler
initial_learning_rate = 0.001
lr_decay_steps = 1000
lr_decay_steps = 10
lr_decay_rate = 0.96
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
Expand All @@ -101,7 +101,7 @@

# Compile model
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['accuracy'])
model.summary()

Expand Down Expand Up @@ -134,4 +134,3 @@
# Save model
model.save(f"{save_path}{name}.h5")

# TODO: Different data augmentation (vertical, ..), Augmentation before training
26 changes: 7 additions & 19 deletions utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import platform


def load_dataset(path: str, batch_size: int, img_height: int, img_width: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
"""
:param path: Path to the Dataset folder
:param batch_size: Integer which defines how many Images are in one Batch
Expand All @@ -20,27 +20,19 @@ def load_dataset(path: str, batch_size: int, img_height: int, img_width: int) ->
:return: Tuple of train, val Dataset and Class names
"""
data_dir = pathlib.Path(path)
# if "more_classes" in path:
# image_count = len(list(data_dir.glob('*/*/*.jpg')))
# else:
# image_count = len(list(data_dir.glob('*/*/*/*.jpg')))

# print("Image count:", image_count)
# cars = list(data_dir.glob('*/*/*/*.jpg'))
# PIL.Image.open(str(cars[0]))
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)

Expand Down Expand Up @@ -99,12 +91,6 @@ def load_image_subset(path: str, batch_size: int, img_height: int, img_width: in
:return: Subset of Dataset
"""
data_dir = pathlib.Path(path)
# if "more_classes" in path:
# image_count = len(list(data_dir.glob('*/*/*.jpg')))
# else:
# image_count = len(list(data_dir.glob('*/*/*/*.jpg')))

# print("Image count:", image_count)

data = tf.keras.utils.image_dataset_from_directory(
data_dir,
Expand Down Expand Up @@ -168,12 +154,14 @@ def create_augmentation_layer(img_height: int, img_width: int) -> keras.Sequenti
"""
return keras.Sequential(
[
layers.RandomFlip("horizontal",
layers.RandomFlip("vertical",
input_shape=(img_height,
img_width,
3)),
layers.RandomRotation(0.1),
layers.RandomRotation(0.2),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
layers.GaussianNoise(0.1)
]
)

Expand Down

0 comments on commit 60f7f61

Please sign in to comment.