diff --git a/ai8x.py b/ai8x.py index 6ffbff4d5..665484aaa 100644 --- a/ai8x.py +++ b/ai8x.py @@ -2237,16 +2237,15 @@ def stat_collect(train_loader, model, args): model(inputs) -def pre_qat(model, train_loader, args, qat_policy, local_rank=0): +def pre_qat(model, train_loader, args, qat_policy): """ Prepare the model for quantization aware training """ - if local_rank <= 0: - init_hist(model) - stat_collect(train_loader, model, args) - init_threshold(model, qat_policy["outlier_removal_z_score"]) - release_hist(model) - apply_scales(model) + init_hist(model) + stat_collect(train_loader, model, args) + init_threshold(model, qat_policy["outlier_removal_z_score"]) + release_hist(model) + apply_scales(model) def init_hist(model): diff --git a/train.py b/train.py index 2916a12ae..f76eab4af 100755 --- a/train.py +++ b/train.py @@ -610,7 +610,7 @@ def flush(self): msglogger.info('Collecting statistics for quantization aware training (QAT)...') - ai8x.pre_qat(model, train_loader, args, qat_policy, local_rank) + ai8x.pre_qat(model, train_loader, args, qat_policy) # Update the optimizer to reflect fused batchnorm layers optimizer = ai8x.update_optimizer(model, optimizer)