-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
1 addition
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"C0ffGllJpULR"},"outputs":[],"source":["import keras\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import os\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","import pickle\n","from collections import namedtuple\n","from ctypes import ArgumentError\n","from dataclasses import dataclass\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.vgg16 import preprocess_input\n","from keras.layers import RandomZoom, RandomFlip, RandomRotation, Input, Dense, Dropout, GlobalAveragePooling2D, Flatten\n","from keras.models import Model\n","from keras.optimizers import RMSprop\n","from textwrap import dedent\n","from tqdm import tqdm\n","from keras.callbacks import (\n"," ModelCheckpoint,\n"," EarlyStopping,\n"," Callback,\n",")\n","\n","keras.mixed_precision.set_global_policy(\"mixed_float16\")\n","\n","SEED = 42\n","\n","def set_global_determinism():\n"," os.environ['PYTHONHASHSEED'] = str(SEED)\n"," tf.random.set_seed(SEED)\n"," np.random.seed(SEED)\n","\n"," os.environ['TF_DETERMINISTIC_OPS'] = '1'\n"," os.environ['TF_CUDNN_DETERMINISTIC'] = '1'\n","\n"," tf.config.threading.set_inter_op_parallelism_threads(1)\n"," tf.config.threading.set_intra_op_parallelism_threads(1)\n","\n","\n","set_global_determinism()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"7KvkQ_JBQ79F"},"outputs":[],"source":["@tf.function\n","def compress_image_with_energy(image, energy_factor=0.9):\n"," # Returns a compressed image based on a desired energy factor\n"," image_rescaled = tf.convert_to_tensor(image)\n"," image_batched = tf.transpose(image_rescaled, [2, 0, 1])\n"," s, U, V = tf.linalg.svd(image_batched, compute_uv=True, full_matrices=False)\n","\n"," # Extracting singular values\n"," props_rgb = tf.map_fn(lambda x: tf.cumsum(x) / tf.reduce_sum(x), s)\n"," props_rgb_mean = tf.reduce_mean(props_rgb, axis=0)\n","\n"," # Find closest k that corresponds to the energy factor\n"," k = tf.argmin(tf.abs(props_rgb_mean - energy_factor)) + 1\n","\n"," # Compute the low-rank approximation\n"," s_k, U_k, V_k = s[..., :k], U[..., :, :k], V[..., :, :k]\n"," A_k = tf.einsum(\"...s,...us,...vs->...uv\", s_k, U_k, V_k)\n","\n"," return tf.transpose(A_k, [1, 2, 0])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"t4Y7ypGP3Qxk"},"outputs":[],"source":["@dataclass\n","class TrialDetail:\n"," dataset_name: str\n"," augmentation_method: str\n"," energy_factor: float\n"," percentage_data: float = 1\n","\n"," def __getitem__(self, idx):\n"," props = (\n"," self.dataset_name,\n"," self.augmentation_method,\n"," self.percentage_data,\n"," self.energy_factor,\n"," )\n"," return props[idx]\n","\n"," def __repr__(self):\n"," energy_factor_repr = (\n"," f\"{self.energy_factor:.2%}\"\n"," if self.augmentation_method in {\"all\", \"svd\"}\n"," else \"N/A\"\n"," )\n"," return dedent(\n"," f\"\"\"\n"," Model & Dataset Name: {self.dataset_name.upper()}\n"," Percentage of Data Used: {self.percentage_data:.2%}\n"," Augmentation Method: {self.augmentation_method.upper()}\n"," Energy Factor: {energy_factor_repr}\n"," \"\"\"\n"," )\n","\n"," def file_str(self):\n"," return f\"{self.dataset_name}_{self.augmentation_method}_{self.percentage_data}_{self.energy_factor}\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dWLYaTy72Prr"},"outputs":[],"source":["def get_kfold_callbacks(trial_detail: TrialDetail, fold_num):\n"," checkpoint_cb = ModelCheckpoint(\n"," f\"{trial_detail.file_str()}_fine_tuning\",\n"," save_best_only=True,\n"," monitor=\"val_accuracy\",\n"," )\n"," kfold_train_details_logger_cb = _KFoldTrainingDetailsLogger(trial_detail, fold_num)\n","\n"," return [\n"," checkpoint_cb,\n"," kfold_train_details_logger_cb,\n"," ]\n","\n","\n","def get_retrain_callbacks(trial_detail: TrialDetail, purpose=\"retrain\"):\n"," checkpoint_cb = ModelCheckpoint(\n"," f\"{trial_detail.file_str()}_{purpose}\",\n"," save_best_only=True,\n"," monitor=\"accuracy\",\n"," )\n"," retrain_train_details_logger_cb = _RetrainDetailsLogger(trial_detail)\n","\n"," return [\n"," checkpoint_cb,\n"," retrain_train_details_logger_cb,\n"," ]\n","\n","\n","class _KFoldTrainingDetailsLogger(Callback):\n"," def __init__(self, trial_detail, fold_num):\n"," self.trial_detail = trial_detail\n"," self.fold_num = fold_num + 1\n","\n"," def on_train_begin(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n\\nSTART OF TRAINING - FOLD #{self.fold_num}:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n"," def on_train_end(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n\\nEND OF TRAINING - FOLD #{self.fold_num}:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n","\n","class _RetrainDetailsLogger(Callback):\n"," def __init__(self, trial_detail):\n"," self.trial_detail = trial_detail\n","\n"," def on_train_begin(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\nRETRAINING ON ENTIRE DATASET:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n"," def on_train_end(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\nEND OF RETRAINING:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JVDj8PJyN9hD"},"outputs":[],"source":["class RandomCompression():\n"," \"\"\"Utilizes low rank approximation of images using SVD with randomized energy factors\"\"\"\n"," def __init__(\n"," self,\n"," max_energy_factor=0.975,\n"," min_energy_factor=0.95,\n"," distribution=\"normal\",\n"," skip_threshold=0.80,\n"," **kwargs\n"," ):\n"," super(RandomCompression, self).__init__(**kwargs)\n"," self.min_energy_factor = min_energy_factor\n"," self.max_energy_factor = max_energy_factor\n"," self.distribution = distribution\n"," self.skip_threshold = skip_threshold\n","\n"," def __call__(self, input, training=True):\n"," if training:\n"," return self.compressed_input(input)\n"," else:\n"," return input\n","\n"," def compressed_input(self, input):\n"," def compress_with_distribution(x):\n"," if self._will_skip():\n"," return None\n"," else:\n"," energy_factor = self._sample_from_distribution()\n"," return compress_image_with_energy(x, energy_factor)\n","\n","\n"," compressed_image = compress_with_distribution(input)\n","\n"," return compressed_image\n","\n"," def _will_skip(self):\n"," uncompressed_threshold = tf.random.uniform(shape=(), minval=0.0, maxval=1.0)\n"," return uncompressed_threshold < self.skip_threshold\n","\n"," def _sample_from_distribution(self):\n"," if self.distribution in [\"gaussian\", \"normal\"]:\n"," return tf.random.normal(shape=(), mean=self.max_energy_factor, stddev=0.025)\n"," elif self.distribution == \"uniform\":\n"," return tf.random.uniform(\n"," shape=(), minval=self.min_energy_factor, maxval=self.max_energy_factor\n"," )\n"," else:\n"," raise ArgumentError(\n"," \"Random Compression layer only supports uniform and gaussian distributions\"\n"," )\n","\n","\n","\n","def preprocess(image, label):\n"," TARGET_SIZE = (180, 180)\n"," image = tf.image.resize(image, TARGET_SIZE)\n"," image = tf.cast(image, tf.float32)\n"," return image, label\n","\n","\n","def extend_ds(ds):\n"," print(f\"\\n{'*'*50}Compressing Images{'*'*50}\")\n","\n"," ds = ds.unbatch()\n"," aug_ds = []\n","\n"," total = len(list(ds))\n"," compression = RandomCompression(max_energy_factor=0.975)\n"," total_compressed_images = 0\n","\n"," for idx, (image, label) in tqdm(enumerate(ds.map(preprocess)), total=total, ncols=110):\n"," compressed_image_0 = compression(image)\n"," if compressed_image_0 is not None:\n"," aug_ds.append((compressed_image_0, label))\n"," total_compressed_images += 1\n","\n","\n"," images = [item[0] for item in aug_ds]\n"," labels = [item[1] for item in aug_ds]\n","\n"," # Create a TensorFlow dataset from the lists\n"," aug_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(BATCH_SZ)\n","\n"," print(\"Total compressed images:\", total_compressed_images) # Print total compressed images\n","\n"," print(f\"{'*'*46}Finished Compressing Images{'*'*46}\\n\")\n","\n"," return aug_ds"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"TQ42qDcetClG"},"outputs":[],"source":["BUFFER_SZ = 1024\n","BATCH_SZ = 32\n","\n","\n","def load_dataset():\n"," ds = tfds.load(\n"," \"cats_vs_dogs\", split=\"train\", as_supervised=True, shuffle_files=True,\n"," )\n","\n"," num_train_samples = int(len(ds) * 0.6) # 60% Train\n"," num_val_samples = int(len(ds) * 0.2) # 20% Validation\n","\n"," ds = ds.shuffle(BUFFER_SZ).prefetch(tf.data.AUTOTUNE)\n","\n"," ds_train = ds.take(num_train_samples).cache()\n"," ds_val = ds.skip(num_train_samples)\n"," ds_test = ds.skip(num_train_samples + num_val_samples)\n","\n","\n"," ds_train = ds_train.map(preprocess).batch(BATCH_SZ)\n"," ds_test = ds_test.map(preprocess).batch(BATCH_SZ)\n"," ds_val = ds_val.map(preprocess).batch(BATCH_SZ)\n","\n"," return ds_train, ds_test, ds_val\n","\n","\n","def split_dataset_kfold(ds_train, k):\n"," num_train_samples = len(list(ds_train))\n"," fold_size = num_train_samples // k\n","\n"," fold_datasets = []\n"," for fold in range(k):\n"," start_index = fold * fold_size\n"," end_index = (fold + 1) * fold_size\n","\n"," ds_val_fold = (\n"," ds_train.skip(start_index).take(fold_size)\n"," )\n","\n"," ds_train_fold_1 = ds_train.take(start_index)\n"," ds_train_fold_2 = ds_train.skip(end_index)\n"," ds_train_fold = ds_train_fold_1.concatenate(ds_train_fold_2)\n","\n"," fold_datasets.append((ds_train_fold, ds_val_fold))\n","\n","\n"," return fold_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"YZhvy8tc7WAb"},"outputs":[],"source":["TrialResult = namedtuple(\"TrialResult\", \"histories generalization_performance\")\n","\n","def _save_trial_result(histories, trial_detail):\n"," dataset_name, augmentation_method, percentage_data, energy_factor = trial_detail\n","\n"," if augmentation_method not in {\"all\", \"svd\"}:\n"," energy_factor = 0\n","\n"," output_path = os.path.join(dataset_name)\n"," file_path = f\"{augmentation_method}_{percentage_data}_{energy_factor}.pkl\"\n"," os.makedirs(output_path, exist_ok=True)\n","\n"," with open(os.path.join(output_path, file_path), \"wb\") as f:\n"," pickle.dump(trial_result, f)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"o982eoLhtNqe"},"outputs":[],"source":["def default_augmentations_layer():\n"," data_augmentation = keras.Sequential(\n"," [\n"," RandomFlip(\"horizontal\"),\n"," RandomRotation(0.1),\n"," RandomZoom(0.2),\n"," ],\n"," name = 'hflip_rot_zoom'\n"," )\n"," return data_augmentation"]},{"cell_type":"code","source":["def build_feat_ext_model():\n"," feature_ext_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_feature_extraction\")\n"," conv_base = feature_ext_model.get_layer(\"vgg16\")\n","\n"," conv_base.trainable = True\n"," for layer in conv_base.layers[:-4]:\n"," layer.trainable = False\n","\n"," feature_ext_model.compile(\n"," optimizer=RMSprop(learning_rate=1e-5),\n"," loss=\"binary_crossentropy\",\n"," metrics=\"accuracy\",\n"," )\n","\n"," return feature_ext_model"],"metadata":{"id":"5u2gnzhGAt-z"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zZOx5Ba4Bpdz","outputId":"1c616d51-737b-44b2-9812-46308905a320","executionInfo":{"status":"ok","timestamp":1715214584968,"user_tz":240,"elapsed":5852,"user":{"displayName":"Kennette Basco","userId":"00881963222092464928"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["146/146 [==============================] - 5s 9ms/step - loss: 0.2285 - accuracy: 0.9854\n"]}],"source":["conv_base = keras.applications.vgg16.VGG16(\n"," weights=\"imagenet\",\n"," include_top=False)\n","conv_base.trainable = False\n","\n","data_augmentation = default_augmentations_layer()\n","\n","inputs = keras.Input(shape=(180, 180, 3))\n","x = data_augmentation(inputs)\n","x = keras.applications.vgg16.preprocess_input(x)\n","x = conv_base(x)\n","x = Flatten()(x)\n","x = Dense(256)(x)\n","x = Dropout(0.5)(x)\n","outputs = Dense(1, activation=\"sigmoid\")(x)\n","model = keras.Model(inputs, outputs)\n","model.compile(loss=\"binary_crossentropy\",\n"," optimizer=\"rmsprop\",\n"," metrics=[\"accuracy\"])\n","\n","\n","trial_detail_1 = TrialDetail(dataset_name=\"cats_vs_dogs\",\n"," percentage_data=1.0,\n"," augmentation_method=\"svd\",\n"," energy_factor=0.975\n"," )\n","print(trial_detail_1)\n","\n","ds_train, ds_test, ds_val = load_dataset()\n","ds_train_augmented = extend_ds(ds_train)\n","ds_train_ext = ds_train.concatenate(ds_train_augmented).shuffle(BUFFER_SZ).cache().prefetch(tf.data.AUTOTUNE)\n","\n","\"\"\"\n","*****************************Feature Extraction*****************************\n","\"\"\"\n","\n","callbacks = [EarlyStopping(patience=10,\n"," monitor=\"val_accuracy\",\n"," restore_best_weights=True),\n"," ModelCheckpoint(filepath=f\"{trial_detail_1.file_str()}_feature_extraction\",\n"," monitor=\"val_accuracy\",\n"," save_best_only=True),\n"," ]\n","\n","print(f\"\\n\\n{'*'*50}\\nFeature Extraction\\n{'*'*50}\\n\\n\")\n","history_1 = model.fit(\n"," ds_train_ext,\n"," epochs=50,\n"," validation_data=ds_val,\n"," callbacks=callbacks,\n",")\n","\n","\"\"\"\n","*****************************Fine Tuning*****************************\n","\"\"\"\n","\n","\n","all_histories = []\n","best_epochs_loss = []\n","best_epochs_acc = []\n","\n","fold_datasets = split_dataset_kfold(ds_train_ext, 5)\n","\n","print(f\"\\n\\n{'*'*50}\\nFine Tuning\\n{'*'*50}\\n\\n\")\n","for fold, (ds_train_fold, ds_val_fold) in enumerate(fold_datasets):\n"," callbacks = get_kfold_callbacks(trial_detail_1, fold)\n"," feature_ext_model = build_feat_ext_model()\n","\n"," history = feature_ext_model.fit(\n"," ds_train_fold,\n"," epochs=100,\n"," validation_data=ds_val_fold,\n"," callbacks=callbacks,\n"," )\n","\n"," all_histories.append(history.history)\n"," best_epochs_loss.append(np.argmin(history.history['val_loss']))\n"," best_epochs_acc.append(np.argmax(history.history['val_accuracy']))\n","\n","\n","\"\"\"\n","*****************************Retraining on Entire Dataset*****************************\n","\"\"\"\n","\n","ds_train_ext = ds_train_ext.concatenate(ds_val).shuffle(BUFFER_SZ).cache().prefetch(tf.data.AUTOTUNE)\n","\n","print(f\"\\n\\n{'*'*50}\\nRetraining\\n{'*'*50}\\n\\n\")\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"naive\")\n","best_epoch = int(50 * 1.2)\n","history_2 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"argmin_loss\")\n","best_epoch = int(np.mean(best_epochs_loss) * 1.2)\n","history_3 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"argmax_acc\")\n","best_epoch = int(np.mean(best_epochs_acc) * 1.2)\n","history_4 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_feature_extraction\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Feature Extraction Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_fine_tuning\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Fine Tuning Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_naive\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Naive Retraining Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_argmin_loss\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Argmin Retraining accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_argmax_acc\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Argmax Retraining accuracy: {test_acc:.3f}\")\n","\n","generalization_performance = test_model.evaluate(ds_test)\n","trial_result = TrialResult(all_histories, generalization_performance)\n","_save_trial_result(trial_result, trial_detail_1)\n","\n","plot_history(history_1)\n","plot_history(history_2)\n","plot_history(history_3)\n","plot_history(history_4)"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"A100","machine_shape":"hm","provenance":[{"file_id":"1fm5oDBimTjUo4VGCWx3LPXPwWS76YrxB","timestamp":1714596391673}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0} |