Skip to content

Commit

Permalink
PairwiseGP GPU compatibility (#537)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #537

Enabling PairwiseGP to work with GPU

Reviewed By: Balandat

Differential Revision: D23570258

fbshipit-source-id: ce0e1f8cbf381e3cec7a6a9cc39cc58b8e63c36a
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Sep 10, 2020
1 parent 7258e56 commit de486e3
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 59 deletions.
170 changes: 111 additions & 59 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
datapoints: Tensor,
comparisons: Tensor,
covar_module: Optional[Module] = None,
noise_module: Optional[HomoskedasticNoise] = None,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -74,33 +75,38 @@ def __init__(
comparisons[i] is a noisy indicator suggesting the utility value
of comparisons[i, 0]-th is greater than comparisons[i, 1]-th.
covar_module: Covariance module
noise_module: Noise module
"""

# Compatibility variables with fit_gpytorch_*: Dummy likelihood
# Likelihood is tightly tied with this model and
# it doesn't make much sense to keep it separate
self.likelihood = None

self.datapoints = None
self.comparisons = None
# TODO: remove these variables from `state_dict()` so that when calling
# `load_state_dict()`, only the hyperparameters are copied over
self.register_buffer("datapoints", None)
self.register_buffer("comparisons", None)
self.register_buffer("utility", None)
self.register_buffer("covar_chol", None)
self.register_buffer("likelihood_hess", None)
self.register_buffer("hlcov_eye", None)
self.register_buffer("covar", None)
self.register_buffer("covar_inv", None)

self.train_inputs = []
self.train_targets = None

self.pred_cov_fac_need_update = True
self._input_batch_shape = torch.Size()
self.dim = None
# will be set to match datapoints' dtype and device
# since scipy.optimize.fsolve only works on cpu, it'd be the
# fastest to fit the model on cpu and take samples on gpu to avoid
# overhead of moving data back and forth during fitting time
self.tkwargs = {}
# See set_train_data for additional compatibility variables
self.set_train_data(datapoints, comparisons, update_model=False)
self.covar = None
self.covar_inv = None
self.covar_chol = None
self.likelihood_hess = None
self.utility = None
self.hlcov_eye = None
self._std_norm = torch.distributions.normal.Normal(
torch.zeros(1, **self.tkwargs), torch.ones(1, **self.tkwargs)
)

# Set optional parameters
# jitter to add for numerical stability
Expand All @@ -124,26 +130,33 @@ def __init__(
# Do not optimize constant mean prior
for param in self.mean_module.parameters():
param.requires_grad = False
self.noise_covar = HomoskedasticNoise(
noise_prior=SmoothedBoxPrior(-5, 5, 0.5, transform=torch.log),
noise_constraint=GreaterThan(1e-4), # if None, 1e-4 by default
batch_shape=self._input_batch_shape,
)

# set covariance module
if noise_module is None:
noise_module = HomoskedasticNoise(
noise_prior=SmoothedBoxPrior(-5, 5, 0.5, transform=torch.log),
noise_constraint=GreaterThan(1e-4), # if None, 1e-4 by default
batch_shape=self._input_batch_shape,
)
self.noise_module = noise_module

# set covariance module
if covar_module is None:
ls_prior = GammaPrior(1.2, 0.5)
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
self.covar_module = RBFKernel(
covar_module = RBFKernel(
batch_shape=self._input_batch_shape,
ard_num_dims=self.dim,
lengthscale_prior=ls_prior,
lengthscale_constraint=Positive(
transform=None, initial_value=ls_prior_mode
),
)
else:
self.covar_module = covar_module
self.covar_module = covar_module

self._x0 = None # will store temporary results for warm-starting
if self.datapoints is not None and self.comparisons is not None:
self.to(dtype=self.datapoints.dtype, device=self.datapoints.device)
self._update() # Find f_map for initial parameters

self.to(self.datapoints)
Expand Down Expand Up @@ -180,11 +193,11 @@ def __deepcopy__(self, memo) -> PairwiseGP:

@property
def std_noise(self) -> Tensor:
return self.noise_covar.noise
return self.noise_module.noise

@std_noise.setter
def std_noise(self, value: Tensor) -> None:
self.noise_covar.initialize(noise=value)
self.noise_module.initialize(noise=value)

@property
def num_outputs(self) -> int:
Expand All @@ -207,14 +220,18 @@ def _calc_covar(
covar = self.covar_module(X1, X2)
if observation_noise:
noise_shape = self._input_batch_shape + self.covar.shape[-1:]
noise = self.noise_covar(shape=noise_shape)
noise = self.noise_module(shape=noise_shape)
covar = AddedDiagLazyTensor(covar, noise)
return covar.evaluate()

def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor:
r"""Wrapper to perform (batched) cholesky inverse"""
# TODO: get rid of this once cholesky_inverse supports batch mode
batch_eye = torch.eye(mat_chol.shape[-1], **self.tkwargs)
batch_eye = torch.eye(
mat_chol.shape[-1],
dtype=self.datapoints.dtype,
device=self.datapoints.device,
)

if len(mat_chol.shape) == 2:
mat_inv = torch.cholesky_inverse(mat_chol)
Expand Down Expand Up @@ -302,14 +319,19 @@ def _calc_z(
z_logcdf: log CDF of z
hazard: hazard function defined as pdf(z)/cdf(z)
"""
scaled_util = (utility / (math.sqrt(2) * std_noise)).unsqueeze(-1)

scaled_util = (utility / (math.sqrt(2) * std_noise)).unsqueeze(-1).to(D)
z = (D @ scaled_util).squeeze(-1)
std_norm = torch.distributions.normal.Normal(
torch.zeros(1, dtype=z.dtype, device=z.device),
torch.ones(1, dtype=z.dtype, device=z.device),
)
# Clamp z for stable log transformation. This also prevent extreme values
# from appearing in the hess matrix, which should help with numerical
# stability and avoid extreme fitted hyperparameters
z = z.clamp(-self._zlim, self._zlim)
z_logpdf = self._std_norm.log_prob(z)
z_cdf = self._std_norm.cdf(z)
z_logpdf = std_norm.log_prob(z)
z_cdf = std_norm.cdf(z)
z_logcdf = torch.log(z_cdf)
hazard = torch.exp(z_logpdf - z_logcdf)
return z, z_logpdf, z_logcdf, hazard
Expand Down Expand Up @@ -388,17 +410,21 @@ def _grad_posterior_f(
covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
ret_np: return a numpy array if true, otherwise a Tensor
"""
prior_mean = self._prior_mean(datapoints)

if ret_np:
utility = torch.tensor(utility, **self.tkwargs)
utility = torch.tensor(utility, dtype=self.datapoints.dtype)
prior_mean = prior_mean.cpu()

b = self._grad_likelihood_f_sum(utility, D, std_noise)

# g_ = covar_inv x (utility - pred_prior)
p = (utility - self._prior_mean(datapoints)).unsqueeze(-1)
p = (utility - prior_mean).unsqueeze(-1).to(covar_chol)
g_ = torch.cholesky_solve(p, covar_chol).squeeze(-1)
g = g_ + b

if ret_np:
return g.numpy()
return g.cpu().numpy()
else:
return g

Expand Down Expand Up @@ -430,7 +456,8 @@ def _hess_posterior_f(
ret_np: return a numpy array if true, otherwise a Tensor
"""
if ret_np:
utility = torch.tensor(utility, **self.tkwargs)
utility = torch.tensor(utility, dtype=self.datapoints.dtype)

hl = self._hess_likelihood_f_sum(utility, D, DT, std_noise)
hess = hl + covar_inv
return hess.numpy() if ret_np else hess
Expand Down Expand Up @@ -466,7 +493,9 @@ def _update_utility_derived_values(self) -> None:
"""
hl = self.likelihood_hess # "C" from page 27, [Brochu2010tutorial]_
hlcov = hl @ self.covar
eye = torch.eye(hlcov.size(-1)).expand(hlcov.shape)
eye = torch.eye(
hlcov.size(-1), dtype=self.datapoints.dtype, device=self.datapoints.device
).expand(hlcov.shape)
self.hlcov_eye = hlcov + eye

self.pred_cov_fac_need_update = False
Expand Down Expand Up @@ -501,18 +530,18 @@ def _update(self, **kwargs) -> None:
x0 = self._x0

if len(self._input_batch_shape) > 0:
# batch mode, do optimize.fsolve sequentially
# batch mode, do optimize.fsolve sequentially on CPU
# TODO: enable vectorization/parallelization here
x0 = x0.reshape(-1, self.n)
dp_v = self.datapoints.view(-1, self.n, self.dim)
D_v = self.D.view(-1, self.m, self.n)
DT_v = self.DT.view(-1, self.n, self.m)
dp_v = self.datapoints.view(-1, self.n, self.dim).cpu()
D_v = self.D.view(-1, self.m, self.n).cpu()
DT_v = self.DT.view(-1, self.n, self.m).cpu()
# Use `expand` here since we need to expand std_noise along
# the batch shape dimensions if we start off as non-batch model,
# but later conditioned on batched new data
sn_v = self.std_noise.expand(*init_x0_size[:-1], 1).reshape(-1)
ch_v = self.covar_chol.view(-1, self.n, self.n)
ci_v = self.covar_inv.view(-1, self.n, self.n)
sn_v = self.std_noise.expand(*init_x0_size[:-1], 1).reshape(-1).cpu()
ch_v = self.covar_chol.view(-1, self.n, self.n).cpu()
ci_v = self.covar_inv.view(-1, self.n, self.n).cpu()
x = np.empty(x0.shape)
for i in range(x0.shape[0]):
fsolve_args = (
Expand All @@ -537,13 +566,14 @@ def _update(self, **kwargs) -> None:
)
x = x.reshape(*init_x0_size)
else:
# fsolve only works on CPU
fsolve_args = (
self.datapoints,
self.D,
self.DT,
self.std_noise,
self.covar_chol,
self.covar_inv,
self.datapoints.cpu(),
self.D.cpu(),
self.DT.cpu(),
self.std_noise.cpu(),
self.covar_chol.cpu(),
self.covar_inv.cpu(),
True,
)
with warnings.catch_warnings():
Expand All @@ -559,7 +589,9 @@ def _update(self, **kwargs) -> None:
)

self._x0 = x.copy() # save for warm-starting
f = torch.tensor(x, **self.tkwargs)
f = torch.tensor(
x, dtype=self.datapoints.dtype, device=self.datapoints.device
)

# To perform hyperparameter optimization, this need to be recalculated
# when calling forward() in order to obtain correct gradients
Expand All @@ -569,8 +601,11 @@ def _update(self, **kwargs) -> None:
f, self.D, self.DT, self.std_noise
)

# Lazy update utility pred_cov_fac
# Lazy update hlcov_eye, which is used in calculating posterior during training
self.pred_cov_fac_need_update = True
# fill in dummy values for hlcov_eye so that load_state_dict can function
hlcov_eye_size = torch.Size((*self.likelihood_hess.shape[:-2], self.n, self.n))
self.hlcov_eye = torch.empty(hlcov_eye_size)

self.utility = f.clone().requires_grad_(True)

Expand Down Expand Up @@ -605,12 +640,13 @@ def _transform_batch_shape(self, X: Tensor, X_new: Tensor) -> Tuple[Tensor, Tens
# if X has fewer dimension, try to expand it to X_new's shape
return X.expand(X_new_bs + X.shape[-2:]), X_new

def _util_newton_updates(self, x0, max_iter=100, xtol=None) -> Tensor:
def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:
r"""Make `max_iter` newton updates on utility.
This is used in `forward` to calculate and fill in gradient into tensors.
Instead of doing utility -= H^-1 @ g, use substition method.
See more explanation in _update_utility_derived_values.
See more explanation in _update_utility_derived_values.dd
By default only need to run one iteration just to fill the the gradients.
Args:
x0: A `batch_size x n` dimension tensor, initial values
Expand All @@ -636,7 +672,11 @@ def _util_newton_updates(self, x0, max_iter=100, xtol=None) -> Tensor:
hl = self._hess_likelihood_f_sum(x, D, DT, sn)
cov_hl = covar @ hl
if eye is None:
eye = torch.eye(cov_hl.size(-1)).expand(cov_hl.shape)
eye = torch.eye(
cov_hl.size(-1),
dtype=self.datapoints.dtype,
device=self.datapoints.device,
).expand(cov_hl.shape)
cov_hl = cov_hl + eye # add 1 to cov_hl
g = self._grad_posterior_f(x, dp, D, DT, sn, ch, ci)
cov_g = covar @ g.unsqueeze(-1)
Expand Down Expand Up @@ -675,9 +715,6 @@ def set_train_data(
"dtype": self.datapoints.dtype,
"device": self.datapoints.device,
}
self._std_norm = torch.distributions.normal.Normal(
torch.zeros(1, **self.tkwargs), torch.ones(1, **self.tkwargs)
)

# Compatibility variables with fit_gpytorch_*
# alias for datapoints (train_inputs) and comparisons ("train_targets" here)
Expand All @@ -699,13 +736,14 @@ def set_train_data(
# TODO: make D a sparse matrix once pytorch has better support for
# sparse tensors
D_size = torch.Size((*(self._input_batch_shape), self.m, self.n))
self.D = torch.zeros(D_size, **self.tkwargs)
self.D = torch.zeros(
D_size, dtype=self.datapoints.dtype, device=self.datapoints.device
)
comp_view = self.comparisons.view(-1, self.m, 2).long()
for i, sub_D in enumerate(self.D.view(-1, self.m, self.n)):
sub_D.scatter_(1, comp_view[i, :, [0]], 1)
sub_D.scatter_(1, comp_view[i, :, [1]], -1)
self.DT = self.D.transpose(-1, -2)

if update_model:
self._update()

Expand Down Expand Up @@ -758,7 +796,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1
# where A = covar_inv, B = hl
hl_cov = hl @ covar
eye = torch.eye(hl_cov.size(-1)).expand(hl_cov.shape)
eye = torch.eye(
hl_cov.size(-1),
dtype=self.datapoints.dtype,
device=self.datapoints.device,
).expand(hl_cov.shape)
hl_cov_I = hl_cov + eye # add I to hl_cov
train_covar_map = covar - covar @ torch.solve(hl_cov, hl_cov_I).solution
output_mean, output_covar = self.utility, train_covar_map
Expand Down Expand Up @@ -805,7 +847,15 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
output_mean, output_covar = pred_mean, pred_covar

try:
diag_jitter = torch.eye(output_covar.size(-1)).expand(output_covar.shape)
if self.datapoints is None:
diag_jitter = torch.eye(output_covar.size(-1))
else:
diag_jitter = torch.eye(
output_covar.size(-1),
dtype=self.datapoints.dtype,
device=self.datapoints.device,
)
diag_jitter = diag_jitter.expand(output_covar.shape)
diag_jitter = diag_jitter * self._jitter
# Preemptively adding jitter to diagonal to prevent the use of _add_jitter
# given that torch.cholesky may be very slow on non-pd matrix input
Expand Down Expand Up @@ -850,8 +900,8 @@ def posterior(
post = self(X)

if observation_noise:
noise_covar = self.noise_covar(shape=post.mean.shape).evaluate()
post = MultivariateNormal(post.mean, post.covariance_matrix + noise_covar)
noise_module = self.noise_module(shape=post.mean.shape).evaluate()
post = MultivariateNormal(post.mean, post.covariance_matrix + noise_module)

return GPyTorchPosterior(post)

Expand Down Expand Up @@ -938,7 +988,9 @@ def forward(self, post: Posterior, comp: Tensor) -> Tensor:
part1 = -log_posterior

part2 = model.covar @ model.likelihood_hess
eye = torch.eye(part2.size(-1)).expand(part2.shape)
eye = torch.eye(
part2.size(-1), dtype=model.datapoints.dtype, device=model.datapoints.device
).expand(part2.shape)
part2 = part2 + eye
part2 = -0.5 * torch.logdet(part2)

Expand Down
Loading

0 comments on commit de486e3

Please sign in to comment.