From 721f906fff63625f7fda4c5954f8bdb4e20e3745 Mon Sep 17 00:00:00 2001 From: TobyBoyne Date: Tue, 27 Feb 2024 17:01:10 +0000 Subject: [PATCH] Move lambda to global func --- gpytorch/kernels/index_kernel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gpytorch/kernels/index_kernel.py b/gpytorch/kernels/index_kernel.py index 7fa5e01f3..df240ca4c 100644 --- a/gpytorch/kernels/index_kernel.py +++ b/gpytorch/kernels/index_kernel.py @@ -72,7 +72,7 @@ def __init__( if prior is not None: if not isinstance(prior, Prior): raise TypeError("Expected gpytorch.priors.Prior but got " + type(prior).__name__) - self.register_prior("IndexKernelPrior", prior, lambda m: m._eval_covar_matrix()) + self.register_prior("IndexKernelPrior", prior, _index_kernel_prior_closure) self.register_constraint("raw_var", var_constraint) @@ -109,3 +109,7 @@ def forward(self, i1, i2, **params): right_interp_indices=i2.expand(batch_shape + i2.shape[-2:]), ) return res + + +def _index_kernel_prior_closure(m): + return m._eval_covar_matrix()