diff --git a/ai8x.py b/ai8x.py index d46bd1163..6ffbff4d5 100644 --- a/ai8x.py +++ b/ai8x.py @@ -437,45 +437,6 @@ def forward(self, _, x): # pylint: disable=arguments-differ return x -def interp(x, xp, fp, method='linear'): - """ - Simple PyTorch implementation of `np.interp`. - 1D data only, length must be 2 or greater. - `method` must be "linear" or "lower". - """ - # Find the index - n = len(xp) - 1 - if n == 0: - return fp[0] - if x == 1.: - return fp[-1] - i = torch.clip(torch.searchsorted(xp, x, side='right').unsqueeze(0), 1, n) - 1 - # Calculate fractional index - if method == 'linear': - g = x * n - i - else: - assert method == 'lower' - g = .0 - # Interpolate result - return fp[i] + g * (fp[i + 1] - fp[i]) - - -def quantile(x, q, method='linear'): - """ - Ersatz quantile function in PyTorch that works with torch.compile(). - 1D data only, len(x) must be 2 or greater. - `method` must be "linear" or "lower". - """ - x = x.flatten() - n = len(x) - return interp( - q, - torch.linspace(1 / (2 * n), (2 * n - 1) / (2 * n), n, device=x.device), - torch.sort(x)[0], - method, - ).squeeze(0) - - class OutputShiftLimit(nn.Module): """ Calculate the clamped output shift when adjusting during quantization-aware training. @@ -486,7 +447,7 @@ def __init__(self, shift_quantile=1.0): def forward(self, x, _): # pylint: disable=arguments-differ """Forward prop""" - limit = quantile(x.abs(), self.shift_quantile) + limit = torch.quantile(x.abs(), self.shift_quantile) return -(1./limit).log2().floor().clamp(min=-15., max=15.) @@ -2276,15 +2237,16 @@ def stat_collect(train_loader, model, args): model(inputs) -def pre_qat(model, train_loader, args, qat_policy): +def pre_qat(model, train_loader, args, qat_policy, local_rank=0): """ Prepare the model for quantization aware training """ - init_hist(model) - stat_collect(train_loader, model, args) - init_threshold(model, qat_policy["outlier_removal_z_score"]) - release_hist(model) - apply_scales(model) + if local_rank <= 0: + init_hist(model) + stat_collect(train_loader, model, args) + init_threshold(model, qat_policy["outlier_removal_z_score"]) + release_hist(model) + apply_scales(model) def init_hist(model): diff --git a/train.py b/train.py index 7c74889f6..2916a12ae 100755 --- a/train.py +++ b/train.py @@ -610,7 +610,7 @@ def flush(self): msglogger.info('Collecting statistics for quantization aware training (QAT)...') - ai8x.pre_qat(model, train_loader, args, qat_policy) + ai8x.pre_qat(model, train_loader, args, qat_policy, local_rank) # Update the optimizer to reflect fused batchnorm layers optimizer = ai8x.update_optimizer(model, optimizer) @@ -640,6 +640,12 @@ def flush(self): torch._dynamo.reset() # pylint: disable=protected-access model = torch.compile(model, mode=args.compiler_mode, backend=args.compiler_backend) + + # TODO: Optimize DDP is currently not supported with QAT. + # Once pytorch supports DDP with higher order ops, + # we can enable optimize DDP with QAT. + # https://github.com/pytorch/pytorch/issues/104674. + torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access msglogger.info( 'torch.compile() successful, mode=%s, cache limit=%d', args.compiler_mode, @@ -734,7 +740,7 @@ def flush(self): if not args.dr: test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt") test(test_loader, model, criterion, [pylogger], args=args, mode="best", - ckpt_name=checkpoint_name) + ckpt_name=checkpoint_name, local_rank=local_rank) if args.copy_output_folder and local_rank <= 0: msglogger.info('Copying output folder to: %s', args.copy_output_folder) @@ -1067,7 +1073,7 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger) -def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None): +def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None, local_rank=0): """Model Test""" assert msglogger is not None if mode == 'ckpt': @@ -1075,11 +1081,31 @@ def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=No top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) else: msglogger.info('--- test (best) ---------------------') - if ckpt_name is None: - best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar') - else: - best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar") - model = apputils.load_lean_checkpoint(model, best_ckpt_path) + model, dynamo, ddp = model_wrapper.unwrap(model) + if local_rank <= 0: + if ckpt_name is None: + best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar') + else: + best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar") + model = apputils.load_lean_checkpoint(model, best_ckpt_path) + + if ddp: + model = DistributedDataParallel( + model, + device_ids=[local_rank] if args.device == 'cuda' else None, + output_device=local_rank if args.device == 'cuda' else None, + ) + + if dynamo: + torch._dynamo.reset() # pylint: disable=protected-access + model = torch.compile(model, mode=args.compiler_mode, + backend=args.compiler_backend) + msglogger.info( + 'torch.compile() successful, mode=%s, cache limit=%d', + args.compiler_mode, + torch._dynamo.config.cache_size_limit, # pylint: disable=protected-access + ) + top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) return top1, top5, vloss, mAP