Skip to content

Commit

Permalink
Adds casting to the default dtype in turbo (#60)
Browse files Browse the repository at this point in the history
* Adds casting to the default dtype in turbo

* Moves the default dtype to torch's

* Forgot to move one tensor
  • Loading branch information
miguelgondu authored Oct 24, 2024
1 parent f07c60a commit b5e3267
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from poli_baselines.core.step_by_step_solver import StepByStepSolver

DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DTYPE = torch.double
DEFAULT_DTYPE = torch.get_default_dtype()


NUM_RESTARTS = 10
Expand Down Expand Up @@ -72,8 +72,12 @@ def from_turbo(X):

self.device = device
self.to_turbo, self.from_turbo = make_transforms()
self.X_turbo = torch.tensor(self.to_turbo(x0)).to(self.device)
self.Y_turbo = torch.tensor(y0).to(self.device)
self.X_turbo = (
torch.tensor(self.to_turbo(x0))
.to(self.device)
.to(torch.get_default_dtype())
)
self.Y_turbo = torch.tensor(y0).to(self.device).to(torch.get_default_dtype())
self.batch_size = 1
dim = x0.shape[1]
self.state = TurboState(dim, batch_size=self.batch_size)
Expand Down Expand Up @@ -218,7 +222,7 @@ def generate_batch(
mask[ind, torch.randint(0, dim - 1, size=(len(ind),), device=device)] = 1

# Create candidate points from the perturbations and the mask
X_cand = x_center.expand(n_candidates, dim).clone().to(device)
X_cand = x_center.expand(n_candidates, dim).clone().to(device).to(DEFAULT_DTYPE)
X_cand[mask] = pert[mask]

# Sample on the candidate points
Expand Down

0 comments on commit b5e3267

Please sign in to comment.