forked from TeamSemiSuperCV/semi-super
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
868 lines (730 loc) · 32.2 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""The main training pipeline."""
import json
import math
import os
import glob
from absl import app
from absl import flags
from absl import logging
import data as data_lib
import metrics
import model as model_lib
import objective as obj_lib
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import plots as plots_lib
FLAGS = flags.FLAGS
flags.DEFINE_float(
'learning_rate', 0.3,
'Initial learning rate per batch size of 256.')
flags.DEFINE_enum(
'learning_rate_scaling', 'linear', ['linear', 'sqrt'],
'How to scale the learning rate as a function of batch size.')
flags.DEFINE_float(
'warmup_epochs', 10,
'Number of epochs of warmup.')
flags.DEFINE_float('weight_decay', 1e-6, 'Amount of weight decay to use.')
flags.DEFINE_float(
'batch_norm_decay', 0.9,
'Batch norm decay parameter.')
flags.DEFINE_integer(
'train_batch_size', 512,
'Batch size for training.')
flags.DEFINE_string(
'train_split', 'train',
'Split for training.')
flags.DEFINE_integer(
'train_epochs', 100,
'Number of epochs to train for.')
flags.DEFINE_integer(
'train_steps', 0,
'Number of steps to train for. If provided, overrides train_epochs.')
flags.DEFINE_integer(
'eval_steps', 0,
'Number of steps to eval for. If not provided, evals over entire dataset.')
flags.DEFINE_integer(
'eval_batch_size', 256,
'Batch size for eval.')
flags.DEFINE_integer(
'checkpoint_epochs', 5,
'Number of epochs between checkpoints/summaries.')
flags.DEFINE_integer(
'checkpoint_steps', 0,
'Number of steps between checkpoints/summaries. If provided, overrides '
'checkpoint_epochs.')
flags.DEFINE_string(
'eval_split', 'validation',
'Split for evaluation.')
flags.DEFINE_string(
'dataset', 'xray_orig',
'Name of a dataset.')
flags.DEFINE_bool(
'cache_dataset', False,
'Whether to cache the entire dataset in memory. If the dataset is '
'ImageNet, this is a very bad idea, but for smaller datasets it can '
'improve performance.')
flags.DEFINE_enum(
'mode', 'train', ['train', 'eval', 'train_then_eval'],
'Whether to perform training or evaluation.')
flags.DEFINE_enum(
'train_mode', 'pretrain', ['pretrain', 'finetune'],
'The train mode controls different objectives and trainable components.')
flags.DEFINE_bool('lineareval_while_pretraining', True,
'Whether to finetune supervised head while pretraining.')
flags.DEFINE_string(
'checkpoint', None,
'Loading from the given checkpoint for fine-tuning if a finetuning '
'checkpoint does not already exist in model_dir.')
flags.DEFINE_bool(
'zero_init_logits_layer', False,
'If True, zero initialize layers after avg_pool for supervised learning.')
flags.DEFINE_integer(
'fine_tune_after_block', -1,
'The layers after which block that we will fine-tune. -1 means fine-tuning '
'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
'just the linear head.')
flags.DEFINE_string(
'master', None,
'Address/name of the TensorFlow master to use. By default, use an '
'in-process master.')
flags.DEFINE_string(
'model_dir', None,
'Model directory for training.')
flags.DEFINE_string(
'data_dir', None,
'Directory where dataset is stored.')
flags.DEFINE_bool(
'use_tpu', True,
'Whether to run on TPU.')
flags.DEFINE_string(
'tpu_name', None,
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
flags.DEFINE_string(
'tpu_zone', None,
'[Optional] GCE zone where the Cloud TPU is located in. If not '
'specified, we will attempt to automatically detect the GCE project from '
'metadata.')
flags.DEFINE_string(
'gcp_project', None,
'[Optional] Project name for the Cloud TPU-enabled project. If not '
'specified, we will attempt to automatically detect the GCE project from '
'metadata.')
flags.DEFINE_enum(
'optimizer', 'lars', ['momentum', 'adam', 'lars'],
'Optimizer to use.')
flags.DEFINE_float(
'momentum', 0.9,
'Momentum parameter.')
flags.DEFINE_string(
'eval_name', None,
'Name for eval.')
flags.DEFINE_integer(
'keep_checkpoint_max', 1,
'Maximum number of checkpoints to keep.')
flags.DEFINE_integer(
'keep_hub_module_max', 1,
'Maximum number of Hub modules to keep.')
flags.DEFINE_float(
'temperature', 0.1,
'Temperature parameter for contrastive loss.')
flags.DEFINE_boolean(
'hidden_norm', True,
'Temperature parameter for contrastive loss.')
flags.DEFINE_enum(
'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
'How the head projection is done.')
flags.DEFINE_integer(
'proj_out_dim', 128,
'Number of head projection dimension.')
flags.DEFINE_integer(
'num_proj_layers', 3,
'Number of non-linear head layers.')
flags.DEFINE_integer(
'ft_proj_selector', 1,
'Which layer of the projection head to use during fine-tuning. '
'0 means no projection head, and -1 means the final layer.')
flags.DEFINE_boolean(
'global_bn', True,
'Whether to aggregate BN statistics across distributed cores.')
flags.DEFINE_integer(
'width_multiplier', 1,
'Multiplier to change width of network.')
flags.DEFINE_integer(
'resnet_depth', 50,
'Depth of ResNet.')
flags.DEFINE_float(
'sk_ratio', 0.,
'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')
flags.DEFINE_float(
'se_ratio', 0.,
'If it is bigger than 0, it will enable SE.')
flags.DEFINE_integer(
'image_size', 224,
'Input image size.')
flags.DEFINE_float(
'color_jitter_strength', 1.0,
'The strength of color jittering.')
flags.DEFINE_boolean(
'use_blur', True,
'Whether or not to use Gaussian blur for augmentation during pretraining.')
flags.DEFINE_boolean(
'test_crop', False,
'Whether or not to crop image during testing.')
flags.DEFINE_boolean(
'eval_per_loop', False,
'Eval every loop.')
flags.DEFINE_boolean(
'save_best_loss', False,
'Save best loss model on eval split.')
flags.DEFINE_boolean(
'save_best_acc', False,
'Save best acc model on eval split.')
flags.DEFINE_string(
'tmp_folder', '/tmp',
'Name of a dataset.')
flags.DEFINE_float(
'area_range_min', 0.08,
'The area range min value of crop.')
flags.DEFINE_boolean(
'save_only_last_ckpt', True,
'Save only last checkpoint. Intermediate ckpts are not saved.')
flags.DEFINE_float(
'max_rot_angle', 0.0,
'Max rotation during data augmentation')
flags.DEFINE_boolean(
'include_rotation', False,
'Include rotation in data augmentation.')
flags.DEFINE_boolean(
'distill_mode', False,
'Activate distillation mode.')
flags.DEFINE_string(
'teacher_model_dir', None,
'Load the given teacher model for distillation mode.')
flags.DEFINE_boolean(
'keras_resnet50', False,
'Use Keras ResNet50 as student model.')
flags.DEFINE_boolean(
'vertical_flip', False,
'Include vertical_flip in data augmentation.')
def get_salient_tensors_dict(include_projection_head):
"""Returns a dictionary of tensors."""
graph = tf.compat.v1.get_default_graph()
result = {}
for i in range(1, 5):
result['block_group%d' % i] = graph.get_tensor_by_name(
'resnet/block_group%d/block_group%d:0' % (i, i))
result['initial_conv'] = graph.get_tensor_by_name(
'resnet/initial_conv/Identity:0')
result['initial_max_pool'] = graph.get_tensor_by_name(
'resnet/initial_max_pool/Identity:0')
result['final_avg_pool'] = graph.get_tensor_by_name('resnet/final_avg_pool:0')
result['sup_head_input'] = graph.get_tensor_by_name('projection_head/sup_head_input:0')
result['logits_sup'] = graph.get_tensor_by_name(
'head_supervised/logits_sup:0')
if include_projection_head:
result['proj_head_input'] = graph.get_tensor_by_name(
'projection_head/proj_head_input:0')
result['proj_head_output'] = graph.get_tensor_by_name(
'projection_head/proj_head_output:0')
return result
def build_saved_model(model, include_projection_head=True):
"""Returns a tf.Module for saving to SavedModel."""
class SimCLRModel(tf.Module):
"""Saved model for exporting to hub."""
def __init__(self, model):
self.model = model
# This can't be called `trainable_variables` because `tf.Module` has
# a getter with the same name.
self.trainable_variables_list = model.trainable_variables
@tf.function
def __call__(self, inputs, trainable):
self.model(inputs, training=trainable)
return get_salient_tensors_dict(include_projection_head)
module = SimCLRModel(model)
input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
module.__call__.get_concrete_function(input_spec, trainable=True)
module.__call__.get_concrete_function(input_spec, trainable=False)
return module
def save(model, global_step):
"""Export as SavedModel for finetuning and inference."""
if FLAGS.distill_mode and FLAGS.keras_resnet50:
export_dir = os.path.join(FLAGS.model_dir, 'saved_model_keras')
checkpoint_export_dir = os.path.join(export_dir, str(global_step))
if tf.io.gfile.exists(checkpoint_export_dir):
tf.io.gfile.rmtree(checkpoint_export_dir)
tf.keras.models.save_model(model, checkpoint_export_dir)
else:
saved_model = build_saved_model(model)
export_dir = os.path.join(FLAGS.model_dir, 'saved_model')
checkpoint_export_dir = os.path.join(export_dir, str(global_step))
if tf.io.gfile.exists(checkpoint_export_dir):
tf.io.gfile.rmtree(checkpoint_export_dir)
tf.saved_model.save(saved_model, checkpoint_export_dir)
if FLAGS.keep_hub_module_max > 0:
# Delete old exported SavedModels.
exported_steps = []
for subdir in tf.io.gfile.listdir(export_dir):
if not subdir.isdigit():
continue
exported_steps.append(int(subdir))
exported_steps.sort()
for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
tf.io.gfile.rmtree(os.path.join(export_dir, str(step_to_delete)))
def save_best(model, global_step, model_best_metric):
"""Export best model as SavedModel for finetuning and inference."""
saved_model = build_saved_model(model)
export_dir = os.path.join(FLAGS.model_dir, 'saved_model_{}'.format(model_best_metric))
checkpoint_export_dir = os.path.join(export_dir, str(global_step))
if tf.io.gfile.exists(checkpoint_export_dir):
tf.io.gfile.rmtree(checkpoint_export_dir)
tf.saved_model.save(saved_model, checkpoint_export_dir)
if FLAGS.keep_hub_module_max > 0:
# Delete old exported SavedModels.
exported_steps = []
for subdir in tf.io.gfile.listdir(export_dir):
if not subdir.isdigit():
continue
exported_steps.append(int(subdir))
exported_steps.sort()
for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
tf.io.gfile.rmtree(os.path.join(export_dir, str(step_to_delete)))
def try_restore_from_checkpoint(model, global_step, optimizer):
"""Restores the latest ckpt if it exists, otherwise check FLAGS.checkpoint."""
checkpoint = tf.train.Checkpoint(
model=model, global_step=global_step, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=FLAGS.model_dir,
max_to_keep=FLAGS.keep_checkpoint_max)
latest_ckpt = checkpoint_manager.latest_checkpoint
if latest_ckpt:
# Restore model weights, global step, optimizer states
logging.info('Restoring from latest checkpoint: %s', latest_ckpt)
checkpoint_manager.checkpoint.restore(latest_ckpt).expect_partial()
elif FLAGS.checkpoint and FLAGS.keras_resnet50 == False:
# Restore model weights only, but not global step and optimizer states
logging.info('Restoring from given checkpoint: %s', FLAGS.checkpoint)
checkpoint_manager2 = tf.train.CheckpointManager(
tf.train.Checkpoint(model=model),
directory=FLAGS.model_dir,
max_to_keep=FLAGS.keep_checkpoint_max)
checkpoint_manager2.checkpoint.restore(FLAGS.checkpoint).expect_partial()
if FLAGS.zero_init_logits_layer:
model = checkpoint_manager2.checkpoint.model
output_layer_parameters = model.supervised_head.trainable_weights
logging.info('Initializing output layer parameters %s to zero',
[x.op.name for x in output_layer_parameters])
for x in output_layer_parameters:
x.assign(tf.zeros_like(x))
return checkpoint_manager
def json_serializable(val):
try:
json.dumps(val)
return True
except TypeError:
return False
def perform_evaluation(model, builder, eval_steps, ckpt, strategy, topology, training_complete):
"""Perform evaluation."""
if FLAGS.train_mode == 'pretrain' and not FLAGS.lineareval_while_pretraining:
logging.info('Skipping eval during pretraining without linear eval.')
return
# Build input pipeline.
ds = data_lib.build_distributed_dataset(builder, FLAGS.eval_batch_size, False,
strategy, topology)
summary_writer = tf.summary.create_file_writer(FLAGS.model_dir)
# Build metrics.
with strategy.scope():
regularization_loss = tf.keras.metrics.Mean('eval/regularization_loss')
eval_sup_loss_metric = tf.keras.metrics.Mean('eval/supervised_loss')
label_top_1_accuracy = tf.keras.metrics.Accuracy(
'eval/label_top_1_accuracy')
label_top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(
5, 'eval/label_top_5_accuracy')
label_recall = tf.keras.metrics.Recall(name='eval/recall')
label_precision = tf.keras.metrics.Precision(name='eval/precision')
all_metrics = [
regularization_loss, eval_sup_loss_metric, label_top_1_accuracy,
label_top_5_accuracy,
label_recall,
label_precision
]
# Restore checkpoint.
logging.info('Restoring from %s', ckpt)
checkpoint = tf.train.Checkpoint(
model=model, global_step=tf.Variable(0, dtype=tf.int64))
checkpoint.restore(ckpt).expect_partial()
global_step = checkpoint.global_step
logging.info('Performing eval at step %d', global_step.numpy())
def single_step(features, labels):
if FLAGS.distill_mode and FLAGS.keras_resnet50:
_, supervised_head_outputs = None, model(features, training=False)
else:
_, supervised_head_outputs = model(features, training=False)
assert supervised_head_outputs is not None
outputs = supervised_head_outputs
l = labels['labels']
metrics.update_finetune_metrics_eval(label_top_1_accuracy,
label_top_5_accuracy,
label_recall,
label_precision,
outputs, l)
if FLAGS.distill_mode and FLAGS.keras_resnet50:
reg_loss = model_lib.add_weight_decay_keras(
model, adjust_per_optimizer=True)
else:
reg_loss = model_lib.add_weight_decay(
model, adjust_per_optimizer=True)
regularization_loss.update_state(reg_loss)
eval_sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
eval_sup_loss_metric.update_state(eval_sup_loss)
with strategy.scope():
@tf.function
def run_single_step(iterator):
images, labels = next(iterator)
features, labels = images, {'labels': labels}
strategy.run(single_step, (features, labels))
iterator = iter(ds)
for i in range(eval_steps):
run_single_step(iterator)
logging.info('Completed eval for %d / %d steps', i + 1, eval_steps)
logging.info('Finished eval for %s', ckpt)
# Write summaries
cur_step = global_step.numpy()
logging.info('Writing summaries for %d step', cur_step)
with summary_writer.as_default():
metrics.log_and_write_metrics_to_summary(all_metrics, cur_step)
summary_writer.flush()
# Record results as JSON.
if FLAGS.mode == 'eval':
result_json_path = os.path.join(FLAGS.model_dir, 'result_{}.json'.format(FLAGS.eval_split))
result = {metric.name: metric.result().numpy() for metric in all_metrics}
result['global_step'] = global_step.numpy()
logging.info(result)
with tf.io.gfile.GFile(result_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
else:
result_json_path = os.path.join(FLAGS.model_dir, 'result_{}.json'.format(FLAGS.eval_split))
result = {metric.name: metric.result().numpy() for metric in all_metrics}
result['global_step'] = global_step.numpy()
logging.info(result)
with tf.io.gfile.GFile(result_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
result_json_path = os.path.join(
FLAGS.model_dir, 'result_%d.json'%result['global_step'])
with tf.io.gfile.GFile(result_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')
with tf.io.gfile.GFile(flag_json_path, 'w') as f:
serializable_flags = {}
for key, val in FLAGS.flag_values_dict().items():
# Some flag value types e.g. datetime.timedelta are not json serializable,
# filter those out.
if json_serializable(val):
serializable_flags[key] = val
json.dump(serializable_flags, f)
# Export as SavedModel for finetuning and inference.
if training_complete:
save(model, global_step=result['global_step'])
# Export best acc model
if FLAGS.save_best_acc:
result_best_json_path = os.path.join(FLAGS.model_dir, 'result_best_acc.json')
if tf.io.gfile.exists(result_best_json_path):
with tf.io.gfile.GFile(result_best_json_path, 'r') as f:
result_best = json.load(f)
is_best_acc = result["eval/label_top_1_accuracy"] > result_best["eval/label_top_1_accuracy"]
if is_best_acc:
save_best(model, global_step=result['global_step'], model_best_metric='acc')
with tf.io.gfile.GFile(result_best_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
else:
save_best(model, global_step=result['global_step'], model_best_metric='acc')
with tf.io.gfile.GFile(result_best_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
# Export best loss model
if FLAGS.save_best_loss:
result_best_json_path = os.path.join(FLAGS.model_dir, 'result_best_loss.json')
if tf.io.gfile.exists(result_best_json_path):
with tf.io.gfile.GFile(result_best_json_path, 'r') as f:
result_best = json.load(f)
is_best_loss = result["eval/supervised_loss"] < result_best["eval/supervised_loss"]
if is_best_loss:
save_best(model, global_step=result['global_step'], model_best_metric='loss')
with tf.io.gfile.GFile(result_best_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
else:
save_best(model, global_step=result['global_step'], model_best_metric='loss')
with tf.io.gfile.GFile(result_best_json_path, 'w') as f:
json.dump({k: float(v) for k, v in result.items()}, f)
return result
def _restore_latest_or_from_pretrain(checkpoint_manager):
"""Restores the latest ckpt if training already.
Or restores from FLAGS.checkpoint if in finetune mode.
Args:
checkpoint_manager: tf.traiin.CheckpointManager.
"""
latest_ckpt = checkpoint_manager.latest_checkpoint
if latest_ckpt:
# The model is not build yet so some variables may not be available in
# the object graph. Those are lazily initialized. To suppress the warning
# in that case we specify `expect_partial`.
logging.info('Restoring from %s', latest_ckpt)
checkpoint_manager.checkpoint.restore(latest_ckpt).expect_partial()
elif FLAGS.train_mode == 'finetune':
# Restore from pretrain checkpoint.
assert FLAGS.checkpoint, 'Missing pretrain checkpoint.'
logging.info('Restoring from %s', FLAGS.checkpoint)
checkpoint_manager.checkpoint.restore(FLAGS.checkpoint).expect_partial()
# TODO(iamtingchen): Can we instead use a zeros initializer for the
# supervised head?
if FLAGS.zero_init_logits_layer:
model = checkpoint_manager.checkpoint.model
output_layer_parameters = model.supervised_head.trainable_weights
logging.info('Initializing output layer parameters %s to zero',
[x.op.name for x in output_layer_parameters])
for x in output_layer_parameters:
x.assign(tf.zeros_like(x))
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
builder.download_and_prepare()
num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
num_classes = builder.info.features['label'].num_classes
train_steps = model_lib.get_train_steps(num_train_examples)
eval_steps = FLAGS.eval_steps or int(
math.ceil(num_eval_examples / FLAGS.eval_batch_size))
epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))
logging.info('# train examples: %d', num_train_examples)
logging.info('# train_steps: %d', train_steps)
logging.info('# eval examples: %d', num_eval_examples)
logging.info('# eval steps: %d', eval_steps)
checkpoint_steps = (
FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))
topology = None
if FLAGS.use_tpu:
if FLAGS.tpu_name:
cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
else:
cluster = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master)
tf.config.experimental_connect_to_cluster(cluster)
topology = tf.tpu.experimental.initialize_tpu_system(cluster)
logging.info('Topology:')
logging.info('num_tasks: %d', topology.num_tasks)
logging.info('num_tpus_per_task: %d', topology.num_tpus_per_task)
strategy = tf.distribute.TPUStrategy(cluster)
else:
# For (multiple) GPUs.
strategy = tf.distribute.MirroredStrategy()
logging.info('Running using MirroredStrategy on %d replicas',
strategy.num_replicas_in_sync)
with strategy.scope():
if FLAGS.distill_mode and FLAGS.keras_resnet50:
input_shape = (FLAGS.image_size, FLAGS.image_size, 3)
# model = tf.keras.applications.ResNet50(weights=None,
# input_shape=input_shape, classes=num_classes, classifier_activation=None)
model = model_lib.resnet50_mod(input_shape, num_classes)
logging.info('Loaded Keras ResNet50 as student model')
else:
model = model_lib.Model(num_classes)
if FLAGS.distill_mode:
logging.info('Distillation mode active')
logging.info('Restoring teacher model from: %s', FLAGS.teacher_model_dir)
teacher_model = tf.saved_model.load(FLAGS.teacher_model_dir)
if FLAGS.mode == 'eval':
for ckpt in tf.train.checkpoints_iterator(
FLAGS.model_dir, min_interval_secs=15):
result = perform_evaluation(model, builder, eval_steps, ckpt, strategy,
topology, training_complete=True)
if result['global_step'] >= train_steps:
logging.info('Eval complete. Exiting...')
return
else:
summary_writer = tf.summary.create_file_writer(FLAGS.model_dir)
with strategy.scope():
# Build input pipeline.
ds = data_lib.build_distributed_dataset(builder, FLAGS.train_batch_size,
True, strategy, topology)
# Build LR schedule and optimizer.
learning_rate = model_lib.WarmUpAndCosineDecay(FLAGS.learning_rate,
num_train_examples)
optimizer = model_lib.build_optimizer(learning_rate)
# Build metrics.
all_metrics = [] # For summaries.
weight_decay_metric = tf.keras.metrics.Mean('train/weight_decay')
total_loss_metric = tf.keras.metrics.Mean('train/total_loss')
all_metrics.extend([weight_decay_metric, total_loss_metric])
if FLAGS.train_mode == 'pretrain':
contrast_loss_metric = tf.keras.metrics.Mean('train/contrast_loss')
contrast_acc_metric = tf.keras.metrics.Mean('train/contrast_acc')
contrast_entropy_metric = tf.keras.metrics.Mean(
'train/contrast_entropy')
all_metrics.extend([
contrast_loss_metric, contrast_acc_metric, contrast_entropy_metric
])
if FLAGS.train_mode == 'finetune' or FLAGS.lineareval_while_pretraining:
supervised_loss_metric = tf.keras.metrics.Mean('train/supervised_loss')
supervised_acc_metric = tf.keras.metrics.Mean('train/supervised_acc')
all_metrics.extend([supervised_loss_metric, supervised_acc_metric])
# Restore checkpoint if available.
checkpoint_manager = try_restore_from_checkpoint(
model, optimizer.iterations, optimizer)
steps_per_loop = checkpoint_steps
def single_step(features, labels):
with tf.GradientTape() as tape:
# Log summaries on the last step of the training loop to match
# logging frequency of other scalar summaries.
#
# Notes:
# 1. Summary ops on TPUs get outside compiled so they do not affect
# performance.
# 2. Summaries are recorded only on replica 0. So effectively this
# summary would be written once per host when should_record == True.
# 3. optimizer.iterations is incremented in the call to apply_gradients.
# So we use `iterations + 1` here so that the step number matches
# those of scalar summaries.
# 4. We intentionally run the summary op before the actual model
# training so that it can run in parallel.
should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0)
with tf.summary.record_if(should_record):
# Only log augmented images for the first tower.
tf.summary.image(
'image', features[:, :, :, :3], step=optimizer.iterations + 1)
if FLAGS.distill_mode and FLAGS.keras_resnet50:
projection_head_outputs, supervised_head_outputs = None, model(
features, training=True)
else:
projection_head_outputs, supervised_head_outputs = model(
features, training=True)
if FLAGS.distill_mode:
teacher_outputs = teacher_model(
features, trainable=False)['logits_sup']
else:
teacher_outputs = None
loss = None
if projection_head_outputs is not None:
outputs = projection_head_outputs
con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
outputs,
hidden_norm=FLAGS.hidden_norm,
temperature=FLAGS.temperature,
strategy=strategy)
if loss is None:
loss = con_loss
else:
loss += con_loss
metrics.update_pretrain_metrics_train(contrast_loss_metric,
contrast_acc_metric,
contrast_entropy_metric,
con_loss, logits_con,
labels_con)
if supervised_head_outputs is not None:
outputs = supervised_head_outputs
l = labels['labels']
if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
l = tf.concat([l, l], 0)
if FLAGS.distill_mode:
sup_loss = obj_lib.add_kd_loss(teacher_logits=teacher_outputs,
student_logits=outputs, temperature=FLAGS.temperature)
else:
sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
if loss is None:
loss = sup_loss
else:
loss += sup_loss
if FLAGS.distill_mode:
metrics.update_finetune_metrics_train(supervised_loss_metric,
supervised_acc_metric, sup_loss,
teacher_outputs, outputs)
else:
metrics.update_finetune_metrics_train(supervised_loss_metric,
supervised_acc_metric, sup_loss,
l, outputs)
if FLAGS.distill_mode and FLAGS.keras_resnet50:
weight_decay = model_lib.add_weight_decay_keras(
model, adjust_per_optimizer=True)
else:
weight_decay = model_lib.add_weight_decay(
model, adjust_per_optimizer=True)
weight_decay_metric.update_state(weight_decay)
loss += weight_decay
total_loss_metric.update_state(loss)
# The default behavior of `apply_gradients` is to sum gradients from all
# replicas so we divide the loss by the number of replicas so that the
# mean gradient is applied.
loss = loss / strategy.num_replicas_in_sync
# logging.info('Trainable variables:')
# for var in model.trainable_variables:
# logging.info(var.name)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
with strategy.scope():
@tf.function
def train_multiple_steps(iterator):
# `tf.range` is needed so that this runs in a `tf.while_loop` and is
# not unrolled.
for _ in tf.range(steps_per_loop):
# Drop the "while" prefix created by tf.while_loop which otherwise
# gets prefixed to every variable name. This does not affect training
# but does affect the checkpoint conversion script.
# TODO(b/161712658): Remove this.
with tf.name_scope(''):
images, labels = next(iterator)
features, labels = images, {'labels': labels}
strategy.run(single_step, (features, labels))
global_step = optimizer.iterations
cur_step = global_step.numpy()
iterator = iter(ds)
while cur_step < train_steps:
# Calls to tf.summary.xyz lookup the summary writer resource which is
# set by the summary writer's context manager.
with summary_writer.as_default():
train_multiple_steps(iterator)
cur_step = global_step.numpy()
training_complete = cur_step >= train_steps
if FLAGS.save_only_last_ckpt == False or training_complete:
checkpoint_manager.save(cur_step)
logging.info('Completed: %d / %d steps', cur_step, train_steps)
metrics.log_and_write_metrics_to_summary_json(all_metrics, cur_step)
if FLAGS.eval_per_loop == True and FLAGS.save_only_last_ckpt == False:
perform_evaluation(model, builder, eval_steps,
checkpoint_manager.latest_checkpoint, strategy,
topology, training_complete)
elif FLAGS.eval_per_loop == True:
logging.warning('Skipping eval per loop because save_only_last_ckpt is set to True')
tf.summary.scalar(
'learning_rate',
learning_rate(tf.cast(global_step, dtype=tf.float32)),
global_step)
summary_writer.flush()
for metric in all_metrics:
metric.reset_states()
logging.info('Training complete...')
if (FLAGS.mode == 'train_then_eval' and FLAGS.eval_per_loop == False) or (FLAGS.eval_per_loop == True and FLAGS.save_only_last_ckpt == True):
perform_evaluation(model, builder, eval_steps,
checkpoint_manager.latest_checkpoint, strategy,
topology, training_complete=True)
plots_lib.gen_plots()
if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
# For outside compilation of summaries on TPU.
tf.config.set_soft_device_placement(True)
app.run(main)