Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Jit profile #150

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions test.py

This file was deleted.

6 changes: 6 additions & 0 deletions test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
from torchvision.models import vgg11
from thop import JitProfile
model1 = vgg11()
input1 = torch.rand(1,3,224,224)
print(JitProfile.calculate_macs(model1,input1))
2 changes: 1 addition & 1 deletion thop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils import clever_format
from .profile import profile, profile_origin

import torch
from .jit_profile import JitProfile
default_dtype = torch.float64
31 changes: 31 additions & 0 deletions thop/jit_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import numpy as np
from .trace.trace import trace
from thop.vision.jit_handler import handlers


class JitProfile():
def calculate_params(model):
script_model = torch.jit.script(model)
params = 0
for param in script_model.parameters():
params += np.prod(param.size())
print(param)
return params

def calculate_macs(model, args=(),reduction = sum):
results = dict()
graph = trace(model, args)
for node in graph.nodes:
for operators, func in handlers:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an O(N) operation and will become slow when the number of handlers increases. I suggest to rework it to dictionary.

if isinstance(operators, str):
operators = [operators]
if node.operator in operators:
if func is not None:
results[node] = func(node)
break

if reduction is not None:
return reduction(results.values())
else:
return results
8 changes: 7 additions & 1 deletion thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk))
nn.BatchNorm1d: count_bn,
nn.BatchNorm2d: count_bn,
nn.BatchNorm3d: count_bn,
nn.LayerNorm: count_ln,
nn.InstanceNorm1d: count_in,
nn.InstanceNorm2d: count_in,
nn.InstanceNorm3d: count_in,
nn.PReLU: count_prelu,
nn.Softmax: count_softmax,

nn.ReLU: zero_ops,
nn.ReLU6: zero_ops,
Expand Down Expand Up @@ -67,8 +73,8 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk))
nn.RNN: count_rnn,
nn.GRU: count_gru,
nn.LSTM: count_lstm,

nn.Sequential: zero_ops,

}

if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
Expand Down
49 changes: 49 additions & 0 deletions thop/trace/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torch.jit
from ..trace.type import Variable,Node,Graph
def trace(model, args = ()):
graph, _ = torch.jit._get_trace_graph(model,args)
variables = {}
#print(graph.__dir__())
#print(graph.inputs)
#print(graph)
for x in graph.nodes():
#print(x)
for v in list(x.inputs()) or list(x.outputs()):
if 'tensor' in v.type().kind().lower():
if 'tensor' in v.type().kind().lower():
variables[v] = Variable(
name=v.debugName(),
dtype=v.type().scalarType(),
shape=v.type().sizes(),
)
else:
variables[v] = Variable(
name=v.debugName(),
dtype=str(v.type()),
)
pass
nodes = []
for x in graph.nodes():
node = Node(
operator=x.kind(),
attributes={
s: getattr(x, x.kindOf(s))(s)
for s in x.attributeNames()
},
inputs=[variables[v] for v in x.inputs() if v in variables],
outputs=[variables[v] for v in x.outputs() if v in variables],
scope=x.scopeName() \
.replace('Flatten/', '', 1) \
.replace('Flatten', '', 1),
)
nodes.append(node)
graph = Graph(
name=model.__class__.__module__ + '.' + model.__class__.__name__,
variables=[v for v in variables.values()],
inputs=[variables[v] for v in graph.inputs() if v in variables],
outputs=[variables[v] for v in graph.outputs() if v in variables],
nodes=nodes,
)
return graph

3 changes: 3 additions & 0 deletions thop/trace/type/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .variable import Variable
from .node import Node
from .graph import Graph
10 changes: 10 additions & 0 deletions thop/trace/type/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = ['Graph']


class Graph:
def __init__(self, name, variables, inputs, outputs, nodes):
self.name = name
self.variables = variables
self.inputs = inputs
self.outputs = outputs
self.nodes = nodes
10 changes: 10 additions & 0 deletions thop/trace/type/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = ['Node']


class Node:
def __init__(self, operator, attributes, inputs, outputs, scope):
self.operator = operator
self.attributes = attributes
self.inputs = inputs
self.outputs = outputs
self.scope = scope
42 changes: 42 additions & 0 deletions thop/trace/type/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
__all__ = ['Variable']


class Variable:
def __init__(self, name, dtype, shape=None):
self.name = name
self.dtype = dtype
self.shape = shape

@property
def name(self):
return self._name

@name.setter
def name(self, name):
self._name = name

@property
def dtype(self):
return self._dtype

@dtype.setter
def dtype(self, dtype):
self._dtype = dtype.lower()

