Skip to content

Commit

Permalink
Fix LSTM int8 quantization model size issue (pytorch#23577)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#23577

This diff is fixing a model size issue introduced in pytorch#23291. After that PR, the model size after in8 quantization is the same as that of the original unquantized model. The reason is that we save original weight for int8 quantization even when that's not needed anymore. This diff fixes that by only saving original weight for fp16 quantization path.

Reviewed By: llyfacebook

Differential Revision: D16557619

fbshipit-source-id: f924ae8d155a0d525b86a7440b3c7147d5bead0a
  • Loading branch information
mingzhe09088 authored and facebook-github-bot committed Aug 2, 2019
1 parent 3107f1d commit 29881c7
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions torch/jit/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,6 @@ def process_weights(ihhh, layer, suffix, dtype):
weight = getattr(other, weight_name)
bias = getattr(other, bias_name)

orig_weights.append(weight_name)
self.register_buffer(weight_name, weight)

if dtype == torch.int8:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
Expand All @@ -318,6 +315,8 @@ def process_weights(ihhh, layer, suffix, dtype):
packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
weight.clone().float())

orig_weights.append(weight_name)
self.register_buffer(weight_name, weight)
params = [packed_weight, bias]
pos_names = ['packed', 'b']
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
Expand All @@ -336,7 +335,13 @@ def process_weights(ihhh, layer, suffix, dtype):

self._packed_weights = packed_weights
self._quantized_weights = quantized_weights
self._orig_weights = orig_weights
# For int8 quantization, _orig_weights is not needed in the quantization logic,
# however there is a JIT compilation error without it. This is just used to
# workaround that error.
if dtype == torch.int8:
self._orig_weights = self._packed_weights
else:
self._orig_weights = orig_weights

@torch.jit.script_method
def check_input(self, input, batch_sizes):
Expand Down

0 comments on commit 29881c7

Please sign in to comment.