Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization #2

Draft
wants to merge 124 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 121 commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
58813dc
hacky version of quantize pass implemented
Sep 18, 2020
40f988d
quantize / dequantize instead of casting
electriclilies Sep 18, 2020
24b7fa1
more updates
electriclilies Sep 21, 2020
6f1df09
Add preprocessing.
jwfromm Sep 21, 2020
cd7121f
start of requantize pass
electriclilies Sep 22, 2020
8d116cd
requantize pass runs
electriclilies Sep 22, 2020
24dea2e
requantize after conv2d in most cases
electriclilies Sep 24, 2020
2acc2b4
some notes to self
electriclilies Sep 24, 2020
f4615dc
testing maxpool no option
electriclilies Sep 24, 2020
94ea3aa
added skipping layers and quantization of dense operator
electriclilies Sep 29, 2020
01269b2
added variables for scale, zp, and skip_layers
electriclilies Sep 30, 2020
dfdea03
some stuff
electriclilies Oct 2, 2020
16a1499
first attempt at calibration_map for conv2d
electriclilies Oct 5, 2020
bc885b8
remove node_map
electriclilies Oct 5, 2020
84dbc3e
attempt at global calibration pass with Let
electriclilies Oct 6, 2020
514fec5
global calibration runs!
electriclilies Oct 8, 2020
c726aed
remove prints
electriclilies Oct 8, 2020
d0e7f23
test script compares global calibration to unquantized
electriclilies Oct 12, 2020
2aba530
calibration pass in progress, pad bug
electriclilies Oct 13, 2020
f172f01
subgraphs run!
electriclilies Oct 14, 2020
2a5dcdd
separate quantization for weights and activation
electriclilies Oct 14, 2020
1f6dac3
refactor calibration_map
Oct 20, 2020
2487817
merge with main
electriclilies Oct 20, 2020
fe23ae8
change calibration_callback
electriclilies Oct 20, 2020
0841b14
cleaning up code
electriclilies Oct 21, 2020
cb2b412
add is_weight util
electriclilies Oct 21, 2020
7c067ec
unit testing
electriclilies Oct 22, 2020
230275c
Merge remote-tracking branch 'upstream/main' into quantization
electriclilies Oct 22, 2020
460f5b8
add more testing for quantization
electriclilies Oct 23, 2020
afd2690
debugging bind_params_by_name
electriclilies Oct 23, 2020
03de111
fix calibration callback bug
electriclilies Oct 26, 2020
2631792
change calibration_callback to return a dictionary
electriclilies Oct 26, 2020
ffd1e1d
first attempt at kl_divergence
electriclilies Oct 29, 2020
1b7850b
forgot some files
electriclilies Oct 29, 2020
4ac5012
average mean calibration runs
electriclilies Nov 2, 2020
f87678a
change qnn.add to add and qnn.mul to multiply
electriclilies Nov 2, 2020
49df674
changes to turn subgraph_fn to subgraph_mod
electriclilies Nov 2, 2020
ef00747
Merge remote-tracking branch 'upstream/main' into quantization
electriclilies Nov 2, 2020
027e47b
fix bad merge
electriclilies Nov 2, 2020
9ab5619
some debug statements
electriclilies Nov 4, 2020
9d54885
Merge remote-tracking branch 'upstream/main' into quantization
electriclilies Nov 4, 2020
7a5535c
add dynamic dequantize
electriclilies Nov 4, 2020
ef30a98
register quantize and dequantize as opaque
electriclilies Nov 4, 2020
3cd5b4f
make tests better
electriclilies Nov 4, 2020
f96398d
black
electriclilies Nov 4, 2020
b0e7cec
remove main fn
electriclilies Nov 4, 2020
95cdd5a
fix black again
electriclilies Nov 4, 2020
6faea23
Merge branch 'dequantize_expr_scale' into quantization
electriclilies Nov 4, 2020
b1898cb
_average_mean_calibration is fastgit status!
electriclilies Nov 4, 2020
441f88c
remove warning message
electriclilies Nov 5, 2020
b7d2163
move tests
electriclilies Nov 5, 2020
fe1150c
fix import
electriclilies Nov 5, 2020
e54bc6d
fix import again
electriclilies Nov 5, 2020
e2fc37b
try again
electriclilies Nov 5, 2020
76b1aa0
Merge branch 'dequantize_expr_scale' into quantization
electriclilies Nov 5, 2020
33aa7c5
import run_infer_type
electriclilies Nov 5, 2020
966082e
fix import again
electriclilies Nov 5, 2020
4e70bb2
fix import
electriclilies Nov 5, 2020
684ad95
Merge branch 'dequantize_expr_scale' into quantization
electriclilies Nov 5, 2020
9ce1f6f
fix add/mul unit tests
electriclilies Nov 5, 2020
a4894fb
small updates
electriclilies Nov 9, 2020
da12e5b
trying to find bug
electriclilies Nov 9, 2020
a83852c
merge and refactor
electriclilies Nov 16, 2020
3b92e5d
Assign type to scale, zp if channelwise
electriclilies Nov 16, 2020
1847a58
calibration inputs don't require names, some requantize
electriclilies Nov 19, 2020
970fa14
clean up prints
electriclilies Nov 19, 2020
acb1859
push onnx file for matt
electriclilies Nov 19, 2020
ac1963a
Fix pattern matcher bug
electriclilies Nov 19, 2020
1b1dcd5
first attempt at requantizer
electriclilies Nov 20, 2020
1004fd8
add cifar model
electriclilies Nov 20, 2020
941fff9
most of requantize working, except resnets
electriclilies Nov 24, 2020
7c04964
trying to get requantize to work with resnets
electriclilies Dec 2, 2020
337c0cb
not working! :(
electriclilies Dec 3, 2020
9db2647
allow pattern matcher to optionally rewrite overlapping patterns
Dec 7, 2020
80c6970
Merge pull request #1 from electriclilies/lily/quantization
Dec 7, 2020
57057be
update demo
electriclilies Dec 7, 2020
b98289a
Merge branch 'quantization' of github.com:electriclilies/incubator-tv…
electriclilies Dec 7, 2020
2781f0b
let qnn support int32s, quantized BERT builds!
electriclilies Dec 11, 2020
050b431
merge w/ main
electriclilies Dec 11, 2020
bd85023
remove prints
electriclilies Dec 14, 2020
9d43962
added relay node and attrs to calibration map
electriclilies Dec 15, 2020
c8ae881
per channel quantization works now!
electriclilies Dec 17, 2020
1dcc00b
per channel quantization mostly working
electriclilies Dec 18, 2020
b1dd207
good progress on C++ version of quantization
electriclilies Jan 4, 2021
b992b50
more progress on c++ quantization
electriclilies Jan 5, 2021
c6e997d
merge main in
electriclilies Jan 5, 2021
1187e65
most of C++ implemented!
electriclilies Jan 7, 2021
11013ff
C++ done
electriclilies Jan 8, 2021
986e0b1
new quantizer and calibrater
electriclilies Jan 12, 2021
96e5509
add more calibrater and dont add scale/zps to partitioned func args i…
electriclilies Jan 12, 2021
bbfd691
AverageMeanCalibrater and rearranging files
electriclilies Jan 14, 2021
15ec074
AverageMeanQuantize demo works
electriclilies Jan 14, 2021
94ca0ee
skip layers and per channel avg mean quantize
electriclilies Jan 25, 2021
515ae8c
requantize debugging
electriclilies Jan 25, 2021
c303254
fix reshape
electriclilies Jan 25, 2021
18a6da9
requantize working on resnet18 now
electriclilies Jan 25, 2021
cb34a79
moved new_quantize to relay.transform.quantize
electriclilies Jan 27, 2021
262d545
moving to relay.transform.quantize
electriclilies Jan 27, 2021
88a4d4c
Added checks to requantizer
electriclilies Jan 27, 2021
04c919e
changed tolerance
electriclilies Jan 27, 2021
4534266
cleaning some things up
electriclilies Jan 28, 2021
ca6381f
add proper pass, try to add VM
electriclilies Jan 29, 2021
63a1ada
remove vm
electriclilies Jan 29, 2021
d42da94
lots of little bugs, tests
electriclilies Feb 2, 2021
cd848bc
merged with main, allow poverlap in pattern groups
electriclilies Feb 2, 2021
236a2ba
fix some lint
electriclilies Feb 2, 2021
4c48417
Update test_op_qnn_dequantize.py
Feb 2, 2021
bfcab0a
lint and docstrings
electriclilies Feb 3, 2021
7f9b9e9
Merge branch 'quantization' of github.com:electriclilies/incubator-tv…
electriclilies Feb 3, 2021
e90d777
adding docs and lint
electriclilies Feb 3, 2021
102bf73
run black and clang, add docs
electriclilies Feb 4, 2021
96e6879
lint + more doc strings
electriclilies Feb 4, 2021
4f73067
fix spelling of calibrator
electriclilies Feb 4, 2021
14d9618
cleaning up code
electriclilies Feb 5, 2021
a738cf5
black
electriclilies Feb 5, 2021
fa72484
finish quantize and calibrate tests, more docs
electriclilies Feb 5, 2021
8411824
problem with average max quantize??
electriclilies Feb 9, 2021
3f05c65
move dataset manager and add DenseBiasAddPattern
electriclilies Feb 11, 2021
d7e9fc2
fix imports
electriclilies Feb 11, 2021
08e1c67
fix bias add
electriclilies Feb 12, 2021
2496eb1
tutorials
electriclilies Feb 16, 2021
f52a81f
requantizer tests
electriclilies Feb 17, 2021
62a43b5
more tests, fixing bugs in tests
electriclilies Feb 17, 2021
e7a0c8d
test relay pass:
electriclilies Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ bool MatchPattern(DFPattern pattern, Expr expr);
* \return Return An Expr with every match of the pattern inside the callbacks rewritten by the
* functions inside the callbacks
*/
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod = IRModule());
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod = IRModule(),
int allow_overlapping_groups = 0);

