From a3e0b5db04b80fd2deffd1dafcf3ac0e730c8e11 Mon Sep 17 00:00:00 2001 From: xvdp Date: Sun, 6 Sep 2020 12:12:53 -0700 Subject: [PATCH 1/2] Merge fixes to hsic functions distmat() minimized memory allocation, added option to remove grad kernelmat() minimized mem alloc by using in place ops, fixed devices, tensor on input device == output device, hsic_normalized_cca() reuse temp tensors --- source/hsicbt/math/hsic.py | 92 +++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/source/hsicbt/math/hsic.py b/source/hsicbt/math/hsic.py index 214880e..ade7d3c 100644 --- a/source/hsicbt/math/hsic.py +++ b/source/hsicbt/math/hsic.py @@ -1,7 +1,8 @@ import torch import numpy as np from torch.autograd import Variable, grad - +# pylint: disable=no-member +# pylint: disable=not-callable def sigma_estimation(X, Y): """ sigma from median distance """ @@ -16,43 +17,55 @@ def sigma_estimation(X, Y): med=1E-2 return med -def distmat(X): - """ distance matrix +def distmat(X, requires_grad=True): + """ distance matrix |X.X - 2(X x Xt) + (X.X)t| + Args + X (tensor) shape (batchsize, dims) + requires_grad (bool[True]) False: removes gradient from output """ - r = torch.sum(X*X, 1) - r = r.view([-1, 1]) - a = torch.mm(X, torch.transpose(X,0,1)) - D = r.expand_as(a) - 2*a + torch.transpose(r,0,1).expand_as(a) - D = torch.abs(D) - return D - -def kernelmat(X, sigma): + _cloned = False + if X.requires_grad and not requires_grad: + X = X.clone().detach() + _cloned = True + out = torch.mm(X, X.T).mul_(-2.0) + out.add_((X*X).sum(1, keepdim=True)) + out.add_((X*X).sum(1, keepdim=True).T) + if _cloned: + del X + return out.abs_() + +def kernelmat(X, sigma=None, requires_grad=True): """ kernel matrix baker + Args + X (tensor) shape (batchsize, dims) + sigma (float [None]) from config + requires_grad (bool [True]) False: removes gradient from output + """ - m = int(X.size()[0]) - dim = int(X.size()[1]) * 1.0 - H = torch.eye(m) - (1./m) * torch.ones([m,m]) - Dxx = distmat(X) - + m, dim = X.size() + H = torch.eye(m, device=X.device).sub_(1/m) + Kx = distmat(X, requires_grad=requires_grad) + if sigma: - variance = 2.*sigma*sigma*X.size()[1] - Kx = torch.exp( -Dxx / variance).type(torch.FloatTensor) # kernel matrices - # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) + variance = 2.*sigma*sigma*dim + torch.exp_(Kx.mul_(-1.0/variance)) else: try: - sx = sigma_estimation(X,X) - Kx = torch.exp( -Dxx / (2.*sx*sx)).type(torch.FloatTensor) + sx = sigma_estimation(X, X) + variance = 2.*sx*sx + torch.exp_(Kx.mul_(-1.0/variance)) except RuntimeError as e: raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format( sx, torch.max(X), torch.min(X))) - Kxc = torch.mm(Kx,H) - + Kxc = torch.mm(Kx, H) + del H + del Kx return Kxc def distcorr(X, sigma=1.0): X = distmat(X) - X = torch.exp( -X / (2.*sigma*sigma)) + X = torch.exp(-X / (2.*sigma*sigma)) return torch.mean(X) def compute_kernel(x, y): @@ -137,21 +150,28 @@ def hsic_normalized(x, y, sigma=None, use_cuda=True, to_numpy=True): thehsic = Pxy/(Px*Py) return thehsic -def hsic_normalized_cca(x, y, sigma, use_cuda=True, to_numpy=True): +def hsic_normalized_cca(x, y, sigma=None, requires_grad=True): """ + Args + x (tensor) shape (batchsize, dims) + y (tensor) shape (batchsize, dims) + sigma (float [None]) + requires_grad (bool[True]) False: removes gradient from output """ - m = int(x.size()[0]) - Kxc = kernelmat(x, sigma=sigma) - Kyc = kernelmat(y, sigma=sigma) - epsilon = 1E-5 - K_I = torch.eye(m) - Kxc_i = torch.inverse(Kxc + epsilon*m*K_I) - Kyc_i = torch.inverse(Kyc + epsilon*m*K_I) - Rx = (Kxc.mm(Kxc_i)) - Ry = (Kyc.mm(Kyc_i)) - Pxy = torch.sum(torch.mul(Rx, Ry.t())) + m = x.size()[0] + K_I = torch.eye(m, device=x.device).mul_(epsilon*m) - return Pxy + Kc = kernelmat(x, sigma=sigma, requires_grad=requires_grad) + Rx = Kc.mm(Kc.add(K_I).inverse()) + + Kc = kernelmat(y, sigma=sigma, requires_grad=requires_grad) + Ry = Kc.mm(Kc.add(K_I).inverse()) + out = Rx.mul_(Ry.t()).sum() + del Rx + del Ry + del Kc + del K_I + return out From 08458c9a530938f80a902ee239864a89bbcc6ed1 Mon Sep 17 00:00:00 2001 From: xvdp Date: Sun, 6 Sep 2020 12:36:39 -0700 Subject: [PATCH 2/2] removed option to remove gradient --- source/hsicbt/math/hsic.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/source/hsicbt/math/hsic.py b/source/hsicbt/math/hsic.py index ade7d3c..932e1f3 100644 --- a/source/hsicbt/math/hsic.py +++ b/source/hsicbt/math/hsic.py @@ -17,34 +17,25 @@ def sigma_estimation(X, Y): med=1E-2 return med -def distmat(X, requires_grad=True): +def distmat(X): """ distance matrix |X.X - 2(X x Xt) + (X.X)t| Args X (tensor) shape (batchsize, dims) - requires_grad (bool[True]) False: removes gradient from output """ - _cloned = False - if X.requires_grad and not requires_grad: - X = X.clone().detach() - _cloned = True out = torch.mm(X, X.T).mul_(-2.0) out.add_((X*X).sum(1, keepdim=True)) out.add_((X*X).sum(1, keepdim=True).T) - if _cloned: - del X return out.abs_() -def kernelmat(X, sigma=None, requires_grad=True): +def kernelmat(X, sigma=None): """ kernel matrix baker Args X (tensor) shape (batchsize, dims) sigma (float [None]) from config - requires_grad (bool [True]) False: removes gradient from output - """ m, dim = X.size() H = torch.eye(m, device=X.device).sub_(1/m) - Kx = distmat(X, requires_grad=requires_grad) + Kx = distmat(X) if sigma: variance = 2.*sigma*sigma*dim @@ -150,22 +141,21 @@ def hsic_normalized(x, y, sigma=None, use_cuda=True, to_numpy=True): thehsic = Pxy/(Px*Py) return thehsic -def hsic_normalized_cca(x, y, sigma=None, requires_grad=True): +def hsic_normalized_cca(x, y, sigma=None): """ Args x (tensor) shape (batchsize, dims) y (tensor) shape (batchsize, dims) sigma (float [None]) - requires_grad (bool[True]) False: removes gradient from output """ epsilon = 1E-5 m = x.size()[0] K_I = torch.eye(m, device=x.device).mul_(epsilon*m) - Kc = kernelmat(x, sigma=sigma, requires_grad=requires_grad) + Kc = kernelmat(x, sigma=sigma) Rx = Kc.mm(Kc.add(K_I).inverse()) - Kc = kernelmat(y, sigma=sigma, requires_grad=requires_grad) + Kc = kernelmat(y, sigma=sigma) Ry = Kc.mm(Kc.add(K_I).inverse()) out = Rx.mul_(Ry.t()).sum()