diff --git a/test.py b/test.py deleted file mode 100644 index c910ccd..0000000 --- a/test.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -import thop - -m = torch.nn.Conv2d(128, 128, 1) -x = torch.randn(1, 128, 16, 16) - -flops = thop.profile(m, inputs=(x,), verbose=True) -fprint(flops) diff --git a/test_jit.py b/test_jit.py new file mode 100644 index 0000000..322c5b3 --- /dev/null +++ b/test_jit.py @@ -0,0 +1,6 @@ +import torch +from torchvision.models import vgg11 +from thop import JitProfile +model1 = vgg11() +input1 = torch.rand(1,3,224,224) +print(JitProfile.calculate_macs(model1,input1)) diff --git a/thop/__init__.py b/thop/__init__.py index 3362b4c..682eaf0 100644 --- a/thop/__init__.py +++ b/thop/__init__.py @@ -1,5 +1,5 @@ from .utils import clever_format from .profile import profile, profile_origin - import torch +from .jit_profile import JitProfile default_dtype = torch.float64 \ No newline at end of file diff --git a/thop/jit_profile.py b/thop/jit_profile.py new file mode 100644 index 0000000..f02cb6d --- /dev/null +++ b/thop/jit_profile.py @@ -0,0 +1,31 @@ +import torch +import numpy as np +from .trace.trace import trace +from thop.vision.jit_handler import handlers + + +class JitProfile(): + def calculate_params(model): + script_model = torch.jit.script(model) + params = 0 + for param in script_model.parameters(): + params += np.prod(param.size()) + print(param) + return params + + def calculate_macs(model, args=(),reduction = sum): + results = dict() + graph = trace(model, args) + for node in graph.nodes: + for operators, func in handlers: + if isinstance(operators, str): + operators = [operators] + if node.operator in operators: + if func is not None: + results[node] = func(node) + break + + if reduction is not None: + return reduction(results.values()) + else: + return results diff --git a/thop/profile.py b/thop/profile.py index cc1e8c4..6c40b54 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -35,6 +35,12 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk)) nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, + nn.LayerNorm: count_ln, + nn.InstanceNorm1d: count_in, + nn.InstanceNorm2d: count_in, + nn.InstanceNorm3d: count_in, + nn.PReLU: count_prelu, + nn.Softmax: count_softmax, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, @@ -67,8 +73,8 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, - nn.Sequential: zero_ops, + } if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): diff --git a/thop/trace/trace.py b/thop/trace/trace.py new file mode 100644 index 0000000..2fe0e0c --- /dev/null +++ b/thop/trace/trace.py @@ -0,0 +1,49 @@ +import torch +import torch.jit +from ..trace.type import Variable,Node,Graph +def trace(model, args = ()): + graph, _ = torch.jit._get_trace_graph(model,args) + variables = {} + #print(graph.__dir__()) + #print(graph.inputs) + #print(graph) + for x in graph.nodes(): + #print(x) + for v in list(x.inputs()) or list(x.outputs()): + if 'tensor' in v.type().kind().lower(): + if 'tensor' in v.type().kind().lower(): + variables[v] = Variable( + name=v.debugName(), + dtype=v.type().scalarType(), + shape=v.type().sizes(), + ) + else: + variables[v] = Variable( + name=v.debugName(), + dtype=str(v.type()), + ) + pass + nodes = [] + for x in graph.nodes(): + node = Node( + operator=x.kind(), + attributes={ + s: getattr(x, x.kindOf(s))(s) + for s in x.attributeNames() + }, + inputs=[variables[v] for v in x.inputs() if v in variables], + outputs=[variables[v] for v in x.outputs() if v in variables], + scope=x.scopeName() \ + .replace('Flatten/', '', 1) \ + .replace('Flatten', '', 1), + ) + nodes.append(node) + graph = Graph( + name=model.__class__.__module__ + '.' + model.__class__.__name__, + variables=[v for v in variables.values()], + inputs=[variables[v] for v in graph.inputs() if v in variables], + outputs=[variables[v] for v in graph.outputs() if v in variables], + nodes=nodes, + ) + return graph + \ No newline at end of file diff --git a/thop/trace/type/__init__.py b/thop/trace/type/__init__.py new file mode 100644 index 0000000..1016bb9 --- /dev/null +++ b/thop/trace/type/__init__.py @@ -0,0 +1,3 @@ +from .variable import Variable +from .node import Node +from .graph import Graph \ No newline at end of file diff --git a/thop/trace/type/graph.py b/thop/trace/type/graph.py new file mode 100644 index 0000000..9b0d8d3 --- /dev/null +++ b/thop/trace/type/graph.py @@ -0,0 +1,10 @@ +__all__ = ['Graph'] + + +class Graph: + def __init__(self, name, variables, inputs, outputs, nodes): + self.name = name + self.variables = variables + self.inputs = inputs + self.outputs = outputs + self.nodes = nodes \ No newline at end of file diff --git a/thop/trace/type/node.py b/thop/trace/type/node.py new file mode 100644 index 0000000..a4cc7b5 --- /dev/null +++ b/thop/trace/type/node.py @@ -0,0 +1,10 @@ +__all__ = ['Node'] + + +class Node: + def __init__(self, operator, attributes, inputs, outputs, scope): + self.operator = operator + self.attributes = attributes + self.inputs = inputs + self.outputs = outputs + self.scope = scope \ No newline at end of file diff --git a/thop/trace/type/variable.py b/thop/trace/type/variable.py new file mode 100644 index 0000000..9f0775d --- /dev/null +++ b/thop/trace/type/variable.py @@ -0,0 +1,42 @@ +__all__ = ['Variable'] + + +class Variable: + def __init__(self, name, dtype, shape=None): + self.name = name + self.dtype = dtype + self.shape = shape + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, dtype): + self._dtype = dtype.lower() + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, shape): + self._shape = shape + + @property + def ndim(self): + return len(self.shape) + + def size(self): + return self.shape + + def dim(self): + return self.ndim \ No newline at end of file diff --git a/thop/vision/__init__.py b/thop/vision/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index e3d7d7d..671fa6a 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -1,6 +1,6 @@ import argparse import logging - +from .counter import * import torch import torch.nn as nn from torch.nn.modules.conv import _ConvNd @@ -12,11 +12,11 @@ def count_parameters(m, x, y): total_params = 0 for p in m.parameters(): total_params += torch.DoubleTensor([p.numel()]) - m.total_params[0] = total_params + m.total_params[0] = counter_parameters(m.parameters()) def zero_ops(m, x, y): - m.total_ops += torch.DoubleTensor([int(0)]) + m.total_ops += counter_zero_ops() def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): @@ -26,9 +26,8 @@ def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): bias_ops = 1 if m.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) - total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops) - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_conv(bias_ops, torch.zeros(m.weight.size() + [2:]).numel(), y.nelement(), m.in_channels, m.groups) def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): @@ -36,24 +35,41 @@ def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): # N x H x W (exclude Cout) output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() - # Cout x Cin x Kw x Kh - kernel_ops = m.weight.nelement() - if m.bias is not None: - # Cout x 1 - kernel_ops += + m.bias.nelement() - # x N x H x W x Cout x (Cin x Kw x Kh + bias) - m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) + # # Cout x Cin x Kw x Kh + # kernel_ops = m.weight.nelement() + # if m.bias is not None: + # # Cout x 1 + # kernel_ops += + m.bias.nelement() + # # x N x H x W x Cout x (Cin x Kw x Kh + bias) + # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) + m.total_ops += counter_conv(m.bias.nelement(), + m.weight.nelement(), output_size) def count_bn(m, x, y): x = x[0] + if not m.training: + m.total_ops += counter_norm(x.numel()) - nelements = x.numel() + +def count_ln(m, x, y): + x = x[0] if not m.training: - # subtract, divide, gamma, beta - total_ops = 2 * nelements + m.total_ops += counter_norm(x.numel()) - m.total_ops += torch.DoubleTensor([int(total_ops)]) + +def count_in(m, x, y): + x = x[0] + if not m.training: + m.total_ops += counter_norm(x.numel()) + + +def count_prelu(m, x, y): + x = x[0] + + nelements = x.numel() + if not m.training: + m.total_ops += counter_relu(nelements) def count_relu(m, x, y): @@ -61,71 +77,45 @@ def count_relu(m, x, y): nelements = x.numel() - m.total_ops += torch.DoubleTensor([int(nelements)]) + m.total_ops += counter_relu(nelements) def count_softmax(m, x, y): x = x[0] + nfeatures = x.size()[m.dim] + batch_size = x.numel()//nfeatures - batch_size, nfeatures = x.size() - - total_exp = nfeatures - total_add = nfeatures - 1 - total_div = nfeatures - total_ops = batch_size * (total_exp + total_add + total_div) - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_softmax(batch_size, nfeatures) def count_avgpool(m, x, y): # total_add = torch.prod(torch.Tensor([m.kernel_size])) # total_div = 1 # kernel_ops = total_add + total_div - kernel_ops = 1 num_elements = y.numel() - total_ops = kernel_ops * num_elements - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_avgpool(num_elements) def count_adap_avgpool(m, x, y): - kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) + kernel = torch.DoubleTensor( + [*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) total_add = torch.prod(kernel) - total_div = 1 - kernel_ops = total_add + total_div num_elements = y.numel() - total_ops = kernel_ops * num_elements - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_adap_avg(total_add, num_elements) # TODO: verify the accuracy def count_upsample(m, x, y): if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" - logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) - return zero_ops(m, x, y) + logging.warning( + "mode %s is not implemented yet, take it a zero op" % m.mode) + return counter_zero_ops() if m.mode == "nearest": - return zero_ops(m, x, y) + return counter_zero_ops() x = x[0] - if m.mode == "linear": - total_ops = y.nelement() * 5 # 2 muls + 3 add - elif m.mode == "bilinear": - # https://en.wikipedia.org/wiki/Bilinear_interpolation - total_ops = y.nelement() * 11 # 6 muls + 5 adds - elif m.mode == "bicubic": - # https://en.wikipedia.org/wiki/Bicubic_interpolation - # Product matrix [4x4] x [4x4] x [4x4] - ops_solve_A = 224 # 128 muls + 96 adds - ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds - total_ops = y.nelement() * (ops_solve_A + ops_solve_p) - elif m.mode == "trilinear": - # https://en.wikipedia.org/wiki/Trilinear_interpolation - # can viewed as 2 bilinear + 1 linear - total_ops = y.nelement() * (13 * 2 + 5) - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_upsample(m.mode, y.nelement()) # nn.Linear @@ -135,6 +125,5 @@ def count_linear(m, x, y): # total_add = m.in_features - 1 # total_add += 1 if m.bias is not None else 0 num_elements = y.numel() - total_ops = total_mul * num_elements - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_linear(total_mul, num_elements) diff --git a/thop/vision/counter.py b/thop/vision/counter.py new file mode 100644 index 0000000..0b2c5eb --- /dev/null +++ b/thop/vision/counter.py @@ -0,0 +1,105 @@ +import torch +import numpy as np + + +def counter_parameters(para_list): + total_params = 0 + for p in para_list: + total_params += torch.DoubleTensor([p.nelement()]) + return total_params + + +def counter_zero_ops(): + return torch.DoubleTensor([int(0)]) + + +def counter_conv(bias, kernel_size, output_size, in_channel, group): + """inputs are all numbers!""" + return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)]) + + +def counter_norm(input_size): + """input is a number not a array or tensor""" + return torch.DoubleTensor([2 * input_size]) + + +def counter_relu(input_size: torch.Tensor): + return torch.DoubleTensor([int(input_size)]) + + +def counter_softmax(batch_size, nfeatures): + total_exp = nfeatures + total_add = nfeatures - 1 + total_div = nfeatures + total_ops = batch_size * (total_exp + total_add + total_div) + return torch.DoubleTensor([int(total_ops)]) + + +def counter_avgpool(input_size): + return torch.DoubleTensor([int(input_size)]) + + +def counter_adap_avg(kernel_size, output_size): + total_div = 1 + kernel_op = kernel_size + total_div + return torch.DoubleTensor([int(kernel_op * output_size)]) + + +def counter_upsample(mode: str, output_size): + total_ops = output_size + if mode == "linear": + total_ops *= 5 + elif mode == "bilinear": + total_ops *= 11 + elif mode == "bicubic": + ops_solve_A = 224 # 128 muls + 96 adds + ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds + total_ops *= (ops_solve_A + ops_solve_p) + elif mode == "trilinear": + total_ops *= (13 * 2 + 5) + return torch.DoubleTensor([int(total_ops)]) + + +def counter_linear(in_feature, num_elements): + return torch.DoubleTensor([int(in_feature * num_elements)]) + + +def counter_matmul(input_size, output_size): + """an error to be fixed""" + input_size = np.array(input_size) + output_size = np.array(output_size) + return np.prod(input_size) * output_size[-1] + + +def counter_mul(input_size): + return input_size + + +def counter_pow(input_size): + return input_size + + +def counter_sqrt(input_size): + return input_size + + +def counter_div(input_size): + return input_size + +# jit profile + + +def counter_addmm(input_size1, input_size2): + n, m = input_size1 + m, p = input_size2 + return n * m * p + + +def counter_addmv(input_size): + n, m = input_size + return n * m + +def counter_bmm(input_size1,input_size2): + b, n, m = input_size1 + b, m, p = input_size2 + return b * m * n * p diff --git a/thop/vision/jit_handler.py b/thop/vision/jit_handler.py new file mode 100644 index 0000000..45df0c6 --- /dev/null +++ b/thop/vision/jit_handler.py @@ -0,0 +1,145 @@ +import numpy as np +__all__ = ['handlers'] +from thop.vision.counter import counter_mul, counter_addmm,\ + counter_addmv,counter_bmm,counter_matmul,counter_avgpool,\ + counter_relu + + +def addmm(node): + # [n, p] = aten::addmm([n, p], [n, m], [m, p], *, *) + n, m = node.inputs[1].shape + m, p = node.inputs[2].shape + + return counter_addmm(node.inputs[1].shape, node.inputs[2].shape) + + +def addmv(node): + # [n] = aten::addmv([n], [n, m], [m], *, *) + return counter_addmv(node.inputs[1].shape) + + +def bmm(node): + # [b, n, p] = aten::bmm([b, n, m], [b, m, p]) + b, n, m = node.inputs[0].shape + b, m, p = node.inputs[1].shape + return counter_bmm(node.inputs[0].shape,node.inputs[1].shape) + + +def matmul(node): + if node.inputs[0].ndim == 1 and node.inputs[1].ndim == 1: + # [] = aten::matmul([n], [n]) + n = node.inputs[0].shape[0] + return counter_mul(n) + elif node.inputs[0].ndim == 1 and node.inputs[1].ndim == 2: + # [m] = aten::matmul([n], [n, m]) + # n, m = node.inputs[1].shape + # return n * m + return counter_mul(np.prod(node.inputs[1].shape)) + elif node.inputs[0].ndim == 2 and node.inputs[1].ndim == 1: + # [n] = aten::matmul([n, m], [m]) + # n, m = node.inputs[0].shape + # return n * m + return counter_mul(np.prod(node.inputs[0].shape)) + elif node.inputs[0].ndim == 2 and node.inputs[1].ndim == 2: + # [n, p] = aten::matmul([n, m], [m, p]) + # n, m = node.inputs[0].shape + # m, p = node.inputs[1].shape + # return n * m * p + return counter_matmul(node.inputs[0].shape,node.inputs[1].shape) + elif node.inputs[0].ndim == 1: + # [..., m] = aten::matmul([n], [..., n, m]) + # *b, n, m = node.inputs[1].shape + # return np.prod(b) * n * m + return counter_mul(np.prod(node.inputs[1].shape)) + elif node.inputs[1].ndim == 1: + # # [..., n] = aten::matmul([..., n, m], [m]) + # *b, n, m = node.inputs[0].shape + return counter_mul(np.prod(node.inputs[0].shape)) + else: + # [..., n, p] = aten::matmul([..., n, m], [..., m, p]) + # *b, n, p = node.outputs[0].shape + # *_, n, m = node.inputs[0].shape + # *_, m, p = node.inputs[1].shape + + # return np.prod(b) * n * m * p + return counter_matmul(node.outputs[0].shape,node.inputs[1].shape[-2:]) + + +def mul(node): + return counter_mul(np.prod(node.outputs[0].shape)) + + +def convolution(node): + if node.outputs[0].shape[1] == node.inputs[1].shape[0]: + oc, ic, *ks = node.inputs[1].shape + else: + ic, oc, *ks = node.inputs[1].shape + os = node.outputs[0].shape + return np.prod(os) * ic * np.prod(ks) + + +def norm(node): + if node.operator in ['aten::batch_norm', 'aten::instance_norm']: + affine = node.inputs[1].shape is not None + elif node.operator in ['aten::layer_norm', 'aten::group_norm']: + affine = node.inputs[2].shape is not None + else: + raise ValueError(node.operator) + + os = node.outputs[0].shape + return np.prod(os) if affine else 0 + + +def avg_pool_or_mean(node): + print("good") + os = node.outputs[0].shape + #return np.prod(os) + return counter_avgpool(np.prod(node.outputs[0].shape)) + + +def leaky_relu(node): + return counter_relu(np.prod(node.outputs[0].shape)) + + +def upsample_bilinear2d(node): + os = node.outputs[0].shape + return np.prod(os) * 4 + + +handlers = ( + ('aten::addmm', addmm), + ('aten::addmv', addmv), + ('aten::bmm', bmm), + (('aten::linear', 'aten::matmul'), matmul), + (('aten::mul', 'aten::mul_'), mul), + ('aten::_convolution', convolution), + (('aten::batch_norm', 'aten::instance_norm', 'aten::layer_norm', + 'aten::group_norm'), norm), + (('aten::adaptive_avg_pool1d', 'aten::adaptive_avg_pool2d', + 'aten::adaptive_avg_pool3d', 'aten::avg_pool1d', 'aten::avg_pool2d', + 'aten::avg_pool3d', 'aten::mean'), avg_pool_or_mean), + ('aten::leaky_relu', leaky_relu), + ('aten::upsample_bilinear2d', upsample_bilinear2d), + (('aten::adaptive_max_pool1d', 'aten::adaptive_max_pool2d', + 'aten::adaptive_max_pool3d', 'aten::add', 'aten::add_', + 'aten::alpha_dropout', 'aten::cat', 'aten::chunk', 'aten::clamp', + 'aten::clone', 'aten::constant_pad_nd', 'aten::contiguous', + 'aten::detach', 'aten::div', 'aten::div_', 'aten::dropout', + 'aten::dropout_', 'aten::embedding', 'aten::eq', 'aten::feature_dropout', + 'aten::flatten', 'aten::floor', 'aten::floor_divide', 'aten::gt', + 'aten::hardtanh_', 'aten::hardtanh', 'aten::index', 'aten::int', 'aten::log_softmax', + 'aten::lt', 'aten::max_pool1d', 'aten::max_pool1d_with_indices', + 'aten::max_pool2d', 'aten::max_pool2d_with_indices', 'aten::max_pool3d', + 'aten::max_pool3d_with_indices', 'aten::max_unpool1d', + 'aten::max_unpool2d', 'aten::max_unpool3d', 'aten::ne', + 'aten::reflection_pad1d', 'aten::reflection_pad2d', + 'aten::reflection_pad3d', 'aten::relu', 'aten::relu_', + 'aten::replication_pad1d', 'aten::replication_pad2d', + 'aten::replication_pad3d', 'aten::rsub', 'aten::select', 'aten::sigmoid', + 'aten::size', 'aten::slice', 'aten::softmax', 'aten::softshrink', + 'aten::squeeze', 'aten::stack', 'aten::sub', 'aten::sum', 'aten::t', + 'aten::tanh', 'aten::threshold', 'aten::to', 'aten::transpose', + 'aten::upsample_nearest2d', 'aten::view', 'aten::zeros', + 'prim::constant', 'prim::listconstruct', 'prim::listunpack', + 'prim::numtotensor', 'prim::tupleconstruct'), None), +)