diff --git a/lib/train.py b/lib/train.py index be6c4bb..19403e2 100644 --- a/lib/train.py +++ b/lib/train.py @@ -134,10 +134,10 @@ def rgb_ratio_loss(self, clear_image, generated_image): def get_models_and_paths(self): import os - generator_clear2fog_weights_path = os.path.join(self.weights_path, 'generator_clear2fog.h5') - generator_fog2clear_weights_path = os.path.join(self.weights_path, 'generator_fog2clear.h5') - discriminator_clear_weights_path = os.path.join(self.weights_path, 'discriminator_clear.h5') - discriminator_fog_weights_path = os.path.join(self.weights_path, 'discriminator_fog.h5') + generator_clear2fog_weights_path = os.path.join(self.weights_path, 'generator_clear2fog.weights.h5') + generator_fog2clear_weights_path = os.path.join(self.weights_path, 'generator_fog2clear.weights.h5') + discriminator_clear_weights_path = os.path.join(self.weights_path, 'discriminator_clear.weights.h5') + discriminator_fog_weights_path = os.path.join(self.weights_path, 'discriminator_fog.weights.h5') models = [self.generator_clear2fog, self.generator_fog2clear, self.discriminator_clear,