Skip to content

Commit

Permalink
Merge pull request #14 from neuralmagic/more-aggressive-cleanup
Browse files Browse the repository at this point in the history
Perform more aggressive cleanup during weight quantization and add tqdm
  • Loading branch information
mgoin authored Jun 13, 2024
2 parents fc895fd + 5e876d2 commit ffea17e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def quantize_weights(
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
for name, linear in model.named_modules():
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
if (
not isinstance(linear, torch.nn.Linear)
or name in quantize_config.ignored_layers
Expand All @@ -205,7 +206,7 @@ def quantize_weights(
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, linear.bias)
replace_module(model, name, quant_linear)
del linear
cleanup_memory()
cleanup_memory()


def quantize_activations(
Expand All @@ -214,6 +215,7 @@ def quantize_activations(
calibration_tokens,
ignored_layers: List[str] = [],
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
if (
not isinstance(dynamic_quant_linear, FP8DynamicLinear)
Expand All @@ -229,14 +231,14 @@ def quantize_activations(
del dynamic_quant_linear
cleanup_memory()

# Calibration.
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating") as pbar:
# Pass through calibration data to measure activation scales
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)

# Replace dynamic quantizer with StaticLinear for export
# Replace dynamic quantizer observer with StaticLinear for export
for name, quantizer in model.named_modules():
if (
not isinstance(quantizer, FP8StaticLinearQuantizer)
Expand Down

0 comments on commit ffea17e

Please sign in to comment.