forked from OndrejTexler/Few-Shot-Patch-Based-Training
-
Notifications
You must be signed in to change notification settings - Fork 8
/
logger.py
32 lines (25 loc) · 994 Bytes
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
from tensorflow import summary
import os
import shutil
class Logger(object):
def __init__(self, log_dir, suffix=None):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.create_file_writer(log_dir, filename_suffix=suffix)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
with self.writer.as_default():
tf.summary.scalar(tag, value, step=step)
self.writer.flush()
class ModelLogger(object):
def __init__(self, log_dir, save_func):
self.log_dir = log_dir
self.save_func = save_func
def save(self, model, epoch, isGenerator):
if isGenerator:
new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch)
else:
new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch)
self.save_func(model, new_path)
def copy_file(self, source):
shutil.copy(source, self.log_dir)