/*!
* \brief Partition all matches of a DFPattern inside an Expr into separate Function calls
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int axis;
DataType out_dtype;

TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe(
"The channel axis for channel wise dequantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(out_dtype)
.describe(
"The datatype we are dequantizing to (float32 or int32). Defaults to float32.")
.set_default(DataType::Float(32));
}
};

Expand Down
18 changes: 18 additions & 0 deletions python/tvm/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
from ._dataset_manager import DatasetManager, TFDatasetManager, RandomDatasetManager
136 changes: 136 additions & 0 deletions python/tvm/data/_dataset_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Wrapper classes to expose datasets during quantization."""

import numpy as np

class DatasetManager:
"""Simple wrapper class to expose datasets in quantization."""

def get_next_batch(self):
"""Returns the next batch of data.

Returns
-------
inputs : List
The inputs to be provided to the graph.
The list is of the form [batched_input_1, batched_input_2, ..., batched_input_n]

labels: List
The expected outputs of the graph.
The length of labels should be equal to the batch size.
"""
raise NotImplementedError

def batch_size(self):
"""Returns the size of each batch the dataset manager has.

Returns
-------
batch_size : int
The number of inputs in each batch.
"""
def num_batches(self):
"""Returns the number of batches the dataset manager has.

Returns
------
num_batches : int
The number of batches the dataset manager contains.
"""
raise NotImplementedError

def is_empty(self):
"""Checks whether the dataset manager has gone through
all its batches.
Returns
-------
is_empty : bool
True if there are batches left, False if there are no more
batches.
"""
raise NotImplementedError

def reset(self):
"""Resets the counter in the dataset manager to the beginning."""
raise NotImplementedError


class TFDatasetManager(DatasetManager):
"""DatasetManager wrapping a tensorflow dataset."""

def __init__(self, tf_dataset, batch_size, total_batches):
self.idx = 0
self.total_batches = total_batches
self.batch_sz = batch_size
self.tf_dataset = tf_dataset
self.tf_iter = iter(self.tf_dataset)

def get_next_batch(self):
if self.is_empty():
raise IndexError
self.idx += 1

data, label = next(self.tf_iter)

return [data.numpy()], label.numpy()

def num_batches(self):
return self.total_batches

def batch_size(self):
return self.batch_sz

def is_empty(self):
return self.idx >= self.total_batches

def reset(self):
self.tf_iter = iter(self.tf_dataset)
self.idx = 0


class RandomDatasetManager(DatasetManager):
"""DatasetManager that creates a random input of a specific shape.
This class is mostly used for testing, and as an example of how to
implement a DatasetManager.
"""

def __init__(self, data_shape, dtype, batch_size, total_batches):
self.idx = 0
self.data_shape = data_shape
self.dtype = dtype
self.batch_sz = batch_size
self.total_batches = total_batches

def get_next_batch(self):
if self.is_empty():
raise IndexError
self.idx += 1
return [np.random.randn(*self.data_shape).astype(self.dtype)], [None]

def batch_size(self):
return self.batch_sz

def num_batches(self):
return self.total_batches

def is_empty(self):
return self.idx >= self.total_batches

def reset(self):
self.idx = 0
10 changes: 5 additions & 5 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def __init__(self, require_type=False):
self.pattern = None
self.require_type = require_type

def rewrite(self, expr: Expr) -> Expr:
def rewrite(self, expr: Expr, allow_overlapping_groups: bool = False) -> Expr:
"""
Rewrite expression with this callback

Expand All @@ -813,7 +813,7 @@ def rewrite(self, expr: Expr) -> Expr:
result : tvm.relay.Expr
The Expression with matched subgraphs rewritten by the callbacks.
"""
return rewrite(self, expr)
return rewrite(self, expr, allow_overlapping_groups = allow_overlapping_groups)

def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
"""
Expand Down Expand Up @@ -843,7 +843,8 @@ def __init__(self, pattern, callback, require_type):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)


def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None,
allow_overlapping_groups: bool = False) -> Expr:
"""
Rewrite expression with the given callbacks.

Expand All @@ -868,8 +869,7 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
for callback in callbacks:
assert callback.pattern is not None
tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type))

return ffi.rewrite(tmp, expr, mod)
return ffi.rewrite(tmp, expr, mod, allow_overlapping_groups)


def partition(
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from __future__ import absolute_import as _abs
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.utils import get_pad_tuple2d
from . import _make
from ... import op as reg
from ...op import OpPattern

from . import _make

def requantize(
data,
Expand Down Expand Up @@ -118,7 +117,7 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
return _make.quantize(data, output_scale, output_zero_point, axis, out_dtype)


def dequantize(data, input_scale, input_zero_point, axis=-1):
def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"):
r"""Dequantize op
This operator takes quantized int8 and unit8 as input and produces
dequantized float32 as output. The output shape is the same as input shape. The input
Expand All @@ -134,13 +133,15 @@ def dequantize(data, input_scale, input_zero_point, axis=-1):
The input scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional
The output type to dequantize to. Can be either float32 or int32.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.dequantize(data, input_scale, input_zero_point, axis)
return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype)


def concatenate(data, input_scales, input_zero_points, output_scale, output_zero_point, axis):
Expand Down Expand Up @@ -611,7 +612,6 @@ def subtract(
output_zero_point,
)


# register fuse pattern for qnn ops
reg.register_pattern("qnn.quantize", OpPattern.OPAQUE)
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
47 changes: 47 additions & 0 deletions python/tvm/relay/transform/quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The namespace containing quantization and calibration passes"""
from ._calibration_callback import (
CalibrationCallback,
GlobalCalibrationCallback,
AverageMaxCalibrationCallback,
)
from ._quantizer_patterns import (
QuantizerPattern,
Conv2DBiasAddPattern,
Conv2DPattern,
DensePattern,
DenseBiasAddPattern,
AddPattern,
MultiplyPattern,
PerChannelPattern
)
from ._average_max_channel_patterns import (
AverageMaxPerChannelConv2DBiasAddPattern,
AverageMaxPerChannelConv2DPattern,
AverageMaxPerChannelDenseBiasAddPattern,
AverageMaxPerChannelDensePattern
)

from ._quantizer_pattern_utils import all_patterns, average_max_per_channel_patterns

from ._quantizer import Quantizer
from ._calibrator import QuantizationCalibrator
from ._requantizer import Requantizer

from . import _ffi as ffi
Loading