Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Compact interval fixes #28

Open
wants to merge 82 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
0487963
refactored compact flows
jonkhler May 7, 2021
059ef83
rough comments for approx_inverse
Olllom May 7, 2021
d8d325d
fixes in compact interval flows
Olllom May 7, 2021
7acf3a7
remove expand_as
Olllom May 7, 2021
1324350
reduce dlogx in Mixture
Olllom May 7, 2021
cc568ed
moved reduction to WrapTransformer...
Olllom May 7, 2021
eaffeb8
added missing args and kwargs
Olllom May 7, 2021
ade7117
added missing reduce
Olllom May 7, 2021
4456055
raise error in ksection
Olllom May 8, 2021
7563859
fixed shape bug
jonkhler May 9, 2021
748e679
some speed comparison with spline flows
jonkhler May 9, 2021
3845a66
Merge branch 'compact-interval-flows' of github.com:noegroup/bgflow i…
Olllom May 10, 2021
6a0c473
Merge branch 'compact-interval-flows' into compact-interval-fixes
Olllom May 10, 2021
3de759b
delete ksection
Olllom May 10, 2021
16a8d01
added missing args/kwargs
Olllom May 10, 2021
ae727bb
make softplus picklable
Olllom May 10, 2021
21c0028
add sirens
Olllom May 11, 2021
99f2a9c
added stupid code for debugging
Olllom May 11, 2021
0e58c1d
...
Olllom May 11, 2021
aa7ed2c
...
Olllom May 11, 2021
7fb1e27
...
Olllom May 11, 2021
caf167e
fixed broken sigmas/alphas; working grid inversion?
jonkhler May 11, 2021
43a8465
...
Olllom May 12, 2021
5fa0cc4
minor bug fixes; added trainable smooth ramp exponent; added moebius …
jonkhler May 12, 2021
8f47517
Merge branch 'compact-interval-flows' of github.com:noegroup/bgflow i…
Olllom May 17, 2021
6d07b64
Merge branch 'compact-interval-flows' into compact-interval-fixes
Olllom May 17, 2021
218289a
bugfix in cond shape
Olllom May 17, 2021
656e327
...
Olllom May 17, 2021
0012f10
...
Olllom May 17, 2021
671636e
...
Olllom May 17, 2021
1c29be1
...
Olllom May 17, 2021
8f09e7a
...
Olllom May 17, 2021
a66ec76
...
Olllom May 17, 2021
ffa3ede
...
Olllom May 17, 2021
68b8faa
...
Olllom May 17, 2021
e5984dc
...
Olllom May 17, 2021
2f373ee
siren fixes
Olllom May 18, 2021
6cc0ac9
better initialization of log sigmas
jonkhler May 19, 2021
fd67ead
added moebius transformation
jonkhler May 19, 2021
6dd5e6c
Merge branch 'compact-interval-flows' of github.com:noegroup/bgflow i…
Olllom May 19, 2021
95b82fd
fix siren bug
Olllom May 19, 2021
9c122ac
argh
Olllom May 19, 2021
8f7ac0a
optional siren init
Olllom May 19, 2021
dd43fe3
...
Olllom May 19, 2021
39269b4
Andreas's shape fixes for multi-dim inputs
jonkhler May 19, 2021
3f8d8a8
Merge branch 'compact-interval-flows' of github.com:noegroup/bgflow i…
Olllom May 19, 2021
316b6ce
Merge branch 'compact-interval-flows' into compact-interval-fixes
Olllom May 19, 2021
73104e8
import moebius stuff
Olllom May 19, 2021
b5778c2
added non-compact sigmoid transformers
jonkhler May 20, 2021
0e2fc45
sloppy uniform
Olllom May 20, 2021
529762c
add log sigma bound to non compact affine trafos
jonkhler May 20, 2021
d538dec
Merge branch 'compact-interval-flows' of github.com:noegroup/bgflow i…
Olllom May 25, 2021
5be0139
Merge branch 'compact-interval-flows' into compact-interval-fixes
Olllom May 25, 2021
e1b03bc
elementwise_jacobian
Olllom May 25, 2021
6550bbc
...
Olllom May 25, 2021
4ff80c4
...
Olllom May 25, 2021
5f8288c
...
Olllom May 25, 2021
faaca30
...
Olllom May 25, 2021
5984dde
...
Olllom May 25, 2021
7043109
Merge branch 'main' into compact-interval-fixes
Olllom May 25, 2021
52f9e32
Merge pull request #8 from noegroup/compact-interval-fixes
jonkhler May 25, 2021
02d8ccd
modulify distributions
Olllom May 25, 2021
75666dc
...
Olllom May 25, 2021
5f510a4
...
Olllom May 25, 2021
28b3749
...
Olllom May 25, 2021
269eb6d
...
Olllom May 25, 2021
a13c6cd
...
Olllom May 25, 2021
3959b9a
...
Olllom May 25, 2021
748ad99
...
Olllom May 25, 2021
8a4fa8f
...
Olllom May 25, 2021
202101e
...
Olllom May 25, 2021
8705b60
...
Olllom May 25, 2021
47afe22
...
Olllom May 25, 2021
0c457a5
...
Olllom May 25, 2021
bb5b562
...
Olllom May 25, 2021
e9230c2
...
Olllom May 25, 2021
ce4ca32
...
Olllom May 25, 2021
33b0ba9
...
Olllom May 25, 2021
df24ac2
...
Olllom May 25, 2021
191465c
Merge branch 'main' into compact-interval-fixes
Olllom Oct 6, 2021
e29d07f
Merge branch 'compact-interval-flows' into compact-interval-fixes
Olllom Jan 20, 2022
8be78aa
Merge branch 'main' into compact-interval-fixes
Olllom Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions bgflow/distribution/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import torch
from .energy import Energy
from .sampling import Sampler
from torch.distributions import constraints


