Skip to content

Commit

Permalink
add center argument to Normalize
Browse files Browse the repository at this point in the history
Summary: This allows one to specify a center other than 0.5 for the normalized data. E.g. for GPs with linear kernels, the inputs should be centered at 0. See discussion on cornellius-gp/gpytorch#2617 (comment).

Differential Revision: D68293784
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 16, 2025
1 parent ff040d0 commit 696585d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
12 changes: 9 additions & 3 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _update_coefficients(self, X: Tensor) -> None:


class Normalize(AffineInputTransform):
r"""Normalize the inputs to the unit cube.
r"""Normalize the inputs have unit range and be centered at 0.5 (by default).
If no explicit bounds are provided this module is stateful: If in train mode,
calling `forward` updates the module state (i.e. the normalizing bounds). If
Expand All @@ -635,6 +635,7 @@ def __init__(
min_range: float = 1e-8,
learn_bounds: bool | None = None,
almost_zero: float = 1e-12,
center: float = 0.5,
) -> None:
r"""Normalize the inputs to the unit cube.
Expand Down Expand Up @@ -662,6 +663,7 @@ def __init__(
NOTE: This only applies if `learn_bounds=True`.
learn_bounds: Whether to learn the bounds in train mode. Defaults
to False if bounds are provided, otherwise defaults to True.
center: The center of the range for each parameter. Default: 0.5.
Example:
>>> t = Normalize(d=2)
Expand Down Expand Up @@ -704,10 +706,11 @@ def __init__(
"will not be updated and the transform will be a no-op.",
UserInputWarning,
)
self.center = center
super().__init__(
d=d,
coefficient=coefficient,
offset=offset,
offset=offset + (0.5 - center) * coefficient,
indices=indices,
batch_shape=batch_shape,
transform_on_train=transform_on_train,
Expand Down Expand Up @@ -745,7 +748,10 @@ def _update_coefficients(self, X) -> None:
coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - offset
almost_zero = coefficient < self.min_range
self._coefficient = torch.where(almost_zero, 1.0, coefficient)
self._offset = torch.where(almost_zero, 0.0, offset)
self._offset = (
torch.where(almost_zero, 0.0, offset)
+ (0.5 - self.center) * self._coefficient
)

def get_init_args(self) -> dict[str, Any]:
r"""Get the arguments necessary to construct an exact copy of the transform."""
Expand Down
14 changes: 10 additions & 4 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
from abc import ABC
from copy import deepcopy
from itertools import product
from random import randint

import torch
Expand Down Expand Up @@ -259,17 +260,19 @@ def test_normalize(self) -> None:
nlz(X)

# basic usage
for batch_shape in (torch.Size(), torch.Size([3])):
for batch_shape, center in product(
(torch.Size(), torch.Size([3])), [0.5, 0.0]
):
# learned bounds
nlz = Normalize(d=2, batch_shape=batch_shape)
nlz = Normalize(d=2, batch_shape=batch_shape, center=center)
X = torch.randn(*batch_shape, 4, 2, device=self.device, dtype=dtype)
for _X in (torch.stack((X, X)), X): # check batch_shape is obeyed
X_nlzd = nlz(_X)
self.assertEqual(nlz.mins.shape, batch_shape + (1, X.shape[-1]))
self.assertEqual(nlz.ranges.shape, batch_shape + (1, X.shape[-1]))

self.assertEqual(X_nlzd.min().item(), 0.0)
self.assertEqual(X_nlzd.max().item(), 1.0)
self.assertAllClose(X_nlzd.min().item(), center - 0.5)
self.assertAllClose(X_nlzd.max().item(), center + 0.5)

nlz.eval()
X_unnlzd = nlz.untransform(X_nlzd)
Expand All @@ -278,6 +281,9 @@ def test_normalize(self) -> None:
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
dim=-2,
)
coeff = expected_bounds[..., 1, :] - expected_bounds[..., 0, :]
expected_bounds[..., 0, :] += (0.5 - center) * coeff
expected_bounds[..., 1, :] = expected_bounds[..., 0, :] + coeff
atol = 1e-6 if dtype is torch.float32 else 1e-12
rtol = 1e-4 if dtype is torch.float32 else 1e-8
self.assertAllClose(nlz.bounds, expected_bounds, atol=atol, rtol=rtol)
Expand Down

0 comments on commit 696585d

Please sign in to comment.