Skip to content

Commit

Permalink
Improve typing annotations for linear kernel (#2613)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisyeh96 authored Dec 6, 2024
1 parent ed021f1 commit 8940078
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
variance_constraint: Optional[Interval] = None,
**kwargs,
):
super(LinearKernel, self).__init__(**kwargs)
super().__init__(**kwargs)
if variance_constraint is None:
variance_constraint = Positive()
self.register_parameter(
Expand All @@ -73,17 +73,17 @@ def variance(self) -> Tensor:
return self.raw_variance_constraint.transform(self.raw_variance)

@variance.setter
def variance(self, value: Union[float, Tensor]):
def variance(self, value: Union[float, Tensor]) -> None:
self._set_variance(value)

def _set_variance(self, value: Union[float, Tensor]):
def _set_variance(self, value: Union[float, Tensor]) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_variance)
self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value))

def forward(
self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params
) -> LinearOperator:
self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: Optional[bool] = False, **params
) -> Union[Tensor, LinearOperator]:
x1_ = x1 * self.variance.sqrt()
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
Expand Down

0 comments on commit 8940078

Please sign in to comment.