Skip to content

Commit

Permalink
Merge pull request #3 from alexander-group/dev
Browse files Browse the repository at this point in the history
Add features requested in Issue #1
  • Loading branch information
noahfranz13 authored Jul 17, 2024
2 parents 6d4b6ba + a37c4ae commit 8b5332f
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 395 deletions.
2 changes: 1 addition & 1 deletion src/syncfit/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.3"
__version__ = "0.3.0"
28 changes: 18 additions & 10 deletions src/syncfit/analysis/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import matplotlib.pyplot as plt
from .util import *
from ..models import MQModel

def plot_chains(sampler, labels, fig=None, axes=None):
'''
Expand Down Expand Up @@ -37,8 +38,8 @@ def plot_chains(sampler, labels, fig=None, axes=None):

return fig, axes

def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None, t=None,
p=None, method='random', fig=None, ax=None, day=None):
def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None,
nu_arr=None, method='random', fig=None, ax=None, day=None):
'''
Plot best fit model
Expand All @@ -49,7 +50,7 @@ def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None, t=None
F [list]: The observed flux densities
Ferr [list]: The observed flux error
nkeep [int]: Number of values to keep
p [float]: p-value used, if not None
nu_arr [list]: List of nus for the best fit lines to be plot with
method [str]: Either 'max' or 'last' or 'random', default is max.
- max: takes the nkeep maximum probability values
- last: takes the last nkeep values from the chain
Expand All @@ -60,7 +61,9 @@ def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None, t=None
Returns:
matplotlib fig, ax
'''

if isinstance(model, MQModel) and model.t is not None:
t = model.t

flat_samples, log_prob = extract_output(sampler)

if method == 'max':
Expand All @@ -71,20 +74,25 @@ def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None, t=None
toplot = flat_samples[-nkeep*10:][np.random.randint(0, nkeep*10, nkeep)]
else:
raise ValueError('method must be either last or max!')

nu_plot = np.arange(1e8,3e11,1e7)

if nu_arr is None:
nu_plot = np.arange(1e8,3e11,1e7)
else:
nu_plot = nu_arr

if ax is None:
fig, ax = plt.subplots(figsize=(4,4))

if t is not None or lum_dist is not None:
kwargs = dict(t=t,lum_dist=lum_dist)
if lum_dist is not None and model.t is None:
kwargs = dict(lum_dist=lum_dist)
elif lum_dist is not None and model.t is not None:
kwargs = dict(lum_dist=lum_dist, t=t)
else:
kwargs={}

for val in toplot:
if p is not None:
res = model.SED(nu_plot, p, *val, **kwargs)
if model.p is not None:
res = model.SED(nu_plot, model.p, *val, **kwargs)
else:
res = model.SED(nu_plot, *val, **kwargs)

Expand Down
46 changes: 35 additions & 11 deletions src/syncfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def do_dynesty(nu:list[float], F_mJy:list[float], F_error:list[float],
lum_dist:float=None, t:float=None,
model:SyncfitModel=MQModel, fix_p:float=None,
upperlimits:list[bool]=None, ncores:int=1, seed:int=None,
upperlimits:list[bool]=None, ncores:int=1, seed:int=None, prior=None,
run_kwargs={}, dynesty_kwargs={}, logprob_kwargs={}
) -> tuple[list[float],list[float]]:
"""
Expand All @@ -28,12 +28,14 @@ def do_dynesty(nu:list[float], F_mJy:list[float], F_error:list[float],
model (SyncfitModel): Model class to use from syncfit.fitter.models. Can also be a custom model
but it must be a subclass of SyncfitModel!
lum_dist (float): luminosity distance in cgs units. Only needed for MQModel. Default is None.
t (flost): observation time in seconds. Only needed for MQModel. Default is None.
t (flost): observation time in days. Only needed for MQModel. Default is None.
fix_p (float): Will fix the p value to whatever you give, do not provide p in theta_init
if this is the case!
upperlimits (list[bool]): True if the point is an upperlimit, False otherwise.
ncores (int) : The number of cores to run on, default is 1 and won't multiprocess
seed (int): The seed for the random number generator passed to dynesty,
prior (dict) : dictionary defining the prior ranges. Keys must be same as model.get_labels().
Value should be a list of length 2 like [min, max], both exclusive.
run_kwargs (dict) : kwargs to pass to dynesty.run_sampler
dynesty_kwargs (dict) : kwargs to pass to dynesty.DynamicNestedSampler
logprob_kwargs (dict) : kwargs to pass to the logprob. For the most part this is
Expand All @@ -42,19 +44,35 @@ def do_dynesty(nu:list[float], F_mJy:list[float], F_error:list[float],
Returns:
flat_samples, log_prob
"""
if isinstance(model(), MQModel) and (lum_dist is None or t is None):
# instantiate a new model object
test_model = model() # just for now
if isinstance(test_model, MQModel) and (lum_dist is None):
raise ValueError('lum_dist and t reequired for MQModel!')

if isinstance(test_model, MQModel):
model = model(p=fix_p, t=t)
else:
model = model(p=fix_p)

# get the extra args
dynesty_args = model.get_kwargs(nu, F_mJy, F_error, lum_dist=lum_dist, t=t, p=fix_p)
dynesty_args = model.get_kwargs(nu, F_mJy, F_error, lum_dist=lum_dist, t=t)

# combine these with the logprob_kwargs
# make the logprob_kwargs second so it overwrites anything we set here
dynesty_args = dynesty_args | logprob_kwargs

ndim = len(model.get_labels(p=fix_p))
ndim = model.ndim
rstate = np.random.default_rng(seed)


# set the model prior instance variable
if prior is not None:
model.prior = prior

if set(model.prior.keys()) != set(model.labels):
raise ValueError(
f'Prior dictionary keys ({model.prior.keys()}) do not match the labels ({model.labels})!'
)

# construct the sampler and run it
# NOTE: I give it the lnprob instead of loglik because there can be some other
# priors that are built into the lnprob that can not be in the dynesty prior
Expand All @@ -75,13 +93,13 @@ def do_dynesty(nu:list[float], F_mJy:list[float], F_error:list[float],
'The override decorator syntax is not currently supported for dynesty!'
)

return dsampler
return model, dsampler

def do_emcee(theta_init:list[float], nu:list[float], F_mJy:list[float],
F_error:list[float], lum_dist:float=None, t:float=None,
model:SyncfitModel=SyncfitModel, niter:int=2000,
nwalkers:int=100, fix_p:float=None, upperlimits:list[bool]=None,
day:str=None, plot:bool=False, ncores:int=1
day:str=None, plot:bool=False, ncores:int=1, prior=None
) -> tuple[list[float],list[float]]:
"""
Fit the data with the given model using the emcee package.
Expand All @@ -106,7 +124,10 @@ def do_emcee(theta_init:list[float], nu:list[float], F_mJy:list[float],
Returns:
flat_samples, log_prob
"""
if isinstance(model(), MQModel) and (lum_dist is None or t is None):
# instantiate a new model object
model = model(p=fix_p)

if isinstance(model, MQModel) and (lum_dist is None or t is None):
raise ValueError('lum_dist and t reequired for MQModel!')

### Fill in initial guesses and number of parameters
Expand All @@ -121,12 +142,15 @@ def do_emcee(theta_init:list[float], nu:list[float], F_mJy:list[float],
upperlimits = np.array(upperlimits)

pos, labels, emcee_args = model.unpack_util(theta_init, nu, F_mJy, F_error,
nwalkers, p=fix_p, lum_dist=lum_dist,
nwalkers, lum_dist=lum_dist,
t=t, upperlimit=upperlimits)

# setup and run the MCMC
nwalkers, ndim = pos.shape

if set(model.prior.keys()) != set(model.labels):
raise ValueError('Prior dictionary keys do not match the labels!')

with Pool(ncores) as pool:
sampler = emcee.EnsembleSampler(
nwalkers,
Expand Down Expand Up @@ -156,4 +180,4 @@ def do_emcee(theta_init:list[float], nu:list[float], F_mJy:list[float],
else:
fig, ax = plot_best_fit(model, sampler, emcee_args['nu'], emcee_args['F'])

return sampler
return model, sampler
71 changes: 23 additions & 48 deletions src/syncfit/models/b1b2_b3b4_weighted_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,34 @@ class B1B2_B3B4_Weighted(SyncfitModel):
This is a specialized model that uses a weighted combination of the B1B2 model and
the B3B4 model. The idea of this model is from XXXYXYXYX et al. (YYYY).
'''

