Skip to content

Commit

Permalink
Bump to version 1.5.2 (#2755)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Feb 1, 2021
1 parent 69562bf commit 585beb9
Show file tree
Hide file tree
Showing 67 changed files with 88 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ install:
- pip install -U pip
# Keep track of pyro-api master branch
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
- pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- pip install .[test]
- pip install coveralls
- pip freeze
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,5 @@ def setup(app):
# @jpchen's hack to get rtd builder to install latest pytorch
# See similar line in the install section of .travis.yml
if 'READTHEDOCS' in os.environ:
os.system('pip install torch==1.6.0+cpu torchvision==0.7.0+cpu '
os.system('pip install torch==1.7.0+cpu torchvision==0.8.1+cpu '
'-f https://download.pytorch.org/whl/torch_stable.html')
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def per_param_optim_args(module_name, param_name):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
help='number of optimization steps to take')
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--num-chains", nargs='?', default=4, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=200, type=int)
parser.add_argument('--jit', action='store_true')
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/scoping_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=200, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/tree_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=100, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/cevae/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Causal Effect Variational Autoencoder")
parser.add_argument("--num-data", default=1000, type=int)
parser.add_argument("--feature-dim", default=5, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(
description="Regional compartmental epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(
description="Compartmental epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=1000, type=float)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/forecast/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def transform(pred, truth):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example")
parser.add_argument("--train-window", default=2160, type=int)
parser.add_argument("--test-window", default=336, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales")
parser.add_argument("-m", "--model", default="1", type=str,
help="one of: {}".format(", ".join(sorted(models.keys()))))
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/gp/sv-dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='Pyro GP MNIST Example')
parser.add_argument('--data-dir', type=str, default=None, metavar='PATH',
help='default directory to cache MNIST data')
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/oed/ab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main(num_vi_steps, num_bo_steps, seed):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="A/B test experiment design using VI")
parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int)
parser.add_argument('--num-bo-steps', nargs="?", default=5, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/timeseries/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="contrib.timeseries example usage")
parser.add_argument("-n", "--num-steps", default=300, type=int)
parser.add_argument("-s", "--seed", default=0, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/cvae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-nq', '--num-quadrant-inputs', metavar='N', type=int,
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def do_evaluation():

# parse command-line arguments and execute the main method
if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')

parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', type=int, default=5000)
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='Eight Schools MCMC')
parser.add_argument('--num-samples', type=int, default=1000,
help='number of MCMC samples (default: 1000)')
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='Eight Schools SVI')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales")
parser.add_argument("-m", "--model", default="1", type=str,
help="one of: {}".format(", ".join(sorted(models.keys()))))
Expand Down
2 changes: 1 addition & 1 deletion examples/inclined_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=500, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
parser.add_argument("-t", "--num-topics", default=8, type=int)
parser.add_argument("-w", "--num-words", default=1024, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior")
parser.add_argument("--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--n", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def guide(data):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-b", "--backend", default="minipyro")
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
parser.add_argument('-n', '--num-steps', default=10000, type=int,
help='number of SVI steps')
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/hyperbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
parser.add_argument('--price', default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/schelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
parser.add_argument('--depth', default=2, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/schelling_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
parser.add_argument('--depth', default=3, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/semantic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def is_all_qud(world):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/scanvi/scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
# Parse command line arguments
parser = argparse.ArgumentParser(description="single-cell ANnotation using Variational Inference")
parser.add_argument('-s', '--seed', default=0, type=int, help='rng seed')
Expand Down
2 changes: 1 addition & 1 deletion examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=10, type=int)
parser.add_argument("-m", "--min-observations", default=3, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_gamma_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=1500, type=int, help='number of training epochs')
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='Krylov KIT')
parser.add_argument('--num-data', type=int, default=750)
parser.add_argument('--num-steps', type=int, default=1000)
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Distributed training via Horovod")
parser.add_argument("-o", "--outfile")
parser.add_argument("-s", "--size", default=1000000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/toy_mixture_model_discrete_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_true_pred_CPDs(CPD, posterior_param):


if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description="Toy mixture model")
parser.add_argument("-n", "--num-steps", default=4000, type=int)
parser.add_argument("-o", "--num-obs", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def main(args):
"-sup 3000 -zd 50 -hl 500 -lr 0.00042 -b1 0.95 -bs 200 -log ./tmp.log"

if __name__ == "__main__":
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')

parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN))

Expand Down
2 changes: 1 addition & 1 deletion examples/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs')
Expand Down
2 changes: 1 addition & 1 deletion examples/vae/vae_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def main(args):


if __name__ == '__main__':
assert pyro.__version__.startswith('1.5.1')
assert pyro.__version__.startswith('1.5.2')
parser = argparse.ArgumentParser(description='VAE using MNIST dataset')
parser.add_argument('-n', '--num-epochs', nargs='?', default=10, type=int)
parser.add_argument('--batch_size', nargs='?', default=128, type=int)
Expand Down
2 changes: 1 addition & 1 deletion pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyro.util import set_rng_seed

# After changing this, run scripts/update_version.py
version_prefix = '1.5.1'
version_prefix = '1.5.2'

# Get the __version__ string from the auto-generated _version.py file, if exists.
try:
Expand Down
18 changes: 16 additions & 2 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)


def as_complex(x):
"""
Similar to :func:`torch.view_as_complex` but copies data in case strides
are not multiples of two.
"""
if any(stride % 2 for stride in x.stride()[:-1]):
# First try to normalize strides.
x = x.squeeze().reshape(x.shape)
if any(stride % 2 for stride in x.stride()[:-1]):
# Fall back to copying data.
x = x.clone()
return torch.view_as_complex(x)


def block_diag_embed(mat):
"""
Takes a tensor of shape (..., B, M, N) and returns a block diagonal tensor
Expand Down Expand Up @@ -278,7 +292,7 @@ def dct(x, dim=-1):
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
M = Y.size(-1)
coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1)
X = torch.view_as_complex(coef) * Y
X = as_complex(coef) * Y
# NB: if we use the full-length version Y_full = fft(y, n=N), then
# the real part of the later half of X will be the flip
# of the negative of the imaginary part of the first half
Expand Down Expand Up @@ -320,7 +334,7 @@ def idct(x, dim=-1):
X = torch.stack([x[..., :M], xi], dim=-1)
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1)
Y = torch.view_as_complex(coef) * torch.view_as_complex(X)
Y = as_complex(coef) * as_complex(X)
# Step 2
y = irfft(Y, n=N)
# Step 3
Expand Down
Loading

0 comments on commit 585beb9

Please sign in to comment.