Skip to content

Commit

Permalink
[TOPI] Memoize winograd matrix (#3687)
Browse files Browse the repository at this point in the history
* [TOPI] Memoize winograd matrix

* lint

* Fix name
  • Loading branch information
merrymercy authored and tqchen committed Aug 2, 2019
1 parent 33ab3c6 commit 7de8a3a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
16 changes: 11 additions & 5 deletions python/tvm/contrib/pickle_memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class Cache(object):
----------
key: str
The file key to the function
save_at_exit: bool
Whether save the cache to file when the program exits
"""
cache_by_key = {}
def __init__(self, key):
def __init__(self, key, save_at_exit):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
Expand All @@ -49,6 +51,7 @@ def __init__(self, key):
else:
self.cache = {}
self.dirty = False
self.save_at_exit = save_at_exit

def save(self):
if self.dirty:
Expand All @@ -60,16 +63,19 @@ def save(self):
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
value.save()
if value.save_at_exit:
value.save()


def memoize(key):
def memoize(key, save_at_exit=False):
"""Memoize the result of function and reuse multiple times.
Parameters
----------
key: str
The unique key to the file
save_at_exit: bool
Whether save the cache to file when the program exits
Returns
-------
Expand All @@ -81,9 +87,9 @@ def _register(f):
allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey)
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__)
cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
cargs = (len(cargs),) + cargs

def _memoized_f(func, *args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/nn/winograd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from operator import mul
from functools import reduce
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from ..util import const_matrix


Expand Down Expand Up @@ -131,6 +132,8 @@ def _interpolation_points(degree):

return np.array(in_pts[degree-1], dtype=np.float64)


@memoize("topi.nn.winograd_matrices", save_at_exit=False)
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
"""
Expand Down

0 comments on commit 7de8a3a

Please sign in to comment.