Skip to content

Commit

Permalink
handle device for randn in euler step (huggingface#1124)
Browse files Browse the repository at this point in the history
* handle device for randn in euler step

* convert device to str
  • Loading branch information
patil-suraj authored Nov 3, 2022
1 parent 42bb459 commit 7b030a7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,16 @@ def step(
prev_sample = sample + derivative * dt

device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)

prev_sample = prev_sample + noise * sigma_up

if not return_dict:
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,16 @@ def step(
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)

eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

Expand Down

0 comments on commit 7b030a7

Please sign in to comment.