Skip to content

Commit

Permalink
fix time sampling from beta
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 18, 2024
1 parent bba5450 commit 5e44ce0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def default_sample_times(
):
""" they propose to sample times from Beta distribution - last part of appendix part B """

uniform = torch.rand(shape, device = device)
sampled = Beta(alpha, beta).sample().to(device)
return ((s - uniform) / s).clamp(0., 1.) * sampled
alpha = torch.full(shape, alpha, device = device)
beta = torch.full(shape, beta, device = device)
sampled = Beta(alpha, beta).sample()
return (1. - sampled) * s

def noise_assignment(data, noise):
device = data.device
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.47"
version = "0.0.48"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 5e44ce0

Please sign in to comment.