Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
kankenny authored Jun 5, 2024
1 parent 56ece96 commit 477082b
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions SVDNet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"C0ffGllJpULR"},"outputs":[],"source":["import keras\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import os\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","import pickle\n","from collections import namedtuple\n","from ctypes import ArgumentError\n","from dataclasses import dataclass\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.vgg16 import preprocess_input\n","from keras.layers import RandomZoom, RandomFlip, RandomRotation, Input, Dense, Dropout, GlobalAveragePooling2D, Flatten\n","from keras.models import Model\n","from keras.optimizers import RMSprop\n","from textwrap import dedent\n","from tqdm import tqdm\n","from keras.callbacks import (\n"," ModelCheckpoint,\n"," EarlyStopping,\n"," Callback,\n",")\n","\n","keras.mixed_precision.set_global_policy(\"mixed_float16\")\n","\n","SEED = 42\n","\n","def set_global_determinism():\n"," os.environ['PYTHONHASHSEED'] = str(SEED)\n"," tf.random.set_seed(SEED)\n"," np.random.seed(SEED)\n","\n"," os.environ['TF_DETERMINISTIC_OPS'] = '1'\n"," os.environ['TF_CUDNN_DETERMINISTIC'] = '1'\n","\n"," tf.config.threading.set_inter_op_parallelism_threads(1)\n"," tf.config.threading.set_intra_op_parallelism_threads(1)\n","\n","\n","set_global_determinism()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"7KvkQ_JBQ79F"},"outputs":[],"source":["@tf.function\n","def compress_image_with_energy(image, energy_factor=0.9):\n"," # Returns a compressed image based on a desired energy factor\n"," image_rescaled = tf.convert_to_tensor(image)\n"," image_batched = tf.transpose(image_rescaled, [2, 0, 1])\n"," s, U, V = tf.linalg.svd(image_batched, compute_uv=True, full_matrices=False)\n","\n"," # Extracting singular values\n"," props_rgb = tf.map_fn(lambda x: tf.cumsum(x) / tf.reduce_sum(x), s)\n"," props_rgb_mean = tf.reduce_mean(props_rgb, axis=0)\n","\n"," # Find closest k that corresponds to the energy factor\n"," k = tf.argmin(tf.abs(props_rgb_mean - energy_factor)) + 1\n","\n"," # Compute the low-rank approximation\n"," s_k, U_k, V_k = s[..., :k], U[..., :, :k], V[..., :, :k]\n"," A_k = tf.einsum(\"...s,...us,...vs->...uv\", s_k, U_k, V_k)\n","\n"," return tf.transpose(A_k, [1, 2, 0])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"t4Y7ypGP3Qxk"},"outputs":[],"source":["@dataclass\n","class TrialDetail:\n"," dataset_name: str\n"," augmentation_method: str\n"," energy_factor: float\n"," percentage_data: float = 1\n","\n"," def __getitem__(self, idx):\n"," props = (\n"," self.dataset_name,\n"," self.augmentation_method,\n"," self.percentage_data,\n"," self.energy_factor,\n"," )\n"," return props[idx]\n","\n"," def __repr__(self):\n"," energy_factor_repr = (\n"," f\"{self.energy_factor:.2%}\"\n"," if self.augmentation_method in {\"all\", \"svd\"}\n"," else \"N/A\"\n"," )\n"," return dedent(\n"," f\"\"\"\n"," Model & Dataset Name: {self.dataset_name.upper()}\n"," Percentage of Data Used: {self.percentage_data:.2%}\n"," Augmentation Method: {self.augmentation_method.upper()}\n"," Energy Factor: {energy_factor_repr}\n"," \"\"\"\n"," )\n","\n"," def file_str(self):\n"," return f\"{self.dataset_name}_{self.augmentation_method}_{self.percentage_data}_{self.energy_factor}\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dWLYaTy72Prr"},"outputs":[],"source":["def get_kfold_callbacks(trial_detail: TrialDetail, fold_num):\n"," checkpoint_cb = ModelCheckpoint(\n"," f\"{trial_detail.file_str()}_fine_tuning\",\n"," save_best_only=True,\n"," monitor=\"val_accuracy\",\n"," )\n"," kfold_train_details_logger_cb = _KFoldTrainingDetailsLogger(trial_detail, fold_num)\n","\n"," return [\n"," checkpoint_cb,\n"," kfold_train_details_logger_cb,\n"," ]\n","\n","\n","def get_retrain_callbacks(trial_detail: TrialDetail, purpose=\"retrain\"):\n"," checkpoint_cb = ModelCheckpoint(\n"," f\"{trial_detail.file_str()}_{purpose}\",\n"," save_best_only=True,\n"," monitor=\"accuracy\",\n"," )\n"," retrain_train_details_logger_cb = _RetrainDetailsLogger(trial_detail)\n","\n"," return [\n"," checkpoint_cb,\n"," retrain_train_details_logger_cb,\n"," ]\n","\n","\n","class _KFoldTrainingDetailsLogger(Callback):\n"," def __init__(self, trial_detail, fold_num):\n"," self.trial_detail = trial_detail\n"," self.fold_num = fold_num + 1\n","\n"," def on_train_begin(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n\\nSTART OF TRAINING - FOLD #{self.fold_num}:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n"," def on_train_end(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n\\nEND OF TRAINING - FOLD #{self.fold_num}:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n","\n","class _RetrainDetailsLogger(Callback):\n"," def __init__(self, trial_detail):\n"," self.trial_detail = trial_detail\n","\n"," def on_train_begin(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\nRETRAINING ON ENTIRE DATASET:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )\n","\n"," def on_train_end(self, logs=None):\n"," print(\n"," dedent(\n"," f\"\"\"\n"," \\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\nEND OF RETRAINING:\\n{self.trial_detail!r}\\n\\n{'*' * 80}\\n{'*' * 80}\\n{'*' * 80}\\n\\n\"\"\"\n"," )\n"," )"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JVDj8PJyN9hD"},"outputs":[],"source":["class RandomCompression():\n"," \"\"\"Utilizes low rank approximation of images using SVD with randomized energy factors\"\"\"\n"," def __init__(\n"," self,\n"," max_energy_factor=0.975,\n"," min_energy_factor=0.95,\n"," distribution=\"normal\",\n"," skip_threshold=0.80,\n"," **kwargs\n"," ):\n"," super(RandomCompression, self).__init__(**kwargs)\n"," self.min_energy_factor = min_energy_factor\n"," self.max_energy_factor = max_energy_factor\n"," self.distribution = distribution\n"," self.skip_threshold = skip_threshold\n","\n"," def __call__(self, input, training=True):\n"," if training:\n"," return self.compressed_input(input)\n"," else:\n"," return input\n","\n"," def compressed_input(self, input):\n"," def compress_with_distribution(x):\n"," if self._will_skip():\n"," return None\n"," else:\n"," energy_factor = self._sample_from_distribution()\n"," return compress_image_with_energy(x, energy_factor)\n","\n","\n"," compressed_image = compress_with_distribution(input)\n","\n"," return compressed_image\n","\n"," def _will_skip(self):\n"," uncompressed_threshold = tf.random.uniform(shape=(), minval=0.0, maxval=1.0)\n"," return uncompressed_threshold < self.skip_threshold\n","\n"," def _sample_from_distribution(self):\n"," if self.distribution in [\"gaussian\", \"normal\"]:\n"," return tf.random.normal(shape=(), mean=self.max_energy_factor, stddev=0.025)\n"," elif self.distribution == \"uniform\":\n"," return tf.random.uniform(\n"," shape=(), minval=self.min_energy_factor, maxval=self.max_energy_factor\n"," )\n"," else:\n"," raise ArgumentError(\n"," \"Random Compression layer only supports uniform and gaussian distributions\"\n"," )\n","\n","\n","\n","def preprocess(image, label):\n"," TARGET_SIZE = (180, 180)\n"," image = tf.image.resize(image, TARGET_SIZE)\n"," image = tf.cast(image, tf.float32)\n"," return image, label\n","\n","\n","def extend_ds(ds):\n"," print(f\"\\n{'*'*50}Compressing Images{'*'*50}\")\n","\n"," ds = ds.unbatch()\n"," aug_ds = []\n","\n"," total = len(list(ds))\n"," compression = RandomCompression(max_energy_factor=0.975)\n"," total_compressed_images = 0\n","\n"," for idx, (image, label) in tqdm(enumerate(ds.map(preprocess)), total=total, ncols=110):\n"," compressed_image_0 = compression(image)\n"," if compressed_image_0 is not None:\n"," aug_ds.append((compressed_image_0, label))\n"," total_compressed_images += 1\n","\n","\n"," images = [item[0] for item in aug_ds]\n"," labels = [item[1] for item in aug_ds]\n","\n"," # Create a TensorFlow dataset from the lists\n"," aug_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(BATCH_SZ)\n","\n"," print(\"Total compressed images:\", total_compressed_images) # Print total compressed images\n","\n"," print(f\"{'*'*46}Finished Compressing Images{'*'*46}\\n\")\n","\n"," return aug_ds"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"TQ42qDcetClG"},"outputs":[],"source":["BUFFER_SZ = 1024\n","BATCH_SZ = 32\n","\n","\n","def load_dataset():\n"," ds = tfds.load(\n"," \"cats_vs_dogs\", split=\"train\", as_supervised=True, shuffle_files=True,\n"," )\n","\n"," num_train_samples = int(len(ds) * 0.6) # 60% Train\n"," num_val_samples = int(len(ds) * 0.2) # 20% Validation\n","\n"," ds = ds.shuffle(BUFFER_SZ).prefetch(tf.data.AUTOTUNE)\n","\n"," ds_train = ds.take(num_train_samples).cache()\n"," ds_val = ds.skip(num_train_samples)\n"," ds_test = ds.skip(num_train_samples + num_val_samples)\n","\n","\n"," ds_train = ds_train.map(preprocess).batch(BATCH_SZ)\n"," ds_test = ds_test.map(preprocess).batch(BATCH_SZ)\n"," ds_val = ds_val.map(preprocess).batch(BATCH_SZ)\n","\n"," return ds_train, ds_test, ds_val\n","\n","\n","def split_dataset_kfold(ds_train, k):\n"," num_train_samples = len(list(ds_train))\n"," fold_size = num_train_samples // k\n","\n"," fold_datasets = []\n"," for fold in range(k):\n"," start_index = fold * fold_size\n"," end_index = (fold + 1) * fold_size\n","\n"," ds_val_fold = (\n"," ds_train.skip(start_index).take(fold_size)\n"," )\n","\n"," ds_train_fold_1 = ds_train.take(start_index)\n"," ds_train_fold_2 = ds_train.skip(end_index)\n"," ds_train_fold = ds_train_fold_1.concatenate(ds_train_fold_2)\n","\n"," fold_datasets.append((ds_train_fold, ds_val_fold))\n","\n","\n"," return fold_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"YZhvy8tc7WAb"},"outputs":[],"source":["TrialResult = namedtuple(\"TrialResult\", \"histories generalization_performance\")\n","\n","def _save_trial_result(histories, trial_detail):\n"," dataset_name, augmentation_method, percentage_data, energy_factor = trial_detail\n","\n"," if augmentation_method not in {\"all\", \"svd\"}:\n"," energy_factor = 0\n","\n"," output_path = os.path.join(dataset_name)\n"," file_path = f\"{augmentation_method}_{percentage_data}_{energy_factor}.pkl\"\n"," os.makedirs(output_path, exist_ok=True)\n","\n"," with open(os.path.join(output_path, file_path), \"wb\") as f:\n"," pickle.dump(trial_result, f)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"o982eoLhtNqe"},"outputs":[],"source":["def default_augmentations_layer():\n"," data_augmentation = keras.Sequential(\n"," [\n"," RandomFlip(\"horizontal\"),\n"," RandomRotation(0.1),\n"," RandomZoom(0.2),\n"," ],\n"," name = 'hflip_rot_zoom'\n"," )\n"," return data_augmentation"]},{"cell_type":"code","source":["def build_feat_ext_model():\n"," feature_ext_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_feature_extraction\")\n"," conv_base = feature_ext_model.get_layer(\"vgg16\")\n","\n"," conv_base.trainable = True\n"," for layer in conv_base.layers[:-4]:\n"," layer.trainable = False\n","\n"," feature_ext_model.compile(\n"," optimizer=RMSprop(learning_rate=1e-5),\n"," loss=\"binary_crossentropy\",\n"," metrics=\"accuracy\",\n"," )\n","\n"," return feature_ext_model"],"metadata":{"id":"5u2gnzhGAt-z"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zZOx5Ba4Bpdz","outputId":"1c616d51-737b-44b2-9812-46308905a320","executionInfo":{"status":"ok","timestamp":1715214584968,"user_tz":240,"elapsed":5852,"user":{"displayName":"Kennette Basco","userId":"00881963222092464928"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["146/146 [==============================] - 5s 9ms/step - loss: 0.2285 - accuracy: 0.9854\n"]}],"source":["conv_base = keras.applications.vgg16.VGG16(\n"," weights=\"imagenet\",\n"," include_top=False)\n","conv_base.trainable = False\n","\n","data_augmentation = default_augmentations_layer()\n","\n","inputs = keras.Input(shape=(180, 180, 3))\n","x = data_augmentation(inputs)\n","x = keras.applications.vgg16.preprocess_input(x)\n","x = conv_base(x)\n","x = Flatten()(x)\n","x = Dense(256)(x)\n","x = Dropout(0.5)(x)\n","outputs = Dense(1, activation=\"sigmoid\")(x)\n","model = keras.Model(inputs, outputs)\n","model.compile(loss=\"binary_crossentropy\",\n"," optimizer=\"rmsprop\",\n"," metrics=[\"accuracy\"])\n","\n","\n","trial_detail_1 = TrialDetail(dataset_name=\"cats_vs_dogs\",\n"," percentage_data=1.0,\n"," augmentation_method=\"svd\",\n"," energy_factor=0.975\n"," )\n","print(trial_detail_1)\n","\n","ds_train, ds_test, ds_val = load_dataset()\n","ds_train_augmented = extend_ds(ds_train)\n","ds_train_ext = ds_train.concatenate(ds_train_augmented).shuffle(BUFFER_SZ).cache().prefetch(tf.data.AUTOTUNE)\n","\n","\"\"\"\n","*****************************Feature Extraction*****************************\n","\"\"\"\n","\n","callbacks = [EarlyStopping(patience=10,\n"," monitor=\"val_accuracy\",\n"," restore_best_weights=True),\n"," ModelCheckpoint(filepath=f\"{trial_detail_1.file_str()}_feature_extraction\",\n"," monitor=\"val_accuracy\",\n"," save_best_only=True),\n"," ]\n","\n","print(f\"\\n\\n{'*'*50}\\nFeature Extraction\\n{'*'*50}\\n\\n\")\n","history_1 = model.fit(\n"," ds_train_ext,\n"," epochs=50,\n"," validation_data=ds_val,\n"," callbacks=callbacks,\n",")\n","\n","\"\"\"\n","*****************************Fine Tuning*****************************\n","\"\"\"\n","\n","\n","all_histories = []\n","best_epochs_loss = []\n","best_epochs_acc = []\n","\n","fold_datasets = split_dataset_kfold(ds_train_ext, 5)\n","\n","print(f\"\\n\\n{'*'*50}\\nFine Tuning\\n{'*'*50}\\n\\n\")\n","for fold, (ds_train_fold, ds_val_fold) in enumerate(fold_datasets):\n"," callbacks = get_kfold_callbacks(trial_detail_1, fold)\n"," feature_ext_model = build_feat_ext_model()\n","\n"," history = feature_ext_model.fit(\n"," ds_train_fold,\n"," epochs=100,\n"," validation_data=ds_val_fold,\n"," callbacks=callbacks,\n"," )\n","\n"," all_histories.append(history.history)\n"," best_epochs_loss.append(np.argmin(history.history['val_loss']))\n"," best_epochs_acc.append(np.argmax(history.history['val_accuracy']))\n","\n","\n","\"\"\"\n","*****************************Retraining on Entire Dataset*****************************\n","\"\"\"\n","\n","ds_train_ext = ds_train_ext.concatenate(ds_val).shuffle(BUFFER_SZ).cache().prefetch(tf.data.AUTOTUNE)\n","\n","print(f\"\\n\\n{'*'*50}\\nRetraining\\n{'*'*50}\\n\\n\")\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"naive\")\n","best_epoch = int(50 * 1.2)\n","history_2 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"argmin_loss\")\n","best_epoch = int(np.mean(best_epochs_loss) * 1.2)\n","history_3 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","feature_ext_model = build_feat_ext_model()\n","callbacks = get_retrain_callbacks(trial_detail_1, \"argmax_acc\")\n","best_epoch = int(np.mean(best_epochs_acc) * 1.2)\n","history_4 = feature_ext_model.fit(\n"," ds_train_ext,\n"," epochs=best_epoch,\n"," callbacks=callbacks,\n"," verbose=1,\n",")\n","\n","\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_feature_extraction\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Feature Extraction Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_fine_tuning\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Fine Tuning Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_naive\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Naive Retraining Test accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_argmin_loss\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Argmin Retraining accuracy: {test_acc:.3f}\")\n","\n","test_model = keras.models.load_model(f\"{trial_detail_1.file_str()}_argmax_acc\")\n","_, test_acc = test_model.evaluate(ds_test)\n","print(f\"Argmax Retraining accuracy: {test_acc:.3f}\")\n","\n","generalization_performance = test_model.evaluate(ds_test)\n","trial_result = TrialResult(all_histories, generalization_performance)\n","_save_trial_result(trial_result, trial_detail_1)\n","\n","plot_history(history_1)\n","plot_history(history_2)\n","plot_history(history_3)\n","plot_history(history_4)"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"A100","machine_shape":"hm","provenance":[{"file_id":"1fm5oDBimTjUo4VGCWx3LPXPwWS76YrxB","timestamp":1714596391673}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}

0 comments on commit 477082b

Please sign in to comment.