From 32f3bc560c5b676d041fe7894ecedca79df07656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20B=C3=BCy=C3=BCksolak?= Date: Tue, 17 Oct 2023 20:59:59 +0300 Subject: [PATCH 1/3] Bugfixes: Qat resume with BN models, ckpt name --- ai8x.py | 37 +++++++++++++++++++++++++++++++++++++ train.py | 11 +++++++++++ 2 files changed, 48 insertions(+) diff --git a/ai8x.py b/ai8x.py index f018b9772..b4c3aa750 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1794,6 +1794,43 @@ def _update_model(m): m.apply(_update_model) +def update_optimizer(m, optimizer): + """ + Update optimizer after model 'm' had a batchnorm fusion. + This is needed to update the optimizer state_dict to match the new model parameters. + """ + old_state_dict = optimizer.state_dict() + old_groups = optimizer.param_groups + optimizer = type(optimizer)(m.parameters(), **optimizer.defaults) + new_state_dict = optimizer.state_dict() + groups = optimizer.param_groups + for x, g in enumerate(groups): + for p in g['params']: + if (len(p.shape) == 1 and p.shape[0] == 1): + continue + else: + nf_keys = [] + key_reduce = 0 + for key in old_state_dict['state'].keys(): + sub_keys = old_state_dict['state'][key].keys() + if (old_groups[x]['params'][int(key)].shape == p.shape): + for y, sub_key in enumerate(sub_keys): + if y == 0: + new_state_dict['state'][key-key_reduce] = {sub_key: old_state_dict['state'][key][sub_key]} + else: + new_state_dict['state'][key-key_reduce][sub_key] = old_state_dict['state'][key][sub_key] + old_state_dict['state'].pop(key) + break + else: + nf_keys.append(key) + key_reduce += 1 + continue + for key in nf_keys: + old_state_dict['state'].pop(key) + new_state_dict['param_groups'][x]['initial_lr'] = old_state_dict['param_groups'][x]['initial_lr'] + + optimizer.load_state_dict(new_state_dict) + return optimizer def fuse_bn_layers(m): """ diff --git a/train.py b/train.py index c026db5b6..24fb79ad8 100644 --- a/train.py +++ b/train.py @@ -347,6 +347,10 @@ def main(): # pylint: disable=unsubscriptable-object if checkpoint.get('epoch', None) >= qat_policy['start_epoch']: ai8x.fuse_bn_layers(model) + if args.name: + args.name = f'{args.name}_qat' + else: + args.name = 'qat' # pylint: enable=unsubscriptable-object model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( model, args.resumed_checkpoint_path, model_device=args.device) @@ -359,6 +363,10 @@ def main(): # pylint: disable=unsubscriptable-object if checkpoint.get('epoch', None) >= qat_policy['start_epoch']: ai8x.fuse_bn_layers(model) + if args.name: + args.name = f'{args.name}_qat' + else: + args.name = 'qat' # pylint: enable=unsubscriptable-object model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device) @@ -513,6 +521,9 @@ def main(): # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT) ai8x.fuse_bn_layers(model) + # Update the optimizer to reflect fused batchnorm layers + optimizer = ai8x.update_optimizer(model, optimizer) + # Switch model from unquantized to quantized for QAT ai8x.initiate_qat(model, qat_policy) From 3d8dd763d995119e1eaf8fe61e5790ba8886682b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20B=C3=BCy=C3=BCksolak?= Date: Tue, 17 Oct 2023 21:50:02 +0300 Subject: [PATCH 2/3] linter updates --- ai8x.py | 38 +++++++++++++++++++------------------- train.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ai8x.py b/ai8x.py index b4c3aa750..0600312fb 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1808,26 +1808,26 @@ def update_optimizer(m, optimizer): for p in g['params']: if (len(p.shape) == 1 and p.shape[0] == 1): continue - else: - nf_keys = [] - key_reduce = 0 - for key in old_state_dict['state'].keys(): - sub_keys = old_state_dict['state'][key].keys() - if (old_groups[x]['params'][int(key)].shape == p.shape): - for y, sub_key in enumerate(sub_keys): - if y == 0: - new_state_dict['state'][key-key_reduce] = {sub_key: old_state_dict['state'][key][sub_key]} - else: - new_state_dict['state'][key-key_reduce][sub_key] = old_state_dict['state'][key][sub_key] - old_state_dict['state'].pop(key) - break - else: - nf_keys.append(key) - key_reduce += 1 - continue - for key in nf_keys: + nf_keys = [] + key_reduce = 0 + for key in old_state_dict['state'].keys(): + sub_keys = old_state_dict['state'][key].keys() + if old_groups[x]['params'][int(key)].shape == p.shape: + for y, sub_key in enumerate(sub_keys): + if y == 0: + new_state_dict['state'][key-key_reduce] = \ + {sub_key: old_state_dict['state'][key][sub_key]} + else: + new_state_dict['state'][key-key_reduce][sub_key] = \ + old_state_dict['state'][key][sub_key] old_state_dict['state'].pop(key) - new_state_dict['param_groups'][x]['initial_lr'] = old_state_dict['param_groups'][x]['initial_lr'] + break + nf_keys.append(key) + key_reduce += 1 + for key in nf_keys: + old_state_dict['state'].pop(key) + new_state_dict['param_groups'][x]['initial_lr'] = \ + old_state_dict['param_groups'][x]['initial_lr'] optimizer.load_state_dict(new_state_dict) return optimizer diff --git a/train.py b/train.py index 24fb79ad8..25c43cefa 100644 --- a/train.py +++ b/train.py @@ -521,7 +521,7 @@ def main(): # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT) ai8x.fuse_bn_layers(model) - # Update the optimizer to reflect fused batchnorm layers + # Update the optimizer to reflect fused batchnorm layers optimizer = ai8x.update_optimizer(model, optimizer) # Switch model from unquantized to quantized for QAT From c3b02985a380bd4c1be239a6472d85fbbf78a4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20B=C3=BCy=C3=BCksolak?= Date: Tue, 17 Oct 2023 22:17:06 +0300 Subject: [PATCH 3/3] linter updates --- ai8x.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ai8x.py b/ai8x.py index 0600312fb..fa5683d22 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1794,6 +1794,7 @@ def _update_model(m): m.apply(_update_model) + def update_optimizer(m, optimizer): """ Update optimizer after model 'm' had a batchnorm fusion. @@ -1832,6 +1833,7 @@ def update_optimizer(m, optimizer): optimizer.load_state_dict(new_state_dict) return optimizer + def fuse_bn_layers(m): """ Fuse the bn layers before the quantization aware training starts.