Skip to content

Commit

Permalink
Fix QAT resume with BN models, checkpoint name (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak authored Oct 18, 2023
1 parent 118bd98 commit 87351b3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
39 changes: 39 additions & 0 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,45 @@ def _update_model(m):
m.apply(_update_model)


def update_optimizer(m, optimizer):
"""
Update optimizer after model 'm' had a batchnorm fusion.
This is needed to update the optimizer state_dict to match the new model parameters.
"""
old_state_dict = optimizer.state_dict()
old_groups = optimizer.param_groups
optimizer = type(optimizer)(m.parameters(), **optimizer.defaults)
new_state_dict = optimizer.state_dict()
groups = optimizer.param_groups
for x, g in enumerate(groups):
for p in g['params']:
if (len(p.shape) == 1 and p.shape[0] == 1):
continue
nf_keys = []
key_reduce = 0
for key in old_state_dict['state'].keys():
sub_keys = old_state_dict['state'][key].keys()
if old_groups[x]['params'][int(key)].shape == p.shape:
for y, sub_key in enumerate(sub_keys):
if y == 0:
new_state_dict['state'][key-key_reduce] = \
{sub_key: old_state_dict['state'][key][sub_key]}
else:
new_state_dict['state'][key-key_reduce][sub_key] = \
old_state_dict['state'][key][sub_key]
old_state_dict['state'].pop(key)
break
nf_keys.append(key)
key_reduce += 1
for key in nf_keys:
old_state_dict['state'].pop(key)
new_state_dict['param_groups'][x]['initial_lr'] = \
old_state_dict['param_groups'][x]['initial_lr']

optimizer.load_state_dict(new_state_dict)
return optimizer


def fuse_bn_layers(m):
"""
Fuse the bn layers before the quantization aware training starts.
Expand Down
11 changes: 11 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def main():
# pylint: disable=unsubscriptable-object
if checkpoint.get('epoch', None) >= qat_policy['start_epoch']:
ai8x.fuse_bn_layers(model)
if args.name:
args.name = f'{args.name}_qat'
else:
args.name = 'qat'
# pylint: enable=unsubscriptable-object
model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
model, args.resumed_checkpoint_path, model_device=args.device)
Expand All @@ -359,6 +363,10 @@ def main():
# pylint: disable=unsubscriptable-object
if checkpoint.get('epoch', None) >= qat_policy['start_epoch']:
ai8x.fuse_bn_layers(model)
if args.name:
args.name = f'{args.name}_qat'
else:
args.name = 'qat'
# pylint: enable=unsubscriptable-object
model = apputils.load_lean_checkpoint(model, args.load_model_path,
model_device=args.device)
Expand Down Expand Up @@ -513,6 +521,9 @@ def main():
# Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
ai8x.fuse_bn_layers(model)

# Update the optimizer to reflect fused batchnorm layers
optimizer = ai8x.update_optimizer(model, optimizer)

# Switch model from unquantized to quantized for QAT
ai8x.initiate_qat(model, qat_policy)

Expand Down

0 comments on commit 87351b3

Please sign in to comment.