Skip to content

Commit

Permalink
Resolved Arturus#27 from the original repo
Browse files Browse the repository at this point in the history
For more details go on this link: Arturus#27
  • Loading branch information
amankhandelia authored Jan 7, 2020
1 parent a9abb80 commit fc5ee83
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,22 +468,24 @@ def create_model(scope, index, prefix, seed):
forward_eval_pipe = None
avg_sgd = asgd_decay is not None
#asgd_decay = 0.99 if avg_sgd else None
train_model = Model(pipe, hparams, is_train=True, graph_prefix=prefix, asgd_decay=asgd_decay, seed=seed)
scope.reuse_variables()

eval_stages = []
if side_split:
side_eval_model = Model(side_eval_pipe, hparams, is_train=False,
#loss_mask=np.concatenate([np.zeros(50, dtype=np.float32), np.ones(10, dtype=np.float32)]),
seed=seed)
eval_stages.append((Stage.EVAL_SIDE, side_eval_model))
if avg_sgd:
eval_stages.append((Stage.EVAL_SIDE_EMA, side_eval_model))
if forward_split:
forward_eval_model = Model(forward_eval_pipe, hparams, is_train=False, seed=seed)
eval_stages.append((Stage.EVAL_FRWD, forward_eval_model))
if avg_sgd:
eval_stages.append((Stage.EVAL_FRWD_EMA, forward_eval_model))
with tf.variable_scope('model') as scope:
train_model = Model(pipe, hparams, is_train=True, graph_prefix=prefix, asgd_decay=asgd_decay, seed=seed)
# scope.reuse_variables()
with tf.variable_scope('model', reuse=True) as scope:
eval_stages = []
if side_split:

side_eval_model = Model(side_eval_pipe, hparams, is_train=False,
#loss_mask=np.concatenate([np.zeros(50, dtype=np.float32), np.ones(10, dtype=np.float32)]),
seed=seed)
eval_stages.append((Stage.EVAL_SIDE, side_eval_model))
if avg_sgd:
eval_stages.append((Stage.EVAL_SIDE_EMA, side_eval_model))
if forward_split:
forward_eval_model = Model(forward_eval_pipe, hparams, is_train=False, seed=seed)
eval_stages.append((Stage.EVAL_FRWD, forward_eval_model))
if avg_sgd:
eval_stages.append((Stage.EVAL_FRWD_EMA, forward_eval_model))

if write_summaries:
summ_path = f"{logdir}/{name}_{index}"
Expand Down

0 comments on commit fc5ee83

Please sign in to comment.