Skip to content

Commit

Permalink
KerasKorea#74 : translate more codes
Browse files Browse the repository at this point in the history
  • Loading branch information
jhp committed Oct 21, 2018
1 parent 5da23fe commit afd7387
Showing 1 changed file with 71 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,13 @@ def __init__(self):
img_A_id = self.g_BA(img_A)
img_B_id = self.g_AB(img_B)

# For the combined model we will only train the generators
# generaotr를 학습할 때는 discriminator는 학습하지 않습니다.
self.d_A.trainable = False
self.d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = self.d_A(fake_A)
valid_B = self.d_B(fake_B)

# Combined model trains generators to fool discriminators
self.combined = Model(inputs=[img_A, img_B],
outputs=[ valid_A, valid_B,
reconstr_A, reconstr_B,
Expand All @@ -266,6 +264,7 @@ def __init__(self):
> build_generator() 는 **Generator의 구조** 를 만듭니다. 이 코드에서는 U-Net 을 Generator 로 사용했습니다.
>
> conv2d는 downsampling을 통해 input image의 특징을 추출하고, deconv2d는 upsampling을 통해 이미지의 스타일을 바꿔(translation)주는 용도로 사용합니다.
>
```python
def build_generator(self):
"""U-Net Generator"""
Expand Down Expand Up @@ -308,7 +307,8 @@ def build_generator(self):
```
> #### **3. build_discriminator()**
>
> build_discriminator()가 하는 일은
> build_discriminator()가 하는 일은 이미지가 진짜인지 가짜인지 판별하는 것입니다. 여기서 진짜라는 의미는 우리가 input 으로 넣은 이미지이고 가짜라는 의미는 generator가 만든 이미지입니다.
>
```python
def build_discriminator(self):

Expand All @@ -331,7 +331,74 @@ def build_discriminator(self):

return Model(img, validity)
```
>
> #### **4. train()**
>
> train()은 batch만큼의 데이터로 이미지 데이터로 빌드한 generator, discriminator를 epochs 만큼 학습하는 과정입니다.
>
```python
def train(self, epochs, batch_size=1, sample_interval=50):

start_time = datetime.datetime.now()

# Adversarial loss ground truths
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)

for epoch in range(epochs):
for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

# ----------------------
# Train Discriminators
# ----------------------

# 반대의 도메인으로 이미지 translate
fake_B = self.g_AB.predict(imgs_A)
fake_A = self.g_BA.predict(imgs_B)

# discriminators 학습
# (original images = real / translated = Fake)
dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

# Total disciminator loss
d_loss = 0.5 * np.add(dA_loss, dB_loss)


# ------------------
# Train Generators
# ------------------

# generators 학습
g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
[valid, valid,
imgs_A, imgs_B,
imgs_A, imgs_B])

elapsed_time = datetime.datetime.now() - start_time

# Plot the progress
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
% ( epoch, epochs,
batch_i, self.data_loader.n_batches,
d_loss[0], 100*d_loss[1],
g_loss[0],
np.mean(g_loss[1:3]),
np.mean(g_loss[3:5]),
np.mean(g_loss[5:6]),
elapsed_time))

# If at save interval => save generated image samples
if batch_i % sample_interval == 0:
self.sample_images(epoch, batch_i)
```

<br></br>

### 참고문서
* [김태영의 케라스 블로그](https://tykimos.github.io/)
Expand Down

0 comments on commit afd7387

Please sign in to comment.