-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Jit profile #150
Open
HaoKang-Timmy
wants to merge
26
commits into
Lyken17:master
Choose a base branch
from
HaoKang-Timmy:jit
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Adding Jit profile #150
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
b25be5e
Add hook for nn.Transformer
HaoKang-Timmy 03064d2
Add a function to count macs of nn.Transformer
HaoKang-Timmy fd9f2ec
add a example to test transformer
HaoKang-Timmy d33dbe2
some hooks and edit the softmax
HaoKang-Timmy d632a80
a lot changes
HaoKang-Timmy e1633e6
edit softmax hook
HaoKang-Timmy 8918cf5
Merge branch 'master' into master
HaoKang-Timmy f1a7805
update nn.Transformer
HaoKang-Timmy 9194cd2
edit bugs in nn.Transformers
HaoKang-Timmy f6fa0b0
change about examples
HaoKang-Timmy dd9f41d
change about count_Transformer
HaoKang-Timmy 9d1d0d3
delete print information
HaoKang-Timmy 2034b86
onnx fixed
9492d96
onnx_basic_fixed
4c447f3
delete transformer
6a6a0f0
onnx_counter
1729c7f
onnx_counter
8d8e738
finish onnx
61f83d1
already test models
977a79a
Merge branch 'master' into onnx
HaoKang-Timmy 5ba599a
finish jit
e157545
counter
4efdb87
jit finished
fac3450
jit finished
cfa6f8d
jit finished
d251fe5
no need to change this
HaoKang-Timmy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch | ||
from torchvision.models import vgg11 | ||
from thop import JitProfile | ||
model1 = vgg11() | ||
input1 = torch.rand(1,3,224,224) | ||
print(JitProfile.calculate_macs(model1,input1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .utils import clever_format | ||
from .profile import profile, profile_origin | ||
|
||
import torch | ||
from .jit_profile import JitProfile | ||
default_dtype = torch.float64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import numpy as np | ||
from .trace.trace import trace | ||
from thop.vision.jit_handler import handlers | ||
|
||
|
||
class JitProfile(): | ||
def calculate_params(model): | ||
script_model = torch.jit.script(model) | ||
params = 0 | ||
for param in script_model.parameters(): | ||
params += np.prod(param.size()) | ||
print(param) | ||
return params | ||
|
||
def calculate_macs(model, args=(),reduction = sum): | ||
results = dict() | ||
graph = trace(model, args) | ||
for node in graph.nodes: | ||
for operators, func in handlers: | ||
if isinstance(operators, str): | ||
operators = [operators] | ||
if node.operator in operators: | ||
if func is not None: | ||
results[node] = func(node) | ||
break | ||
|
||
if reduction is not None: | ||
return reduction(results.values()) | ||
else: | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
import torch.jit | ||
from ..trace.type import Variable,Node,Graph | ||
def trace(model, args = ()): | ||
graph, _ = torch.jit._get_trace_graph(model,args) | ||
variables = {} | ||
#print(graph.__dir__()) | ||
#print(graph.inputs) | ||
#print(graph) | ||
for x in graph.nodes(): | ||
#print(x) | ||
for v in list(x.inputs()) or list(x.outputs()): | ||
if 'tensor' in v.type().kind().lower(): | ||
if 'tensor' in v.type().kind().lower(): | ||
variables[v] = Variable( | ||
name=v.debugName(), | ||
dtype=v.type().scalarType(), | ||
shape=v.type().sizes(), | ||
) | ||
else: | ||
variables[v] = Variable( | ||
name=v.debugName(), | ||
dtype=str(v.type()), | ||
) | ||
pass | ||
nodes = [] | ||
for x in graph.nodes(): | ||
node = Node( | ||
operator=x.kind(), | ||
attributes={ | ||
s: getattr(x, x.kindOf(s))(s) | ||
for s in x.attributeNames() | ||
}, | ||
inputs=[variables[v] for v in x.inputs() if v in variables], | ||
outputs=[variables[v] for v in x.outputs() if v in variables], | ||
scope=x.scopeName() \ | ||
.replace('Flatten/', '', 1) \ | ||
.replace('Flatten', '', 1), | ||
) | ||
nodes.append(node) | ||
graph = Graph( | ||
name=model.__class__.__module__ + '.' + model.__class__.__name__, | ||
variables=[v for v in variables.values()], | ||
inputs=[variables[v] for v in graph.inputs() if v in variables], | ||
outputs=[variables[v] for v in graph.outputs() if v in variables], | ||
nodes=nodes, | ||
) | ||
return graph | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .variable import Variable | ||
from .node import Node | ||
from .graph import Graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
__all__ = ['Graph'] | ||
|
||
|
||
class Graph: | ||
def __init__(self, name, variables, inputs, outputs, nodes): | ||
self.name = name | ||
self.variables = variables | ||
self.inputs = inputs | ||
self.outputs = outputs | ||
self.nodes = nodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
__all__ = ['Node'] | ||
|
||
|
||
class Node: | ||
def __init__(self, operator, attributes, inputs, outputs, scope): | ||
self.operator = operator | ||
self.attributes = attributes | ||
self.inputs = inputs | ||
self.outputs = outputs | ||
self.scope = scope |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
__all__ = ['Variable'] | ||
|
||
|
||
class Variable: | ||
def __init__(self, name, dtype, shape=None): | ||
self.name = name | ||
self.dtype = dtype | ||
self.shape = shape | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
@name.setter | ||
def name(self, name): | ||
self._name = name | ||
|
||
@property | ||
def dtype(self): | ||
return self._dtype | ||
|
||
@dtype.setter | ||
def dtype(self, dtype): | ||
self._dtype = dtype.lower() | ||
|
||
@property | ||
def shape(self): | ||
return self._shape | ||
|
||
@shape.setter | ||
def shape(self, shape): | ||
self._shape = shape | ||
|
||
@property | ||
def ndim(self): | ||
return len(self.shape) | ||
|
||
def size(self): | ||
return self.shape | ||
|
||
def dim(self): | ||
return self.ndim |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import argparse | ||
import logging | ||
|
||
from .counter import * | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn.modules.conv import _ConvNd | ||
|
@@ -12,11 +12,11 @@ def count_parameters(m, x, y): | |
total_params = 0 | ||
for p in m.parameters(): | ||
total_params += torch.DoubleTensor([p.numel()]) | ||
m.total_params[0] = total_params | ||
m.total_params[0] = counter_parameters(m.parameters()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why call count_parameters() here? |
||
|
||
|
||
def zero_ops(m, x, y): | ||
m.total_ops += torch.DoubleTensor([int(0)]) | ||
m.total_ops += counter_zero_ops() | ||
|
||
|
||
def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): | ||
|
@@ -26,106 +26,96 @@ def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): | |
bias_ops = 1 if m.bias is not None else 0 | ||
|
||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | ||
total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops) | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_conv(bias_ops, torch.zeros(m.weight.size() | ||
[2:]).numel(), y.nelement(), m.in_channels, m.groups) | ||
|
||
|
||
def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): | ||
x = x[0] | ||
|
||
# N x H x W (exclude Cout) | ||
output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() | ||
# Cout x Cin x Kw x Kh | ||
kernel_ops = m.weight.nelement() | ||
if m.bias is not None: | ||
# Cout x 1 | ||
kernel_ops += + m.bias.nelement() | ||
# x N x H x W x Cout x (Cin x Kw x Kh + bias) | ||
m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) | ||
# # Cout x Cin x Kw x Kh | ||
# kernel_ops = m.weight.nelement() | ||
# if m.bias is not None: | ||
# # Cout x 1 | ||
# kernel_ops += + m.bias.nelement() | ||
# # x N x H x W x Cout x (Cin x Kw x Kh + bias) | ||
# m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) | ||
m.total_ops += counter_conv(m.bias.nelement(), | ||
m.weight.nelement(), output_size) | ||
|
||
|
||
def count_bn(m, x, y): | ||
x = x[0] | ||
if not m.training: | ||
m.total_ops += counter_norm(x.numel()) | ||
|
||
nelements = x.numel() | ||
|
||
def count_ln(m, x, y): | ||
x = x[0] | ||
if not m.training: | ||
# subtract, divide, gamma, beta | ||
total_ops = 2 * nelements | ||
m.total_ops += counter_norm(x.numel()) | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
|
||
def count_in(m, x, y): | ||
x = x[0] | ||
if not m.training: | ||
m.total_ops += counter_norm(x.numel()) | ||
|
||
|
||
def count_prelu(m, x, y): | ||
x = x[0] | ||
|
||
nelements = x.numel() | ||
if not m.training: | ||
m.total_ops += counter_relu(nelements) | ||
|
||
|
||
def count_relu(m, x, y): | ||
x = x[0] | ||
|
||
nelements = x.numel() | ||
|
||
m.total_ops += torch.DoubleTensor([int(nelements)]) | ||
m.total_ops += counter_relu(nelements) | ||
|
||
|
||
def count_softmax(m, x, y): | ||
x = x[0] | ||
nfeatures = x.size()[m.dim] | ||
batch_size = x.numel()//nfeatures | ||
|
||
batch_size, nfeatures = x.size() | ||
|
||
total_exp = nfeatures | ||
total_add = nfeatures - 1 | ||
total_div = nfeatures | ||
total_ops = batch_size * (total_exp + total_add + total_div) | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_softmax(batch_size, nfeatures) | ||
|
||
|
||
def count_avgpool(m, x, y): | ||
# total_add = torch.prod(torch.Tensor([m.kernel_size])) | ||
# total_div = 1 | ||
# kernel_ops = total_add + total_div | ||
kernel_ops = 1 | ||
num_elements = y.numel() | ||
total_ops = kernel_ops * num_elements | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_avgpool(num_elements) | ||
|
||
|
||
def count_adap_avgpool(m, x, y): | ||
kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) | ||
kernel = torch.DoubleTensor( | ||
[*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) | ||
total_add = torch.prod(kernel) | ||
total_div = 1 | ||
kernel_ops = total_add + total_div | ||
num_elements = y.numel() | ||
total_ops = kernel_ops * num_elements | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_adap_avg(total_add, num_elements) | ||
|
||
|
||
# TODO: verify the accuracy | ||
def count_upsample(m, x, y): | ||
if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" | ||
logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) | ||
return zero_ops(m, x, y) | ||
logging.warning( | ||
"mode %s is not implemented yet, take it a zero op" % m.mode) | ||
return counter_zero_ops() | ||
|
||
if m.mode == "nearest": | ||
return zero_ops(m, x, y) | ||
return counter_zero_ops() | ||
|
||
x = x[0] | ||
if m.mode == "linear": | ||
total_ops = y.nelement() * 5 # 2 muls + 3 add | ||
elif m.mode == "bilinear": | ||
# https://en.wikipedia.org/wiki/Bilinear_interpolation | ||
total_ops = y.nelement() * 11 # 6 muls + 5 adds | ||
elif m.mode == "bicubic": | ||
# https://en.wikipedia.org/wiki/Bicubic_interpolation | ||
# Product matrix [4x4] x [4x4] x [4x4] | ||
ops_solve_A = 224 # 128 muls + 96 adds | ||
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds | ||
total_ops = y.nelement() * (ops_solve_A + ops_solve_p) | ||
elif m.mode == "trilinear": | ||
# https://en.wikipedia.org/wiki/Trilinear_interpolation | ||
# can viewed as 2 bilinear + 1 linear | ||
total_ops = y.nelement() * (13 * 2 + 5) | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_upsample(m.mode, y.nelement()) | ||
|
||
|
||
# nn.Linear | ||
|
@@ -135,6 +125,5 @@ def count_linear(m, x, y): | |
# total_add = m.in_features - 1 | ||
# total_add += 1 if m.bias is not None else 0 | ||
num_elements = y.numel() | ||
total_ops = total_mul * num_elements | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
m.total_ops += counter_linear(total_mul, num_elements) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an O(N) operation and will become slow when the number of handlers increases. I suggest to rework it to dictionary.