def __init__(self, p=None):
super().__init__(p=p)

# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)


def get_labels(p=None):
def get_labels(self, p=None):
if p is None:
return ['p','log F_v', 'log nu_a','log nu_m']
return ['p','log_F_nu', 'log_nu_a','log_nu_m']
else:
return ['log F_v', 'log nu_a','log nu_m']
return ['log_F_nu', 'log_nu_a','log_nu_m']

# the model, must be named SED!!!
def SED(nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
def SED(self, nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
### Spectrum 1
b1 = 2
b2 = 1/3
Expand Down Expand Up @@ -56,47 +75,3 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
F = (w1*F1+w2*F2) / (w1+w2)

return F

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B1B2_B3B4_Weighted.SED, p=p
)

if p is None:
p, log_F_nu, log_nu_a, log_nu_m= theta
else:
log_F_nu, log_nu_a, log_nu_m= theta

if 2< p < 4 and -4 < log_F_nu < 2 and 6 < log_nu_a < 12 and 6 < log_nu_m < 12 and uppertest:
return 0.0

else:
return -np.inf

def dynesty_prior(theta, nu, F, upperlimit, p=None, **kwargs):
'''
Prior transform for dynesty
'''
if p is None:
p, log_F_nu, log_nu_a, log_nu_m= theta
fixed_p = False,
else:
fixed_p = True
log_F_nu, log_nu_a, log_nu_m= theta

# log_F_nu between -4 and 2
log_F_nu = log_F_nu*6 - 4

# log_nu_a between 6 and 11
log_nu_a = log_nu_a*5 + 6

# same transform to log_nu_c
log_nu_m = log_nu_m*5 + 6

if not fixed_p:
# p should be between 2 and 4
p = 2*p + 2

return p,log_F_nu,log_nu_a,log_nu_m
return log_F_nu,log_nu_a,log_nu_m
83 changes: 39 additions & 44 deletions src/syncfit/models/b1b2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,33 @@ class B1B2(SyncfitModel):
(nu_m). This model uses nu_m > nu_a, the opposite of the B4B5 model.
'''

def get_labels(p=None):
def __init__(self, p=None):
super().__init__(p=p)

# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)


def get_labels(self, p=None):
if p is None:
return ['p','log F_v', 'log nu_a','log nu_m']
return ['p','log_F_nu', 'log_nu_a','log_nu_m']
else:
return ['log F_v', 'log nu_a','log nu_m']
return ['log_F_nu', 'log_nu_a','log_nu_m']

# the model, must be named SED!!!
def SED(nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
def SED(self, nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
b1 = 2
b2 = 1/3
b3 = (1-p)/2
Expand All @@ -34,49 +53,25 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):

return F_nu * term1 * term2

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
def lnprior(self, theta, nu, F, upperlimit, **kwargs):
'''
Logarithmic prior function that can be changed based on the SED model.
'''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B1B2.SED, p=p
nu, F, upperlimit, theta, self.SED
)

if p is None:
p, log_F_nu, log_nu_a, log_nu_m = theta
else:
log_F_nu, log_nu_a, log_nu_m = theta

if 2< p < 4 and -4 < log_F_nu < 2 and 6 < log_nu_a < 12 and log_nu_m > log_nu_a and uppertest:
packed_theta = self.pack_theta(theta)

all_res = []
for param, val in self.prior.items():
res = val[0] < packed_theta[param] < val[1]
all_res.append(res)

if (all(all_res) and
uppertest and
packed_theta['log_nu_m'] > packed_theta['log_nu_a']
):
return 0.0

else:
return -np.inf


def dynesty_transform(theta, nu, F, upperlimit, p=None, **kwargs):
'''
Prior transform for dynesty
'''

if p is None:
p, log_F_nu, log_nu_a, log_nu_m = theta
fixed_p = False,
else:
fixed_p = True
log_F_nu, log_nu_a, log_nu_m = theta


# log_F_nu between -4 and 2
log_F_nu = log_F_nu*6 - 4

# log_nu_a between 6 and 11
log_nu_a = log_nu_a*5 + 6

# same transform to log_nu_m
log_nu_m = log_nu_m*5 + 6

if not fixed_p:
# p should be between 2 and 4
p = 2*p + 2

return p,log_F_nu,log_nu_a,log_nu_m
return log_F_nu,log_nu_a,log_nu_m
Loading

0 comments on commit 8b5332f

Please sign in to comment.