diff --git a/ai8x.py b/ai8x.py index fa5683d22..fbe44ad75 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1805,12 +1805,13 @@ def update_optimizer(m, optimizer): optimizer = type(optimizer)(m.parameters(), **optimizer.defaults) new_state_dict = optimizer.state_dict() groups = optimizer.param_groups + for x, g in enumerate(groups): + key_reduce = 0 for p in g['params']: if (len(p.shape) == 1 and p.shape[0] == 1): continue nf_keys = [] - key_reduce = 0 for key in old_state_dict['state'].keys(): sub_keys = old_state_dict['state'][key].keys() if old_groups[x]['params'][int(key)].shape == p.shape: @@ -1827,9 +1828,10 @@ def update_optimizer(m, optimizer): key_reduce += 1 for key in nf_keys: old_state_dict['state'].pop(key) - new_state_dict['param_groups'][x]['initial_lr'] = \ - old_state_dict['param_groups'][x]['initial_lr'] - + for key in old_state_dict['param_groups'][x].keys(): + if key != 'params': + new_state_dict['param_groups'][x][key] = \ + old_state_dict['param_groups'][x][key] optimizer.load_state_dict(new_state_dict) return optimizer diff --git a/losses/dummyloss.py b/losses/dummyloss.py new file mode 100644 index 000000000..952ecefde --- /dev/null +++ b/losses/dummyloss.py @@ -0,0 +1,35 @@ +################################################################################################### +# +# Copyright (C) 2023 Analog Devices, Inc. All Rights Reserved. +# +# Analog Devices, Inc. Default Copyright Notice: +# https://www.analog.com/en/about-adi/legal-and-risk-oversight/intellectual-property/copyright-notice.html +# +################################################################################################### +""" +Dummy Loss to use in knowledge distillation when student loss weight is 0 +""" + +import torch +from torch import nn + + +class DummyLoss(nn.Module): + """ + Class for dummy loss + """ + def __init__(self, device='cpu'): + """ + Initializes the loss + """ + super().__init__() + + self.device = device + + # pylint: disable=unused-argument + def forward(self, output, target): + """ + returns 0.0 + """ + + return torch.tensor(0.0, device=self.device) diff --git a/parsecmd.py b/parsecmd.py index e372a8347..2f905b33c 100644 --- a/parsecmd.py +++ b/parsecmd.py @@ -76,7 +76,10 @@ def get_parser(model_names, dataset_names): parser.add_argument('--avg-pool-rounding', action='store_true', default=False, help='when simulating, use "round()" in AvgPool operations ' '(default: use "floor()")') - + parser.add_argument('--copy-output-folder', type=str, default=None, metavar='PATH', + help='Path to copy output folder (default: None)') + parser.add_argument('--kd-relationbased', action='store_true', default=False, + help='enables Relation Based Knowledge Distillation') qat_args = parser.add_argument_group('Quantization Arguments') qat_args.add_argument('--qat-policy', dest='qat_policy', default=os.path.join('policies', 'qat_policy.yaml'), diff --git a/train.py b/train.py index 25c43cefa..26742cabf 100644 --- a/train.py +++ b/train.py @@ -66,6 +66,7 @@ import resource # pylint: disable=import-error # pylint: enable=wrong-import-position +import shutil import sys import time import traceback @@ -112,9 +113,10 @@ import parse_qat_yaml import parsecmd import sample +from losses.dummyloss import DummyLoss from losses.multiboxloss import MultiBoxLoss from nas import parse_nas_yaml -from utils import object_detection_utils, parse_obj_detection_yaml +from utils import kd_relationbased, object_detection_utils, parse_obj_detection_yaml # from range_linear_ai84 import PostTrainLinearQuantizerAI84 @@ -404,6 +406,11 @@ def main(): else: criterion = nn.MSELoss().to(args.device) + # Override criterion with dummy loss when student weight is 0 + if args.kd_student_wt == 0: + criterion = DummyLoss(device=args.device).to(args.device) + msglogger.info("WARNING: kd_student_wt == 0, Overwriting criterion with a dummy loss") + if optimizer is None: optimizer = create_optimizer(model, args) msglogger.info('Optimizer Type: %s', type(optimizer)) @@ -470,12 +477,16 @@ def main(): args.kd_policy = None if args.kd_teacher: - teacher = create_model(supported_models, dimensions, args) + teacher = create_model(supported_models, dimensions, args, mode='kd_teacher') if args.kd_resume: teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume) dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt) - args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw) + if args.kd_relationbased: + args.kd_policy = kd_relationbased.RelationBasedKDPolicy(model, teacher, dlw) + else: + args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, + args.kd_temp, dlw) compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs, frequency=1) @@ -621,6 +632,10 @@ def main(): # Finally run results on the test set test(test_loader, model, criterion, [pylogger], activations_collectors, args=args) + + if args.copy_output_folder: + msglogger.info('Copying output folder to: %s', args.copy_output_folder) + shutil.copytree(msglogger.logdir, args.copy_output_folder, dirs_exist_ok=True) return None @@ -628,19 +643,26 @@ def main(): OBJECTIVE_LOSS_KEY = 'Objective Loss' -def create_model(supported_models, dimensions, args): +def create_model(supported_models, dimensions, args, mode='default'): """Create the model""" - module = next(item for item in supported_models if item['name'] == args.cnn) + if mode == 'default': + module = next(item for item in supported_models if item['name'] == args.cnn) + elif mode == 'kd_teacher': + module = next(item for item in supported_models if item['name'] == args.kd_teacher) # Override distiller's input shape detection. This is not a very clean way to do it since # we're replacing a protected member. distiller.utils._validate_input_shape = ( # pylint: disable=protected-access lambda _a, _b: (1, ) + dimensions[:module['dim'] + 1] ) - - Model = locate(module['module'] + '.' + args.cnn) - if not Model: - raise RuntimeError("Model " + args.cnn + " not found\n") + if mode == 'default': + Model = locate(module['module'] + '.' + args.cnn) + if not Model: + raise RuntimeError("Model " + args.cnn + " not found\n") + elif mode == 'kd_teacher': + Model = locate(module['module'] + '.' + args.kd_teacher) + if not Model: + raise RuntimeError("Model " + args.kd_teacher + " not found\n") # Set model parameters if args.act_mode_8bit: @@ -801,7 +823,7 @@ def train(train_loader, model, criterion, optimizer, epoch, loss = criterion(output, target) # TODO Early exit mechanism for Object Detection case is NOT implemented yet - if not args.obj_detection: + if not args.obj_detection and not args.kd_relationbased: if not args.earlyexit_lossweights: # Measure accuracy if the conditions are set. For `Last Batch` only accuracy # calculation last two batches are used as the last batch might include just a few @@ -997,7 +1019,9 @@ def traverse_pass2(m): def _validate(data_loader, model, criterion, loggers, args, epoch=-1, tflogger=None): """Execute the validation/test loop.""" - losses = {'objective_loss': tnt.AverageValueMeter()} + losses = OrderedDict([(OVERALL_LOSS_KEY, tnt.AverageValueMeter()), + (OBJECTIVE_LOSS_KEY, tnt.AverageValueMeter())]) + if args.obj_detection: map_calculator = MeanAveragePrecision( # box_format='xyxy', # Enable in torchmetrics > 0.6 @@ -1150,7 +1174,10 @@ def save_tensor(t, f, regression=True): else: inputs, target = inputs.to(args.device), target.to(args.device) # compute output from model - output = model(inputs) + if args.kd_relationbased: + output = args.kd_policy.forward(inputs) + else: + output = model(inputs) if args.out_fold_ratio != 1: output = ai8x.unfold_batch(output, args.out_fold_ratio) @@ -1176,10 +1203,14 @@ def save_tensor(t, f, regression=True): if not args.earlyexit_thresholds: # compute loss loss = criterion(output, target) + if args.kd_relationbased: + agg_loss = args.kd_policy.before_backward_pass(None, None, None, None, + loss, None) + losses[OVERALL_LOSS_KEY].add(agg_loss.overall_loss.item()) # measure accuracy and record loss - losses['objective_loss'].add(loss.item()) + losses[OBJECTIVE_LOSS_KEY].add(loss.item()) - if not args.obj_detection: + if not args.obj_detection and not args.kd_relationbased: if len(output.data.shape) <= 2 or args.regression: classerr.add(output.data, target) else: @@ -1205,7 +1236,13 @@ def save_tensor(t, f, regression=True): class_preds.append(class_preds_batch) if not args.earlyexit_thresholds: - if args.obj_detection: + if args.kd_relationbased: + stats = ( + '', + OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean), + ('Overall Loss', losses[OVERALL_LOSS_KEY].mean)]) + ) + elif args.obj_detection: # Only run compute() if there is at least one new update() if have_mAP: # Remove [0] in new torchmetrics @@ -1213,14 +1250,14 @@ def save_tensor(t, f, regression=True): have_mAP = False stats = ( '', - OrderedDict([('Loss', losses['objective_loss'].mean), + OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean), ('mAP', mAP)]) ) else: if not args.regression: stats = ( '', - OrderedDict([('Loss', losses['objective_loss'].mean), + OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean), ('Top1', classerr.value(1))]) ) if args.num_classes > 5: @@ -1228,7 +1265,7 @@ def save_tensor(t, f, regression=True): else: stats = ( '', - OrderedDict([('Loss', losses['objective_loss'].mean), + OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean), ('MSE', classerr.value())]) ) else: @@ -1296,26 +1333,33 @@ def select_n_random(data, labels, features, n=100): if not args.earlyexit_thresholds: + if args.kd_relationbased: + + msglogger.info('==> Overall Loss: %.3f\n', + losses[OVERALL_LOSS_KEY].mean) + + return 0, 0, losses[OVERALL_LOSS_KEY].mean, 0 + if args.obj_detection: msglogger.info('==> mAP: %.5f Loss: %.3f\n', mAP, - losses['objective_loss'].mean) + losses[OBJECTIVE_LOSS_KEY].mean) - return 0, 0, losses['objective_loss'].mean, mAP + return 0, 0, losses[OBJECTIVE_LOSS_KEY].mean, mAP if not args.regression: if args.num_classes > 5: msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', classerr.value()[0], classerr.value()[1], - losses['objective_loss'].mean) + losses[OBJECTIVE_LOSS_KEY].mean) else: msglogger.info('==> Top1: %.3f Loss: %.3f\n', - classerr.value()[0], losses['objective_loss'].mean) + classerr.value()[0], losses[OBJECTIVE_LOSS_KEY].mean) else: msglogger.info('==> MSE: %.5f Loss: %.3f\n', - classerr.value(), losses['objective_loss'].mean) - return classerr.value(), .0, losses['objective_loss'].mean, 0 + classerr.value(), losses[OBJECTIVE_LOSS_KEY].mean) + return classerr.value(), .0, losses[OBJECTIVE_LOSS_KEY].mean, 0 if args.display_confusion: msglogger.info('==> Confusion:\n%s\n', str(confusion.value())) @@ -1325,7 +1369,7 @@ def select_n_random(data, labels, features, n=100): dataformats='HWC') if not args.regression: return classerr.value(1), classerr.value(min(args.num_classes, 5)), \ - losses['objective_loss'].mean, 0 + losses[OBJECTIVE_LOSS_KEY].mean, 0 # else: total_top1, total_top5, losses_exits_stats = earlyexit_validate_stats(args) return total_top1, total_top5, losses_exits_stats[args.num_exits-1], 0 @@ -1342,7 +1386,21 @@ def update_training_scores_history(perf_scores_history, model, top1, top5, mAP, 'top1': top1, 'top5': top5, 'mAP': mAP, 'vloss': -vloss, 'epoch': epoch})) - if args.obj_detection: + if args.kd_relationbased: + + # Keep perf_scores_history sorted from best to worst based on overall loss + # overall_loss = student_loss*student_weight + distillation_loss*distillation_weight + if not args.sparsity_perf: + perf_scores_history.sort(key=operator.attrgetter('vloss', 'epoch'), + reverse=True) + else: + perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'vloss', 'epoch'), + reverse=True) + for score in perf_scores_history[:args.num_best_scores]: + msglogger.info('==> Best [Overall Loss: %f on epoch: %d]', + -score.vloss, score.epoch) + + elif args.obj_detection: # Keep perf_scores_history sorted from best to worst if not args.sparsity_perf: diff --git a/utils/kd_relationbased.py b/utils/kd_relationbased.py new file mode 100644 index 000000000..7ade8c293 --- /dev/null +++ b/utils/kd_relationbased.py @@ -0,0 +1,115 @@ +################################################################################################### +# +# Copyright (C) 2023 Analog Devices, Inc. All Rights Reserved. +# +# Analog Devices, Inc. Default Copyright Notice: +# https://www.analog.com/en/about-adi/legal-and-risk-oversight/intellectual-property/copyright-notice.html +# +################################################################################################### +# +# Portions Copyright (c) 2018 Intel Corporation +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" Relation based Knowledge Distillation Policy""" + +from collections import namedtuple + +import torch +from torch import nn + +from distiller.policy import LossComponent, PolicyLoss, ScheduledTrainingPolicy + +DistillationLossWeights = namedtuple('DistillationLossWeights', + ['distill', 'student', 'teacher']) + + +class RelationBasedKDPolicy(ScheduledTrainingPolicy): + """ + Relation based Knowledge Distillation Policy class based on + the distiller's ScheduledTrainingPolicy class. + """ + def __init__(self, student_model, teacher_model, + loss_weights=DistillationLossWeights(0.5, 0.5, 0)): + super().__init__() + + self.student = student_model + self.teacher = teacher_model + self.teacher_output = None + self.student_output = None + self.loss_wts = loss_weights + self.distillation_loss = nn.MSELoss() + self.overall_loss = None + + # Active is always true, because test will be based on the overall loss and it will be + # realized outside of the epoch loop + self.active = True + + def forward(self, *inputs): + """ + Performs forward propagation through both student and teacher models and + caches the outputs.This function MUST be used instead of calling the student + model directly. + + Returns: + The student model's returned output, to be consistent with what a + script using this would expect + """ + if not self.active: + return self.student(*inputs) + + with torch.no_grad(): + self.teacher_output = self.teacher(*inputs) + + out = self.student(*inputs) + self.student_output = out.clone() + + return out + + # pylint: disable=unused-argument + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): + """ + Not used + """ + + # pylint: disable=unused-argument + def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs): + """ + Not used + """ + + # pylint: disable=unused-argument + def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, + zeros_mask_dict, optimizer=None): + """ + Returns the overall loss, which is a weighted sum of the student loss and + the distillation loss + """ + + if self.student_output is None or self.teacher_output is None: + raise RuntimeError("KnowledgeDistillationPolicy: Student and or teacher outputs" + "were not cached. Make sure to call " + "KnowledgeDistillationPolicy.forward() in your script instead of " + "calling the model directly.") + + distillation_loss = self.distillation_loss(self.student_output, self.teacher_output) + + overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss + + # For logging purposes, we return the un-scaled distillation loss so it's + # comparable between runs with different temperatures + return PolicyLoss(overall_loss, + [LossComponent('Distill Loss', distillation_loss)])