Skip to content

Commit

Permalink
chore: adjust notebook for testing (#18)
Browse files Browse the repository at this point in the history
* chore: adjust notebook for testing

* adjust

* make param

* fix poetry file

* adjust header
  • Loading branch information
ghaiszaher authored Nov 23, 2024
1 parent dc80671 commit 8adecfd
Show file tree
Hide file tree
Showing 10 changed files with 1,409 additions and 1,335 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ image_logs/
trainer_config.json
dataset/
intensity/
custom-images/

!.idea/dictionaries/
161 changes: 102 additions & 59 deletions Foggy_CycleGAN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
},
"source": [
"import sys\n",
"\n",
"colab = 'google.colab' in sys.modules\n",
"import tensorflow as tf"
],
Expand Down Expand Up @@ -121,10 +122,10 @@
" os.environ['PROJECT_DIR'] = project_dir = '/content/Foggy-CycleGAN'\n",
" replace = True\n",
" if os.path.isdir(project_dir):\n",
" choice = input(\"Project already exists in folder \"+\n",
" \"{}\\nDelete the files and pull again? Enter Y/(N):\\n\"\n",
" .format(project_dir))\n",
" if choice.lower()=='y':\n",
" choice = input(\"Project already exists in folder \" +\n",
" \"{}\\nDelete the files and pull again? Enter Y/(N):\\n\"\n",
" .format(project_dir))\n",
" if choice.lower() == 'y':\n",
" !rm -r $PROJECT_DIR\n",
" print(\"Deleted folder {}\".format(project_dir))\n",
" else:\n",
Expand Down Expand Up @@ -171,24 +172,23 @@
"colab_type": "code",
"colab": {}
},
"source": [
"project_label = \"\" #@param {type:\"string\"}"
],
"source": "project_label = \"\" #@param {type:\"string\"}",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"source": [
"mount_path = None #to suppress warnings\n",
"mount_path = None #to suppress warnings\n",
"drive_project_path = None\n",
"if colab:\n",
" # noinspection PyUnresolvedReferences\n",
" from google.colab import drive\n",
"\n",
" mount_path = '/content/drive'\n",
" drive.mount(mount_path, force_remount=True)\n",
" drive_project_path = os.path.join(mount_path,\"My Drive/Colab Notebooks/Foggy-CycleGAN/\",project_label)\n",
" drive_datasets_path = os.path.join(mount_path,\"My Drive/Colab Notebooks/Datasets/\")\n",
" drive_project_path = os.path.join(mount_path, \"My Drive/Colab Notebooks/Foggy-CycleGAN/\", project_label)\n",
" drive_datasets_path = os.path.join(mount_path, \"My Drive/Colab Notebooks/Datasets/\")\n",
" os.environ['DRIVE_PROJECT'] = drive_project_path\n",
" os.environ['DRIVE_DATASETS'] = drive_datasets_path"
],
Expand Down Expand Up @@ -226,9 +226,7 @@
"colab_type": "code",
"colab": {}
},
"source": [
"test_split = 0.2 #@param {type:\"slider\", min:0.05, max:0.95, step:0.05}"
],
"source": "test_split = 0.2 #@param {type:\"slider\", min:0.05, max:0.95, step:0.05}",
"outputs": [],
"execution_count": null
},
Expand All @@ -238,7 +236,7 @@
"from lib.dataset import DatasetInitializer\n",
"\n",
"datasetInit = DatasetInitializer(256, 256)\n",
"datasetInit.dataset_path = '/content/dataset/' if colab else './dataset/'\n",
"datasetInit.dataset_path = '/content/dataset/' if colab else './dataset/'\n",
"(train_clear, train_fog), (test_clear, test_fog), (sample_clear, sample_fog) = datasetInit.prepare_dataset(\n",
" BATCH_SIZE,\n",
" test_split=test_split,\n",
Expand Down Expand Up @@ -275,6 +273,7 @@
},
"source": [
"from lib.models import ModelsBuilder\n",
"\n",
"OUTPUT_CHANNELS = 3\n",
"models_builder = ModelsBuilder()"
],
Expand All @@ -289,11 +288,11 @@
"colab": {}
},
"source": [
"use_transmission_map = False #@param{type: \"boolean\"}\n",
"use_gauss_filter = False #@param{type: \"boolean\"}\n",
"use_transmission_map = False #@param{type: \"boolean\"}\n",
"use_gauss_filter = False #@param{type: \"boolean\"}\n",
"if use_gauss_filter and not use_transmission_map:\n",
" raise Exception(\"Gauss filter requires transmission map\")\n",
"use_resize_conv = False #@param{type: \"boolean\"}\n",
"use_resize_conv = False #@param{type: \"boolean\"}\n",
"\n",
"generator_clear2fog = models_builder.build_generator(use_transmission_map=use_transmission_map,\n",
" use_gauss_filter=use_gauss_filter,\n",
Expand Down Expand Up @@ -356,7 +355,7 @@
"colab": {}
},
"source": [
"use_intensity_for_fog_discriminator = False #@param{type: \"boolean\"}\n",
"use_intensity_for_fog_discriminator = False #@param{type: \"boolean\"}\n",
"discriminator_fog = models_builder.build_discriminator(use_intensity=use_intensity_for_fog_discriminator)\n",
"discriminator_clear = models_builder.build_discriminator(use_intensity=False)"
],
Expand Down Expand Up @@ -415,10 +414,11 @@
},
"source": [
"from lib.train import Trainer\n",
"\n",
"trainer = Trainer(generator_clear2fog, generator_fog2clear,\n",
" discriminator_fog, discriminator_clear)\n",
" discriminator_fog, discriminator_clear)\n",
"\n",
"trainer.configure_checkpoint(weights_path = weights_path, load_optimizers=False)"
"trainer.configure_checkpoint(weights_path=weights_path, load_optimizers=False)"
],
"outputs": [],
"execution_count": null
Expand All @@ -435,6 +435,7 @@
},
"source": [
"from lib.plot import plot_generators_predictions\n",
"\n",
"for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
" plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
],
Expand All @@ -453,6 +454,7 @@
},
"source": [
"from lib.plot import plot_discriminators_predictions\n",
"\n",
"for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
" plot_discriminators_predictions(discriminator_clear, clear, discriminator_fog, fog, use_intensity_for_fog_discriminator)"
],
Expand All @@ -479,9 +481,7 @@
"colab_type": "code",
"colab": {}
},
"source": [
"use_tensorboard = True #@param{type:\"boolean\"}"
],
"source": "use_tensorboard = True #@param{type:\"boolean\"}",
"outputs": [],
"execution_count": null
},
Expand All @@ -498,9 +498,10 @@
"source": [
"if use_tensorboard:\n",
" import tensorboard\n",
"\n",
" tb = tensorboard.program.TensorBoard()\n",
" if colab:\n",
" trainer.tensorboard_base_logdir = os.path.join(drive_project_path,\"tensorboard_logs/\")\n",
" trainer.tensorboard_base_logdir = os.path.join(drive_project_path, \"tensorboard_logs/\")\n",
" tb.configure(argv=[None, '--logdir', trainer.tensorboard_base_logdir])\n",
" url = tb.launch()\n",
" if colab:\n",
Expand All @@ -523,8 +524,8 @@
},
"source": [
"if colab:\n",
" trainer.image_log_path = os.path.join(drive_project_path,\"image_logs/\")\n",
" trainer.config_path = os.path.join(drive_project_path,\"trainer_config.json\")"
" trainer.image_log_path = os.path.join(drive_project_path, \"image_logs/\")\n",
" trainer.config_path = os.path.join(drive_project_path, \"trainer_config.json\")"
],
"outputs": [],
"execution_count": null
Expand Down Expand Up @@ -553,10 +554,10 @@
"colab": {}
},
"source": [
"use_transmission_map_loss=True #@param{type: \"boolean\"}\n",
"use_whitening_loss=True #@param{type: \"boolean\"}\n",
"use_rgb_ratio_loss=True #@param{type: \"boolean\"}\n",
"save_optimizers=False #@param{type: \"boolean\"}\n",
"use_transmission_map_loss = True #@param{type: \"boolean\"}\n",
"use_whitening_loss = True #@param{type: \"boolean\"}\n",
"use_rgb_ratio_loss = True #@param{type: \"boolean\"}\n",
"save_optimizers = False #@param{type: \"boolean\"}\n",
"\n",
"trainer.train(\n",
" train_clear, train_fog,\n",
Expand Down Expand Up @@ -597,7 +598,6 @@
"colab": {}
},
"source": [
"# TODO: store predictions\n",
"for clear, fog in zip(test_clear.take(5), test_fog.take(5)):\n",
" plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
],
Expand Down Expand Up @@ -627,14 +627,15 @@
"\n",
"intensity_path = './intensity/'\n",
"from lib.tools import create_dir\n",
"\n",
"create_dir(intensity_path)\n",
"\n",
"image_clear = next(iter(test_clear))[0][0]\n",
"step = 0.05\n",
"for (ind, i) in enumerate(tf.range(0,1+step, step)):\n",
" fig = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
"for (ind, i) in enumerate(tf.range(0, 1 + step, step)):\n",
" fig, _ = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
" fig.savefig(os.path.join(intensity_path\n",
" , \"{:02d}_intensity_{:0.2f}.jpg\".format(ind,i)), bbox_inches='tight', pad_inches=0)\n",
" , \"{:02d}_intensity_{:0.2f}.jpg\".format(ind, i)), bbox_inches='tight', pad_inches=0)\n",
" if colab:\n",
" plt.show()\n",
" else:\n",
Expand Down Expand Up @@ -666,9 +667,7 @@
},
{
"cell_type": "markdown",
"source": [
"## Testing Custom images\n"
],
"source": "## Testing Custom images - Plot all Results",
"metadata": {
"collapsed": false,
"pycharm": {
Expand All @@ -680,24 +679,44 @@
"cell_type": "code",
"source": [
"from lib.plot import plot_clear2fog_intensity\n",
"from lib.tools import create_dir\n",
"from matplotlib import pyplot as plt\n",
"\n",
"intensity_path = './intensity/'\n",
"from lib.tools import create_dir\n",
"create_dir(intensity_path)\n",
"file_path = './Downloads/test-image.png'\n",
"input_path = './custom-images/input/'\n",
"output_path = './custom-images/output/'\n",
"create_dir(input_path)\n",
"create_dir(output_path)\n",
"\n",
"image_clear = tf.io.decode_png(tf.io.read_file(file_path), channels=3)\n",
"image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n",
"step = 0.05\n",
"for (ind, i) in enumerate(tf.range(0,1+step, step)):\n",
" fig = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
" fig.savefig(os.path.join(intensity_path\n",
" , \"{:02d}_intensity_{:0.2f}.jpg\".format(ind,i)), bbox_inches='tight', pad_inches=0)\n",
" if colab:\n",
" plt.show()\n",
" else:\n",
" plt.close(fig)"
"all_files = os.listdir(input_path)\n",
"print(\"Files in input folder: \", all_files)\n",
"for file_name in all_files:\n",
" if not file_name.endswith('.png'):\n",
" print(\"Skipping file: \", file_name)\n",
" continue\n",
" file_path = os.path.join(input_path, file_name)\n",
" output_folder_path = os.path.join(output_path, file_name)\n",
" print(\"Creating folder: \", output_folder_path)\n",
" create_dir(output_folder_path)\n",
"\n",
" image_clear = tf.io.decode_png(tf.io.read_file(file_path), channels=3)\n",
" image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n",
" step = 0.05\n",
" for (ind, i) in enumerate(tf.range(0,1+step, step)):\n",
" fig, foggy_image = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
"\n",
" figure_path = os.path.join(output_folder_path\n",
" , \"figure_{:02d}_{:0.2f}.jpg\".format(ind,i))\n",
" fig.savefig(figure_path, bbox_inches='tight', pad_inches=0)\n",
" plt.close(fig)\n",
" print(\"Saved figure: \", figure_path)\n",
"\n",
" foggy_image_uint8 = tf.image.convert_image_dtype(foggy_image, dtype=tf.uint8)\n",
" encoded_image = tf.io.encode_png(foggy_image_uint8)\n",
" foggy_image_path = os.path.join(output_folder_path\n",
" , \"foggy_{:02d}_{:0.2f}.jpg\".format(ind,i))\n",
" tf.io.write_file(foggy_image_path, encoded_image)\n",
" print(f\"Saved foggy image to {output_path}\")\n",
" print(\"Done with file: \", file_name)"
],
"metadata": {
"collapsed": false,
Expand All @@ -709,17 +728,41 @@
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"if colab:\n",
" !cd ./intensity; zip /content/intensity.zip *"
" !cd ./custom-images; zip -r /content/custom-images.zip *"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Testing Single Custom Image Output"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from lib.plot import plot_clear2fog_intensity\n",
"import tensorflow as tf\n",
"\n",
"input_path = './custom-images/input/image.png' #@param {type:\"string\"}\n",
"fog_intensity = 0.4 #@param {type:\"slider\", min:0.05, max:1.00, step:0.05}\n",
"output_path = './custom-images/output/image.png' #@param {type:\"string\"}\n",
"\n",
"image_clear = tf.io.decode_png(tf.io.read_file(input_path), channels=3)\n",
"image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n",
"\n",
"\n",
"fig, foggy_image = plot_clear2fog_intensity(generator_clear2fog, image_clear, fog_intensity)\n",
"foggy_image_uint8 = tf.image.convert_image_dtype(foggy_image, dtype=tf.uint8)\n",
"encoded_image = tf.io.encode_png(foggy_image_uint8)\n",
"tf.io.write_file(output_path, encoded_image)\n",
"print(f\"Saved foggy image to {output_path}\")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"execution_count": null
}
Expand Down
Binary file added custom-images/input/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added custom-images/output/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified discriminator_clear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified discriminator_fog.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified generator_clear2fog.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified generator_fog2clear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions lib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,17 @@ def plot_clear2fog_intensity(model_clear2fog, image_clear, intensity=0.5,
fig = plt.figure(figsize=(12, 6))

display_list = [image_clear, prediction_clear2fog[0]]
if normalized_input:
display_list = [item * 0.5 + 0.5 for item in display_list]
title = ['Clear', 'To Fog {:0.2}'.format(original_intensity)]

for i in range(2):
plt.subplot(1, 2, i + 1)
plt.title(title[i])
to_display = display_list[i]
if normalized_input:
to_display = to_display * 0.5 + 0.5
plt.imshow(to_display)
plt.axis('off')

if close_fig:
plt.close(fig)
return fig
return fig, display_list[1]
Loading

0 comments on commit 8adecfd

Please sign in to comment.