Skip to content

Commit

Permalink
Revert local_rank change for pre_qat
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Nov 21, 2024
1 parent 2ee12fb commit 92568a2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,16 +2237,15 @@ def stat_collect(train_loader, model, args):
model(inputs)


def pre_qat(model, train_loader, args, qat_policy, local_rank=0):
def pre_qat(model, train_loader, args, qat_policy):
"""
Prepare the model for quantization aware training
"""
if local_rank <= 0:
init_hist(model)
stat_collect(train_loader, model, args)
init_threshold(model, qat_policy["outlier_removal_z_score"])
release_hist(model)
apply_scales(model)
init_hist(model)
stat_collect(train_loader, model, args)
init_threshold(model, qat_policy["outlier_removal_z_score"])
release_hist(model)
apply_scales(model)


def init_hist(model):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def flush(self):

msglogger.info('Collecting statistics for quantization aware training (QAT)...')

ai8x.pre_qat(model, train_loader, args, qat_policy, local_rank)
ai8x.pre_qat(model, train_loader, args, qat_policy)

# Update the optimizer to reflect fused batchnorm layers
optimizer = ai8x.update_optimizer(model, optimizer)
Expand Down

0 comments on commit 92568a2

Please sign in to comment.