@property
def shape(self):
return self._shape

@shape.setter
def shape(self, shape):
self._shape = shape

@property
def ndim(self):
return len(self.shape)

def size(self):
return self.shape

def dim(self):
return self.ndim
Empty file removed thop/vision/__init__.py
Empty file.
105 changes: 47 additions & 58 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import logging

from .counter import *
import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
Expand All @@ -12,11 +12,11 @@ def count_parameters(m, x, y):
total_params = 0
for p in m.parameters():
total_params += torch.DoubleTensor([p.numel()])
m.total_params[0] = total_params
m.total_params[0] = counter_parameters(m.parameters())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call count_parameters() here? total_params already provides the number of parameters of the model.



def zero_ops(m, x, y):
m.total_ops += torch.DoubleTensor([int(0)])
m.total_ops += counter_zero_ops()


def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
Expand All @@ -26,106 +26,96 @@ def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
bias_ops = 1 if m.bias is not None else 0

# N x Cout x H x W x (Cin x Kw x Kh + bias)
total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops)

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_conv(bias_ops, torch.zeros(m.weight.size()
[2:]).numel(), y.nelement(), m.in_channels, m.groups)


def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
x = x[0]

# N x H x W (exclude Cout)
output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel()
# Cout x Cin x Kw x Kh
kernel_ops = m.weight.nelement()
if m.bias is not None:
# Cout x 1
kernel_ops += + m.bias.nelement()
# x N x H x W x Cout x (Cin x Kw x Kh + bias)
m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)])
# # Cout x Cin x Kw x Kh
# kernel_ops = m.weight.nelement()
# if m.bias is not None:
# # Cout x 1
# kernel_ops += + m.bias.nelement()
# # x N x H x W x Cout x (Cin x Kw x Kh + bias)
# m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)])
m.total_ops += counter_conv(m.bias.nelement(),
m.weight.nelement(), output_size)


def count_bn(m, x, y):
x = x[0]
if not m.training:
m.total_ops += counter_norm(x.numel())

nelements = x.numel()

def count_ln(m, x, y):
x = x[0]
if not m.training:
# subtract, divide, gamma, beta
total_ops = 2 * nelements
m.total_ops += counter_norm(x.numel())

m.total_ops += torch.DoubleTensor([int(total_ops)])

def count_in(m, x, y):
x = x[0]
if not m.training:
m.total_ops += counter_norm(x.numel())


def count_prelu(m, x, y):
x = x[0]

nelements = x.numel()
if not m.training:
m.total_ops += counter_relu(nelements)


def count_relu(m, x, y):
x = x[0]

nelements = x.numel()

m.total_ops += torch.DoubleTensor([int(nelements)])
m.total_ops += counter_relu(nelements)


def count_softmax(m, x, y):
x = x[0]
nfeatures = x.size()[m.dim]
batch_size = x.numel()//nfeatures

batch_size, nfeatures = x.size()

total_exp = nfeatures
total_add = nfeatures - 1
total_div = nfeatures
total_ops = batch_size * (total_exp + total_add + total_div)

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_softmax(batch_size, nfeatures)


def count_avgpool(m, x, y):
# total_add = torch.prod(torch.Tensor([m.kernel_size]))
# total_div = 1
# kernel_ops = total_add + total_div
kernel_ops = 1
num_elements = y.numel()
total_ops = kernel_ops * num_elements

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_avgpool(num_elements)


def count_adap_avgpool(m, x, y):
kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])])
kernel = torch.DoubleTensor(
[*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])])
total_add = torch.prod(kernel)
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_adap_avg(total_add, num_elements)


# TODO: verify the accuracy
def count_upsample(m, x, y):
if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear"
logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode)
return zero_ops(m, x, y)
logging.warning(
"mode %s is not implemented yet, take it a zero op" % m.mode)
return counter_zero_ops()

if m.mode == "nearest":
return zero_ops(m, x, y)
return counter_zero_ops()

x = x[0]
if m.mode == "linear":
total_ops = y.nelement() * 5 # 2 muls + 3 add
elif m.mode == "bilinear":
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
elif m.mode == "bicubic":
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == "trilinear":
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_upsample(m.mode, y.nelement())


# nn.Linear
Expand All @@ -135,6 +125,5 @@ def count_linear(m, x, y):
# total_add = m.in_features - 1
# total_add += 1 if m.bias is not None else 0
num_elements = y.numel()
total_ops = total_mul * num_elements

m.total_ops += torch.DoubleTensor([int(total_ops)])
m.total_ops += counter_linear(total_mul, num_elements)
Loading