diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index c65797106fc5..b5abf9b9b7a6 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -34,9 +34,11 @@ class Cache(object): ---------- key: str The file key to the function + save_at_exit: bool + Whether save the cache to file when the program exits """ cache_by_key = {} - def __init__(self, key): + def __init__(self, key, save_at_exit): cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0]) if not os.path.exists(cache_dir): os.mkdir(cache_dir) @@ -49,6 +51,7 @@ def __init__(self, key): else: self.cache = {} self.dirty = False + self.save_at_exit = save_at_exit def save(self): if self.dirty: @@ -60,16 +63,19 @@ def save(self): def _atexit(): """Save handler.""" for value in Cache.cache_by_key.values(): - value.save() + if value.save_at_exit: + value.save() -def memoize(key): +def memoize(key, save_at_exit=False): """Memoize the result of function and reuse multiple times. Parameters ---------- key: str The unique key to the file + save_at_exit: bool + Whether save the cache to file when the program exits Returns ------- @@ -81,9 +87,9 @@ def _register(f): allow_types = (string_types, int, float) fkey = key + "." + f.__name__ + ".pkl" if fkey not in Cache.cache_by_key: - Cache.cache_by_key[fkey] = Cache(fkey) + Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit) cache = Cache.cache_by_key[fkey] - cargs = tuple(x.cell_contents for x in f.__closure__) + cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else () cargs = (len(cargs),) + cargs def _memoized_f(func, *args, **kwargs): diff --git a/topi/python/topi/nn/winograd_util.py b/topi/python/topi/nn/winograd_util.py index db57f7671618..464b63301b40 100644 --- a/topi/python/topi/nn/winograd_util.py +++ b/topi/python/topi/nn/winograd_util.py @@ -25,6 +25,7 @@ from operator import mul from functools import reduce import numpy as np +from tvm.contrib.pickle_memoize import memoize from ..util import const_matrix @@ -131,6 +132,8 @@ def _interpolation_points(degree): return np.array(in_pts[degree-1], dtype=np.float64) + +@memoize("topi.nn.winograd_matrices", save_at_exit=False) def winograd_transform_matrices(tile_size, kernel_size, out_dtype): """Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`. """