Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr262 #262

Closed
wants to merge 10 commits into from
Closed

Pr262 #262

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,12 +1805,13 @@ def update_optimizer(m, optimizer):
optimizer = type(optimizer)(m.parameters(), **optimizer.defaults)
new_state_dict = optimizer.state_dict()
groups = optimizer.param_groups

for x, g in enumerate(groups):
key_reduce = 0
for p in g['params']:
if (len(p.shape) == 1 and p.shape[0] == 1):
continue
nf_keys = []
key_reduce = 0
for key in old_state_dict['state'].keys():
sub_keys = old_state_dict['state'][key].keys()
if old_groups[x]['params'][int(key)].shape == p.shape:
Expand All @@ -1827,9 +1828,10 @@ def update_optimizer(m, optimizer):
key_reduce += 1
for key in nf_keys:
old_state_dict['state'].pop(key)
new_state_dict['param_groups'][x]['initial_lr'] = \
old_state_dict['param_groups'][x]['initial_lr']

for key in old_state_dict['param_groups'][x].keys():
if key != 'params':
new_state_dict['param_groups'][x][key] = \
old_state_dict['param_groups'][x][key]
optimizer.load_state_dict(new_state_dict)
return optimizer

Expand Down
35 changes: 35 additions & 0 deletions losses/dummyloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
###################################################################################################
#
# Copyright (C) 2023 Analog Devices, Inc. All Rights Reserved.
#
# Analog Devices, Inc. Default Copyright Notice:
# https://www.analog.com/en/about-adi/legal-and-risk-oversight/intellectual-property/copyright-notice.html
#
###################################################################################################
"""
Dummy Loss to use in knowledge distillation when student loss weight is 0
"""

import torch
from torch import nn


class DummyLoss(nn.Module):
"""
Class for dummy loss
"""
def __init__(self, device='cpu'):
"""
Initializes the loss
"""
super().__init__()

self.device = device

# pylint: disable=unused-argument
def forward(self, output, target):
"""
returns 0.0
"""

return torch.tensor(0.0, device=self.device)
5 changes: 4 additions & 1 deletion parsecmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def get_parser(model_names, dataset_names):
parser.add_argument('--avg-pool-rounding', action='store_true', default=False,
help='when simulating, use "round()" in AvgPool operations '
'(default: use "floor()")')

parser.add_argument('--copy-output-folder', type=str, default=None, metavar='PATH',
help='Path to copy output folder (default: None)')
parser.add_argument('--kd-relationbased', action='store_true', default=False,
help='enables Relation Based Knowledge Distillation')
qat_args = parser.add_argument_group('Quantization Arguments')
qat_args.add_argument('--qat-policy', dest='qat_policy',
default=os.path.join('policies', 'qat_policy.yaml'),
Expand Down
110 changes: 84 additions & 26 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import resource # pylint: disable=import-error
# pylint: enable=wrong-import-position

import shutil
import sys
import time
import traceback
Expand Down Expand Up @@ -112,9 +113,10 @@
import parse_qat_yaml
import parsecmd
import sample
from losses.dummyloss import DummyLoss
from losses.multiboxloss import MultiBoxLoss
from nas import parse_nas_yaml
from utils import object_detection_utils, parse_obj_detection_yaml
from utils import kd_relationbased, object_detection_utils, parse_obj_detection_yaml

# from range_linear_ai84 import PostTrainLinearQuantizerAI84

Expand Down Expand Up @@ -404,6 +406,11 @@ def main():
else:
criterion = nn.MSELoss().to(args.device)

# Override criterion with dummy loss when student weight is 0
if args.kd_student_wt == 0:
criterion = DummyLoss(device=args.device).to(args.device)
msglogger.info("WARNING: kd_student_wt == 0, Overwriting criterion with a dummy loss")

if optimizer is None:
optimizer = create_optimizer(model, args)
msglogger.info('Optimizer Type: %s', type(optimizer))
Expand Down Expand Up @@ -470,12 +477,16 @@ def main():

args.kd_policy = None
if args.kd_teacher:
teacher = create_model(supported_models, dimensions, args)
teacher = create_model(supported_models, dimensions, args, mode='kd_teacher')
if args.kd_resume:
teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume)
dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt,
args.kd_teacher_wt)
args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
if args.kd_relationbased:
args.kd_policy = kd_relationbased.RelationBasedKDPolicy(model, teacher, dlw)
else:
args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher,
args.kd_temp, dlw)
compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch,
ending_epoch=args.epochs, frequency=1)

