Skip to content

Commit

Permalink
Add Winograd matrices computation. (apache#3553)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored and wweic committed Aug 9, 2019
1 parent e81d7ce commit 77521b9
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 187 deletions.
61 changes: 7 additions & 54 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,19 @@

import warnings

import numpy as np

import tvm
from tvm import autotvm
import tvm.contrib.nnpack

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
Expand Down Expand Up @@ -330,57 +329,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
r = 3
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)

K = CO
C = CI

Expand All @@ -405,15 +361,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed:
U = kernel
else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

# transform image
B = const_matrix(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
Expand All @@ -427,7 +381,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')

# inverse transform
A = const_matrix(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
Expand Down
61 changes: 7 additions & 54 deletions topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend"""

import numpy as np

import tvm
from tvm import autotvm

from .. import nn
from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, const_matrix, traverse_inline
from ..util import get_const_int, get_const_tuple, traverse_inline
from ..generic import schedule_conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices


def _infer_tile_size(data, kernel):
Expand Down Expand Up @@ -54,7 +53,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
CO, CI, KH, KW = get_const_tuple(kernel.shape)
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel)
HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3
assert HSTR == 1 and WSTR == 1 and KH == KW
else: # kernel tensor is pre-transfomred. this op is created by
# alter op layout, do not check
# dilation is not supported
Expand All @@ -65,62 +64,18 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty

data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
r = 3
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)

H = (H + 2 * HPAD - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

# transform kernel
if not pre_computed:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), name='r_kh')
r_kw = tvm.reduce_axis((0, KW), name='r_kw')
kernel_pack = tvm.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
Expand All @@ -136,7 +91,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
[p % nW * m + nu], name='d')

# transform data
B = const_matrix(B_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
data_pack = tvm.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p:
Expand All @@ -151,7 +105,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
axis=[ci]), name='bgemm')

# inverse transform
A = const_matrix(A_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
Expand Down
60 changes: 6 additions & 54 deletions topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
"""conv2d schedule on ARM Mali GPU"""
import numpy as np

import tvm
from tvm import autotvm
from tvm.autotvm.task.space import get_factors

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_int, get_const_tuple, const_matrix
from ..util import traverse_inline, get_const_int, get_const_tuple
from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout
from ..nn.winograd_util import winograd_transform_matrices

# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
Expand Down Expand Up @@ -226,57 +225,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], out_dtype)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], out_dtype)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
r = 3
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)

H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
Expand Down Expand Up @@ -321,15 +276,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed:
U = kernel
else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco:
tvm.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kh, r_kw]), name='U')

# transform image
B = const_matrix(B_data, 'B')
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b')
V = tvm.compute((alpha, alpha, P_round // bnb, CI, bnb), lambda eps, nu, p, ci, vp:
Expand All @@ -342,7 +295,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm.sum(U[eps][nu][co // bna][ci][co % bna] *
V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M')

A = const_matrix(A_data, 'A')
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b')
Y = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
Expand Down
31 changes: 6 additions & 25 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
"""Conv2D operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import numpy as np
import tvm

from .pad import pad
from .util import get_pad_tuple
from ..util import simplify, const_matrix, get_const_tuple
from ..util import simplify, get_const_tuple
from .winograd_util import winograd_transform_matrices

# workload description of conv2d
Workload = namedtuple('Workload',
Expand Down Expand Up @@ -425,7 +425,7 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
Parameters
----------
kernel: Tensor
The raw kernel tensor with layout "NCHW". Only 3x3 kernel is supported for now
The raw kernel tensor with layout "NCHW".
tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Expand All @@ -434,34 +434,15 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
output : tvm.Tensor
4-D with shape [alpha, alpha, CO, CI]
"""
K = 3

shape = get_const_tuple(kernel.shape)
assert shape[2:] == (K, K), "Only support 3x3 kernel"
assert shape[2] == shape[3], "Only support NxN kernel"

K = shape[3]
r = tile_size + K - 1
shape = (r, r) + shape[:2]

if tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1],
], dtype=kernel.dtype)
elif tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]
], dtype=kernel.dtype)
else:
raise ValueError("Unsupoorted tile size:" + tile_size)
_, _, G = winograd_transform_matrices(tile_size, K, kernel.dtype)

G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, K), name='r_kh')
r_kw = tvm.reduce_axis((0, K), name='r_kw')
return tvm.compute(shape, lambda eps, nu, co, ci:
Expand Down
Loading

0 comments on commit 77521b9

Please sign in to comment.