__all__ = ["TorchDistribution", "CustomDistribution", "UniformDistribution"]
__all__ = ["TorchDistribution", "CustomDistribution", "UniformDistribution", "SloppyUniform"]


class CustomDistribution(Energy, Sampler):
Expand Down Expand Up @@ -67,9 +68,47 @@ def __getattr__(self, name):
raise AttributeError(msg)


class _SloppyUniform(torch.distributions.Uniform):
def __init__(self, *args, tol=1e-5, **kwargs):
super().__init__(*args, **kwargs)
self.tol = tol

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low-self.tol, self.high+self.tol)


class SloppyUniform(torch.nn.Module):
def __init__(self, low, high, validate_args=None, tol=1e-5):
super().__init__()
self.register_buffer("low", low)
self.register_buffer("high", high)
self.tol = tol
self.validate_args = validate_args

def __getattr__(self, name):
try:
return super().__getattr__(name=name)
except AttributeError:
uniform = _SloppyUniform(self.low, self.high, self.validate_args, tol=self.tol)
if hasattr(uniform, name):
return getattr(uniform, name)
except:
raise AttributeError(f"SloppyUniform has no attribute {name}")


class UniformDistribution(TorchDistribution):
"""Shortcut"""
def __init__(self, low, high, validate_args=None, n_event_dims=1):
uniform = torch.distributions.Uniform(low, high, validate_args)
def __init__(self, low, high, tol=1e-5, validate_args=None, n_event_dims=1):
uniform = SloppyUniform(low, high, validate_args, tol=tol)
independent = torch.distributions.Independent(uniform, n_event_dims)
super().__init__(independent)
super().__init__(independent)
self.uniform = uniform

def _energy(self, x):
try:
y = - self._delegate.log_prob(x)[:,None]
assert torch.all(torch.isfinite(y))
return y
except (ValueError, AssertionError):
return -self._delegate.log_prob(self._delegate.sample(sample_shape=x.shape[:-1]))[:,None]
60 changes: 60 additions & 0 deletions bgflow/distribution/modulify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch import nn
from torch.distributions.utils import lazy_property


__all__ = ["distribution_module"]


def distribution_module(cls):
"""A class decorator to "modulify" torch distributions."""
# get the actual distribution class from the mro()
cls_distribution = []
for parent in cls.mro()[1:]:
if (
issubclass(parent, torch.distributions.Distribution)
and parent != torch.distributions.Distribution
):
cls_distribution.append(parent)
assert len(cls_distribution) == 1
cls_distribution = cls_distribution[0]

class DistributionModule:
def __init__(self, *args, **kwargs):
"""Initialize distribution and register plain tensors as buffers."""
super().__init__(*args, **kwargs)

# after initializing re-register distribution parameters as buffers
k_tensors = []
for k in self.__dict__:

# check if it's a lazy_property
# in this case we should simply move on instead of evaluating it.
if self._is_lazy_property(k):
continue

# check if it's a "plain" tensor
if isinstance(getattr(self, k), torch.Tensor):
k_tensors.append(k)

for k in k_tensors:
# for a "plain "tensor we will now register it as buffer.
val = getattr(self, k)
delattr(self, k)
self.register_buffer(k, val)

def __getattribute__(self, name):
"""Return attribute with lazy_property special check."""
if type(self)._is_lazy_property(name):
# deleting the attribute from the instance will simply "reset"
# the lazy_property
delattr(self, name)
return super().__getattribute__(name)

@classmethod
def _is_lazy_property(cls, name):
return isinstance(getattr(cls, name, object), lazy_property)

return type(
cls.__name__, (DistributionModule, cls_distribution, nn.Module), {}
)
27 changes: 26 additions & 1 deletion bgflow/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..utils.types import is_list_or_tuple


__all__ = ["DenseNet", "MeanFreeDenseNet"]
__all__ = ["DenseNet", "MeanFreeDenseNet", "SirenDenseNet"]


class DenseNet(torch.nn.Module):
Expand Down Expand Up @@ -53,3 +53,28 @@ class MeanFreeDenseNet(DenseNet):
def forward(self, x):
y = self._layers(x)
return y - y.mean(dim=1, keepdim=True)


class Sin(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)


class SirenDenseNet(DenseNet):
def __init__(self, *args, scale_first_weights=True, initialize=True, **kwargs):
super().__init__(*args, **kwargs, activation=Sin())
if initialize:
self._init_siren_weights(self._layers, scale_first_weights)

@staticmethod
def _init_siren_weights(layers, scale_first_weights=True):
with torch.no_grad():
linear_layers = [layer for layer in layers if isinstance(layer, torch.nn.Linear)]
for layer in linear_layers:
n = layer.weight.shape[-1]
layer.weight.data = -np.sqrt(6./n) + 2.0*np.sqrt(6./n) * torch.rand_like(layer.weight.data)
if scale_first_weights:
linear_layers[0].weight.data *= 30.
1 change: 1 addition & 0 deletions bgflow/nn/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .estimator import *
from .stochastic import *
from .transformer import *
from .root_finding import *

from .affine import *
from .coupling import *
Expand Down
3 changes: 3 additions & 0 deletions bgflow/nn/flow/root_finding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .approx_inverse import *
from .bisection import *
from .newton import *
Loading