Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add in bioem marginal distribution into cryojax.inference.distributions #208

Open
geoffwoollard opened this issue Apr 15, 2024 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@geoffwoollard
Copy link
Collaborator

Can use this for a template

https://github.com/mjo22/cryojax/blob/main/src/cryojax/inference/distributions/_gaussian_distributions.py

@geoffwoollard geoffwoollard self-assigned this Apr 15, 2024
@geoffwoollard
Copy link
Collaborator Author

class BioemSensor(torch.nn.Module):
    """
    Equations 4 and 10 on p. 6 of SI in 10.1016/j.jsb.2013.10.006
    observe ~ N*simulate + mu
    with flat uniform (flat) prior on N and mu (method="N-mu")
    with saddle point approx of lambda (method="saddle-approx")
    
    Invariant to sign of simulated and observed (each can arbitrarily change sign and does not affect loss)

    Notes:
    -----
    numerical issues when reconstruction with empirical data using 
        N-mu: nans in up and down after 631/4750 iterations, batch size 2
        saddle-approx: nans in term1 and term2 at iterations 2333/4750, batch size 2
    """

    sigma: torch.Tensor

    def __init__(self, image: ImageConfig, 
                 sigma: float, 
                 N_hi: float = 1.0,
                 N_lo: float = 0.1,
                 mu_hi: float = +10.0,
                 mu_lo: float = -10.0,
                 mask_radius: Optional[float] = None, 
                 method: str = 'saddle-approx'):
        super().__init__()

        self.register_buffer('sigma', torch.tensor(sigma))
        self.register_buffer('N_hi', torch.tensor(N_hi))
        self.register_buffer('N_lo', torch.tensor(N_lo))
        self.register_buffer('mu_hi', torch.tensor(mu_hi))
        self.register_buffer('mu_lo', torch.tensor(mu_lo))

        self.mask_radius = mask_radius
        self.method = method

        if mask_radius is not None:
            self.register_buffer(
                'mask',
                cryonerf.nn.affine.make_circular_mask(
                    (image.height, image.width), self.mask_radius
                )
            )
        else:
            self.mask = None

    def likelihood(
        self,
        simulated: torch.Tensor,
        observed: torch.Tensor,
        generator: Optional[torch.Generator] = None,
    ):
        scale = torch.where(self.sigma > 0, 0.5 / self.sigma.square(), torch.ones_like(self.sigma))

        if self.mask is not None:
            observed = observed * self.mask
            simulated = simulated * self.mask

        eps = torch.finfo(torch.float32).eps

        ccc = simulated.pow(2).sum(dim=(-1,-2))
        if torch.isclose(ccc,torch.zeros_like(ccc)).any():
            print('WARNING: simulator all zeros, so ccc too close to zero. Injecting noise to avoid nans.')
            noise_level = (2*scale).sqrt().pow(-1)
            noise = noise_level*torch.randn(simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype)
            simulated = torch.where(ccc.reshape(-1,1,1)==0, simulated + noise, simulated)
            ccc = simulated.pow(2).sum(dim=(-1,-2))

        co = observed.sum(dim=(-1,-2))
        cc = simulated.sum(dim=(-1,-2))
        coo = observed.pow(2).sum(dim=(-1,-2))
        coc = (observed * simulated).sum(dim=(-1,-2))
        
        n_pix = observed.shape[-1] * observed.shape[-2]

        if self.method == 'N-mu':
            # TODO: include missing piece
            up = (n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc -ccc*co*co -coo*cc*cc)
            down = (n_pix*ccc-cc*cc)
            up_over_down = torch.where(torch.logical_and(up==0,down==0), 1,up/down) # protect against 0/0
            neg_log_prob = scale*up_over_down + 0.5*safe_log(down.clamp(min=eps)) + (2-n_pix)*safe_log(scale*2)# neglect constant factors
            assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... up={}|down={}'.format(up,down)

        elif self.method == 'saddle-approx':
            term1 = n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc - ccc*co*co - coo*cc*cc
            term2 = (n_pix-2)*(n_pix*ccc-cc*cc)
            neg_log_prob = -(1.5-n_pix/2)*safe_log(term1.clamp(min=eps)) -(n_pix/2-2)*safe_log(term2.clamp(min=eps))
            assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... term1={}|term2={}'.format(term1,term2)

        elif self.method == 'N-mu-gaussian-prior-N':

            a = -n_pix*scale

            a2 = (cc*cc/n_pix-ccc)*scale
            b2 = (coc-cc*co/n_pix)*scale
            c2 = (co*co/n_pix - coo) * scale

            lambda_N = 100
            mu_N = 1
            a3 = -1/(2*lambda_N*lambda_N)
            b3 = mu_N / (lambda_N*lambda_N)
            c3 = -mu_N*mu_N/(2*lambda_N*lambda_N)

            neg_log_prob = 0.5*safe_log(-a2-a3) + 0.5*safe_log(-a) + (b2+b3)**2/(4*(a2+a3)) - (c2+c3) + math.log(lambda_N) 


        else:
            raise NotImplementedError("choose a method")

        # ad hoc prior for std near 1
        do_prior = False
        if do_prior:
            # std = simulated.std(dim=(-1,-2))
            beta = 0
            neg_log_prob_prior = (ccc.sqrt() - 1).pow(2) #(simulated.std(dim=(-1,-2)) - 1).pow(2)
            neg_log_prob += beta*n_pix*neg_log_prob_prior
        
        neg_log_prob /= n_pix


        likelihood_scale = simulated.new_tensor(n_pix)

        # return log_prob, {'likelihood_scale': likelihood_scale}
        return neg_log_prob, {'likelihood_scale': likelihood_scale, 'neg_log_prob': neg_log_prob}


    def sample(self, simulated: torch.Tensor, generator: Optional[torch.Generator] = None):
        N = self.N_lo + (self.N_hi - self.N_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
        mu =self.mu_lo + (self.mu_hi - self.mu_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
        noise = torch.randn(
            simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype
        )
        return N*simulated + noise.mul_(self.sigma) + mu, {}

    def forward(
        self,
        shot_info: Dict[str, torch.Tensor],
        simulated: torch.Tensor,
        observed: Optional[torch.Tensor] = None,
        generator: Optional[torch.Generator] = None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        if observed is None:
            return self.sample(simulated, generator=generator)
        else:
            return self.likelihood(simulated, observed, generator=generator)

@geoffwoollard
Copy link
Collaborator Author

Screen Shot 2024-04-15 at 10 02 32 AM

@geoffwoollard geoffwoollard added the enhancement New feature or request label Apr 15, 2024
@mjo22
Copy link
Owner

mjo22 commented Apr 15, 2024

Check out the cryojax.inference.distributions.AbstractMarginalDistribution. This was my idea for implementing this. The caveat is that it will require an implementation of the unmarginalized likelihood as well (which I think would be a good idea, at least for the case of the BioEM likelihood)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants