Replies: 1 comment 1 reply
-
Great question. I expect that modeling the sequential structure would help. You can do this with two different GPs; the tricky part is passing through the features import torch
from botorch.models import SingleTaskGP, ModelList, GenericDeterministicModel
from botorch.models.transforms import Standardize, Normalize
from botorch.models.transforms.input import FilterFeatures
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_mll
# generate data
p = torch.linspace(0, 1, 20, dtype=torch.float64).unsqueeze(-1) # first input
b = torch.sin(p) * 2 # first output
q = torch.rand(20, dtype=torch.float64).unsqueeze(-1) # second input
c = -(q - 0.5) ** 2 + q # final output
## Define the first model
# fit b as a function of p and q, ignoring q
model_1a = SingleTaskGP(
train_X=torch.concatenate((p, q), axis=1),
train_Y=b,
outcome_transform=Standardize(m=1),
input_transform=FilterFeatures(feature_indices=torch.tensor([0])),
)
mll = ExactMarginalLogLikelihood(model_1a.likelihood, model_1a)
_ = fit_gpytorch_mll(mll)
# deterministically map p, q => q
model_1b = GenericDeterministicModel(lambda x: x[..., 1].unsqueeze(-1))
model_1 = ModelList(model_1a, model_1b)
## define the second model
model2 = SingleTaskGP(
train_X=torch.concatenate((b, q), axis=1),
train_Y=z,
input_transform=Normalize(d=2),
outcome_transform=Standardize(m=1)
)
mll = ExactMarginalLogLikelihood(model2.likelihood, model2)
_ = fit_gpytorch_mll(mll) To sample from the final output as a function of the inputs (which you might not need to do), you can sample from the posterior of x_test = torch.concatenate((p, c), axis=1)
y_samps = model_1.posterior(x_test).sample(torch.Size([64])) # 64 x 200 x 2
posterior_mean = model2.posterior(y_samps).mean.mean(0) Then to optimize an acquisition function, we'll take from botorch.optim import optimize_acqf
import seaborn as sns
from matplotlib import pyplot as plt
from botorch.acquisition.objective import LearnedObjective
from botorch.acquisition import qLogNoisyExpectedImprovement
# optimize
acqf = qLogNoisyExpectedImprovement(
model=model1,
X_baseline=inputs,
objective=LearnedObjective(pref_model=model2)
)
candidates, _ = optimize_acqf(
acq_function=acqf,
bounds=torch.tensor([[0., 0.], [1., 1.]], dtype=torch.float64),
q=1,
num_restarts=8,
# low number since recursive sampling can take a while
raw_samples=8,
) A weakness of this approach is that it decides |
Beta Was this translation helpful? Give feedback.
-
Hi,
I am looking to use botorch to build models/optimization routines for a type chemical process that looks like this:
Material A -> Material B (subject to some input features P)
Material B-> Material C (subject to some input features Q)
What I would like to do is manipulate all input features (P, Q) to maximize the yield to C. The interesting thing, though, is that I can measure B.
So, I can presumably build a model on the intermediate task B, from which the final task C will share a lot of information because telescoped processes that form high/low C will very likely correlate with intermediate processes that form high/low B.
I was thinking that I could use Multitask modeling to do this. With botorch right now, it is possible to do that only over the subset of features P with tasks B, C. However, task C introduces new parameters (Q). So, Im not sure how to handle that with the current methods.
Is there anyway to set this up so that task C can still benefit from the efficiency of learning the related task B, but then extend that model to include new features Q?
Since I run both steps in sequence, I could include the Q feature in the B task, but this seems like a bad choice since we know that Q features have no influence over B. So, I would think it would just over fit on these.
Beta Was this translation helpful? Give feedback.
All reactions