Expand Down Expand Up @@ -621,26 +632,37 @@ def main():

# Finally run results on the test set
test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)

if args.copy_output_folder:
msglogger.info('Copying output folder to: %s', args.copy_output_folder)
shutil.copytree(msglogger.logdir, args.copy_output_folder, dirs_exist_ok=True)
return None


OVERALL_LOSS_KEY = 'Overall Loss'
OBJECTIVE_LOSS_KEY = 'Objective Loss'


def create_model(supported_models, dimensions, args):
def create_model(supported_models, dimensions, args, mode='default'):
"""Create the model"""
module = next(item for item in supported_models if item['name'] == args.cnn)
if mode == 'default':
module = next(item for item in supported_models if item['name'] == args.cnn)
elif mode == 'kd_teacher':
module = next(item for item in supported_models if item['name'] == args.kd_teacher)

# Override distiller's input shape detection. This is not a very clean way to do it since
# we're replacing a protected member.
distiller.utils._validate_input_shape = ( # pylint: disable=protected-access
lambda _a, _b: (1, ) + dimensions[:module['dim'] + 1]
)

Model = locate(module['module'] + '.' + args.cnn)
if not Model:
raise RuntimeError("Model " + args.cnn + " not found\n")
if mode == 'default':
Model = locate(module['module'] + '.' + args.cnn)
if not Model:
raise RuntimeError("Model " + args.cnn + " not found\n")
elif mode == 'kd_teacher':
Model = locate(module['module'] + '.' + args.kd_teacher)
if not Model:
raise RuntimeError("Model " + args.kd_teacher + " not found\n")

# Set model parameters
if args.act_mode_8bit:
Expand Down Expand Up @@ -801,7 +823,7 @@ def train(train_loader, model, criterion, optimizer, epoch,

loss = criterion(output, target)
# TODO Early exit mechanism for Object Detection case is NOT implemented yet
if not args.obj_detection:
if not args.obj_detection and not args.kd_relationbased:
if not args.earlyexit_lossweights:
# Measure accuracy if the conditions are set. For `Last Batch` only accuracy
# calculation last two batches are used as the last batch might include just a few
Expand Down Expand Up @@ -997,7 +1019,9 @@ def traverse_pass2(m):

def _validate(data_loader, model, criterion, loggers, args, epoch=-1, tflogger=None):
"""Execute the validation/test loop."""
losses = {'objective_loss': tnt.AverageValueMeter()}
losses = OrderedDict([(OVERALL_LOSS_KEY, tnt.AverageValueMeter()),
(OBJECTIVE_LOSS_KEY, tnt.AverageValueMeter())])

if args.obj_detection:
map_calculator = MeanAveragePrecision(
# box_format='xyxy', # Enable in torchmetrics > 0.6
Expand Down Expand Up @@ -1150,7 +1174,10 @@ def save_tensor(t, f, regression=True):
else:
inputs, target = inputs.to(args.device), target.to(args.device)
# compute output from model
output = model(inputs)
if args.kd_relationbased:
output = args.kd_policy.forward(inputs)
else:
output = model(inputs)
if args.out_fold_ratio != 1:
output = ai8x.unfold_batch(output, args.out_fold_ratio)

Expand All @@ -1176,10 +1203,14 @@ def save_tensor(t, f, regression=True):
if not args.earlyexit_thresholds:
# compute loss
loss = criterion(output, target)
if args.kd_relationbased:
agg_loss = args.kd_policy.before_backward_pass(None, None, None, None,
loss, None)
losses[OVERALL_LOSS_KEY].add(agg_loss.overall_loss.item())
# measure accuracy and record loss
losses['objective_loss'].add(loss.item())
losses[OBJECTIVE_LOSS_KEY].add(loss.item())

if not args.obj_detection:
if not args.obj_detection and not args.kd_relationbased:
if len(output.data.shape) <= 2 or args.regression:
classerr.add(output.data, target)
else:
Expand All @@ -1205,30 +1236,36 @@ def save_tensor(t, f, regression=True):
class_preds.append(class_preds_batch)

if not args.earlyexit_thresholds:
if args.obj_detection:
if args.kd_relationbased:
stats = (
'',
OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean),
('Overall Loss', losses[OVERALL_LOSS_KEY].mean)])
)
elif args.obj_detection:
# Only run compute() if there is at least one new update()
if have_mAP:
# Remove [0] in new torchmetrics
mAP = map_calculator.compute()['map_50'][0]
have_mAP = False
stats = (
'',
OrderedDict([('Loss', losses['objective_loss'].mean),
OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean),
('mAP', mAP)])
)
else:
if not args.regression:
stats = (
'',
OrderedDict([('Loss', losses['objective_loss'].mean),
OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean),
('Top1', classerr.value(1))])
)
if args.num_classes > 5:
stats[1]['Top5'] = classerr.value(5)
else:
stats = (
'',
OrderedDict([('Loss', losses['objective_loss'].mean),
OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean),
('MSE', classerr.value())])
)
else:
Expand Down Expand Up @@ -1296,26 +1333,33 @@ def select_n_random(data, labels, features, n=100):

