diff --git a/gpytorch/kernels/linear_kernel.py b/gpytorch/kernels/linear_kernel.py index 51936766e..10102e1be 100644 --- a/gpytorch/kernels/linear_kernel.py +++ b/gpytorch/kernels/linear_kernel.py @@ -52,7 +52,7 @@ def __init__( variance_constraint: Optional[Interval] = None, **kwargs, ): - super(LinearKernel, self).__init__(**kwargs) + super().__init__(**kwargs) if variance_constraint is None: variance_constraint = Positive() self.register_parameter( @@ -73,17 +73,17 @@ def variance(self) -> Tensor: return self.raw_variance_constraint.transform(self.raw_variance) @variance.setter - def variance(self, value: Union[float, Tensor]): + def variance(self, value: Union[float, Tensor]) -> None: self._set_variance(value) - def _set_variance(self, value: Union[float, Tensor]): + def _set_variance(self, value: Union[float, Tensor]) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_variance) self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value)) def forward( - self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params - ) -> LinearOperator: + self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: Optional[bool] = False, **params + ) -> Union[Tensor, LinearOperator]: x1_ = x1 * self.variance.sqrt() if last_dim_is_batch: x1_ = x1_.transpose(-1, -2).unsqueeze(-1)