From afd7387dfc4d7f03586da44398f0fdddfa631ec1 Mon Sep 17 00:00:00 2001 From: jhp Date: Sun, 21 Oct 2018 20:07:38 +0900 Subject: [PATCH] #74 : translate more codes --- ...4 \353\260\224\352\276\270\352\270\260.md" | 75 ++++++++++++++++++- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git "a/74_\353\224\245\353\237\254\353\213\235(CycleGAN)\354\235\204 \354\235\264\354\232\251\355\225\264 Fornite \353\245\274 PUBG \353\241\234 \353\260\224\352\276\270\352\270\260.md" "b/74_\353\224\245\353\237\254\353\213\235(CycleGAN)\354\235\204 \354\235\264\354\232\251\355\225\264 Fornite \353\245\274 PUBG \353\241\234 \353\260\224\352\276\270\352\270\260.md" index 0665a88..bc906ad 100644 --- "a/74_\353\224\245\353\237\254\353\213\235(CycleGAN)\354\235\204 \354\235\264\354\232\251\355\225\264 Fornite \353\245\274 PUBG \353\241\234 \353\260\224\352\276\270\352\270\260.md" +++ "b/74_\353\224\245\353\237\254\353\213\235(CycleGAN)\354\235\204 \354\235\264\354\232\251\355\225\264 Fornite \353\245\274 PUBG \353\241\234 \353\260\224\352\276\270\352\270\260.md" @@ -240,15 +240,13 @@ def __init__(self): img_A_id = self.g_BA(img_A) img_B_id = self.g_AB(img_B) - # For the combined model we will only train the generators + # generaotr를 학습할 때는 discriminator는 학습하지 않습니다. self.d_A.trainable = False self.d_B.trainable = False - # Discriminators determines validity of translated images valid_A = self.d_A(fake_A) valid_B = self.d_B(fake_B) - # Combined model trains generators to fool discriminators self.combined = Model(inputs=[img_A, img_B], outputs=[ valid_A, valid_B, reconstr_A, reconstr_B, @@ -266,6 +264,7 @@ def __init__(self): > build_generator() 는 **Generator의 구조** 를 만듭니다. 이 코드에서는 U-Net 을 Generator 로 사용했습니다. > > conv2d는 downsampling을 통해 input image의 특징을 추출하고, deconv2d는 upsampling을 통해 이미지의 스타일을 바꿔(translation)주는 용도로 사용합니다. +> ```python def build_generator(self): """U-Net Generator""" @@ -308,7 +307,8 @@ def build_generator(self): ``` > #### **3. build_discriminator()** > -> build_discriminator()가 하는 일은 +> build_discriminator()가 하는 일은 이미지가 진짜인지 가짜인지 판별하는 것입니다. 여기서 진짜라는 의미는 우리가 input 으로 넣은 이미지이고 가짜라는 의미는 generator가 만든 이미지입니다. +> ```python def build_discriminator(self): @@ -331,7 +331,74 @@ def build_discriminator(self): return Model(img, validity) ``` +> +> #### **4. train()** +> +> train()은 batch만큼의 데이터로 이미지 데이터로 빌드한 generator, discriminator를 epochs 만큼 학습하는 과정입니다. +> +```python +def train(self, epochs, batch_size=1, sample_interval=50): + + start_time = datetime.datetime.now() + + # Adversarial loss ground truths + valid = np.ones((batch_size,) + self.disc_patch) + fake = np.zeros((batch_size,) + self.disc_patch) + + for epoch in range(epochs): + for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)): + + # ---------------------- + # Train Discriminators + # ---------------------- + + # 반대의 도메인으로 이미지 translate + fake_B = self.g_AB.predict(imgs_A) + fake_A = self.g_BA.predict(imgs_B) + + # discriminators 학습 + # (original images = real / translated = Fake) + dA_loss_real = self.d_A.train_on_batch(imgs_A, valid) + dA_loss_fake = self.d_A.train_on_batch(fake_A, fake) + dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake) + + dB_loss_real = self.d_B.train_on_batch(imgs_B, valid) + dB_loss_fake = self.d_B.train_on_batch(fake_B, fake) + dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake) + + # Total disciminator loss + d_loss = 0.5 * np.add(dA_loss, dB_loss) + + + # ------------------ + # Train Generators + # ------------------ + + # generators 학습 + g_loss = self.combined.train_on_batch([imgs_A, imgs_B], + [valid, valid, + imgs_A, imgs_B, + imgs_A, imgs_B]) + + elapsed_time = datetime.datetime.now() - start_time + + # Plot the progress + print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \ + % ( epoch, epochs, + batch_i, self.data_loader.n_batches, + d_loss[0], 100*d_loss[1], + g_loss[0], + np.mean(g_loss[1:3]), + np.mean(g_loss[3:5]), + np.mean(g_loss[5:6]), + elapsed_time)) + + # If at save interval => save generated image samples + if batch_i % sample_interval == 0: + self.sample_images(epoch, batch_i) +``` +

### 참고문서 * [김태영의 케라스 블로그](https://tykimos.github.io/)