if not args.earlyexit_thresholds:

if args.kd_relationbased:

msglogger.info('==> Overall Loss: %.3f\n',
losses[OVERALL_LOSS_KEY].mean)

return 0, 0, losses[OVERALL_LOSS_KEY].mean, 0

if args.obj_detection:

msglogger.info('==> mAP: %.5f Loss: %.3f\n',
mAP,
losses['objective_loss'].mean)
losses[OBJECTIVE_LOSS_KEY].mean)

return 0, 0, losses['objective_loss'].mean, mAP
return 0, 0, losses[OBJECTIVE_LOSS_KEY].mean, mAP

if not args.regression:
if args.num_classes > 5:
msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n',
classerr.value()[0], classerr.value()[1],
losses['objective_loss'].mean)
losses[OBJECTIVE_LOSS_KEY].mean)
else:
msglogger.info('==> Top1: %.3f Loss: %.3f\n',
classerr.value()[0], losses['objective_loss'].mean)
classerr.value()[0], losses[OBJECTIVE_LOSS_KEY].mean)
else:
msglogger.info('==> MSE: %.5f Loss: %.3f\n',
classerr.value(), losses['objective_loss'].mean)
return classerr.value(), .0, losses['objective_loss'].mean, 0
classerr.value(), losses[OBJECTIVE_LOSS_KEY].mean)
return classerr.value(), .0, losses[OBJECTIVE_LOSS_KEY].mean, 0

if args.display_confusion:
msglogger.info('==> Confusion:\n%s\n', str(confusion.value()))
Expand All @@ -1325,7 +1369,7 @@ def select_n_random(data, labels, features, n=100):
dataformats='HWC')
if not args.regression:
return classerr.value(1), classerr.value(min(args.num_classes, 5)), \
losses['objective_loss'].mean, 0
losses[OBJECTIVE_LOSS_KEY].mean, 0
# else:
total_top1, total_top5, losses_exits_stats = earlyexit_validate_stats(args)
return total_top1, total_top5, losses_exits_stats[args.num_exits-1], 0
Expand All @@ -1342,7 +1386,21 @@ def update_training_scores_history(perf_scores_history, model, top1, top5, mAP,
'top1': top1, 'top5': top5, 'mAP': mAP, 'vloss': -vloss,
'epoch': epoch}))

if args.obj_detection:
if args.kd_relationbased:

# Keep perf_scores_history sorted from best to worst based on overall loss
# overall_loss = student_loss*student_weight + distillation_loss*distillation_weight
if not args.sparsity_perf:
perf_scores_history.sort(key=operator.attrgetter('vloss', 'epoch'),
reverse=True)
else:
perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'vloss', 'epoch'),
reverse=True)
for score in perf_scores_history[:args.num_best_scores]:
msglogger.info('==> Best [Overall Loss: %f on epoch: %d]',
-score.vloss, score.epoch)

elif args.obj_detection:

# Keep perf_scores_history sorted from best to worst
if not args.sparsity_perf:
Expand Down
Loading
Loading