Skip to content

Commit

Permalink
Distributed mode improvements, quantile function replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Nov 21, 2024
1 parent 1cae6b0 commit 2ee12fb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 54 deletions.
54 changes: 8 additions & 46 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,45 +437,6 @@ def forward(self, _, x): # pylint: disable=arguments-differ
return x


def interp(x, xp, fp, method='linear'):
"""
Simple PyTorch implementation of `np.interp`.
1D data only, length must be 2 or greater.
`method` must be "linear" or "lower".
"""
# Find the index
n = len(xp) - 1
if n == 0:
return fp[0]
if x == 1.:
return fp[-1]
i = torch.clip(torch.searchsorted(xp, x, side='right').unsqueeze(0), 1, n) - 1
# Calculate fractional index
if method == 'linear':
g = x * n - i
else:
assert method == 'lower'
g = .0
# Interpolate result
return fp[i] + g * (fp[i + 1] - fp[i])


def quantile(x, q, method='linear'):
"""
Ersatz quantile function in PyTorch that works with torch.compile().
1D data only, len(x) must be 2 or greater.
`method` must be "linear" or "lower".
"""
x = x.flatten()
n = len(x)
return interp(
q,
torch.linspace(1 / (2 * n), (2 * n - 1) / (2 * n), n, device=x.device),
torch.sort(x)[0],
method,
).squeeze(0)


class OutputShiftLimit(nn.Module):
"""
Calculate the clamped output shift when adjusting during quantization-aware training.
Expand All @@ -486,7 +447,7 @@ def __init__(self, shift_quantile=1.0):

def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
limit = quantile(x.abs(), self.shift_quantile)
limit = torch.quantile(x.abs(), self.shift_quantile)
return -(1./limit).log2().floor().clamp(min=-15., max=15.)


Expand Down Expand Up @@ -2276,15 +2237,16 @@ def stat_collect(train_loader, model, args):
model(inputs)


def pre_qat(model, train_loader, args, qat_policy):
def pre_qat(model, train_loader, args, qat_policy, local_rank=0):
"""
Prepare the model for quantization aware training
"""
init_hist(model)
stat_collect(train_loader, model, args)
init_threshold(model, qat_policy["outlier_removal_z_score"])
release_hist(model)
apply_scales(model)
if local_rank <= 0:
init_hist(model)
stat_collect(train_loader, model, args)
init_threshold(model, qat_policy["outlier_removal_z_score"])
release_hist(model)
apply_scales(model)


def init_hist(model):
Expand Down
42 changes: 34 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def flush(self):

msglogger.info('Collecting statistics for quantization aware training (QAT)...')

ai8x.pre_qat(model, train_loader, args, qat_policy)
ai8x.pre_qat(model, train_loader, args, qat_policy, local_rank)

# Update the optimizer to reflect fused batchnorm layers
optimizer = ai8x.update_optimizer(model, optimizer)
Expand Down Expand Up @@ -640,6 +640,12 @@ def flush(self):
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)

# TODO: Optimize DDP is currently not supported with QAT.
# Once pytorch supports DDP with higher order ops,
# we can enable optimize DDP with QAT.
# https://github.com/pytorch/pytorch/issues/104674.
torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
Expand Down Expand Up @@ -734,7 +740,7 @@ def flush(self):
if not args.dr:
test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt")
test(test_loader, model, criterion, [pylogger], args=args, mode="best",
ckpt_name=checkpoint_name)
ckpt_name=checkpoint_name, local_rank=local_rank)

if args.copy_output_folder and local_rank <= 0:
msglogger.info('Copying output folder to: %s', args.copy_output_folder)
Expand Down Expand Up @@ -1067,19 +1073,39 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non
return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger)


def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None):
def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None, local_rank=0):
"""Model Test"""
assert msglogger is not None
if mode == 'ckpt':
msglogger.info('--- test (ckpt) ---------------------')
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
else:
msglogger.info('--- test (best) ---------------------')
if ckpt_name is None:
best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
else:
best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
model = apputils.load_lean_checkpoint(model, best_ckpt_path)
model, dynamo, ddp = model_wrapper.unwrap(model)
if local_rank <= 0:
if ckpt_name is None:
best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
else:
best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
model = apputils.load_lean_checkpoint(model, best_ckpt_path)

if ddp:
model = DistributedDataParallel(
model,
device_ids=[local_rank] if args.device == 'cuda' else None,
output_device=local_rank if args.device == 'cuda' else None,
)

if dynamo:
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
torch._dynamo.config.cache_size_limit, # pylint: disable=protected-access
)

top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)

return top1, top5, vloss, mAP
Expand Down

0 comments on commit 2ee12fb

Please sign in to comment.