Skip to content

Commit

Permalink
Merge pull request #6 from alexander-group/dev
Browse files Browse the repository at this point in the history
make it so you can fix any parameter using the logprob_kwargs argumen…
  • Loading branch information
noahfranz13 authored Jul 19, 2024
2 parents 8a6c409 + 2ada550 commit 4a05ad3
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 138 deletions.
2 changes: 1 addition & 1 deletion src/syncfit/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.1"
__version__ = "0.3.2"
24 changes: 9 additions & 15 deletions src/syncfit/analysis/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def plot_chains(sampler, labels, fig=None, axes=None):

return fig, axes

def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None,
nu_arr=None, method='random', fig=None, ax=None, day=None):
def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000,
nu_arr=None, method='random', fig=None, ax=None, day=None, **kwargs):
'''
Plot best fit model
Expand All @@ -62,8 +62,11 @@ def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None,
matplotlib fig, ax
'''
if isinstance(model, MQModel) and model.t is not None:
t = model.t

kwargs['t'] = model.t

if model.p is not None:
kwargs['p'] = model.p

flat_samples, log_prob = extract_output(sampler)

if method == 'max':
Expand All @@ -82,19 +85,10 @@ def plot_best_fit(model, sampler, nu, F, Ferr, nkeep=1000, lum_dist=None,

if ax is None:
fig, ax = plt.subplots(figsize=(4,4))

if lum_dist is not None and model.t is None:
kwargs = dict(lum_dist=lum_dist)
elif lum_dist is not None and model.t is not None:
kwargs = dict(lum_dist=lum_dist, t=t)
else:
kwargs={}

for val in toplot:
if model.p is not None:
res = model.SED(nu_plot, model.p, *val, **kwargs)
else:
res = model.SED(nu_plot, *val, **kwargs)
packed_theta = model.pack_theta(val, **kwargs)
res = model.SED(nu_plot, **packed_theta)

ax.plot(nu_plot, res,
'-', color='cornflowerblue', lw = 0.5, alpha = 0.1)
Expand Down
7 changes: 4 additions & 3 deletions src/syncfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import emcee
import dynesty
from multiprocessing import Pool
from warnings import warn
from .analysis import *
from .models.mq_model import MQModel
from .models.syncfit_model import SyncfitModel
Expand Down Expand Up @@ -48,11 +49,11 @@ def do_dynesty(nu:list[float], F_mJy:list[float], F_error:list[float],
test_model = model() # just for now
if isinstance(test_model, MQModel) and (lum_dist is None):
raise ValueError('lum_dist and t reequired for MQModel!')

if isinstance(test_model, MQModel):
model = model(p=fix_p, t=t)
model = model(prior=prior, p=fix_p, t=t)
else:
model = model(p=fix_p)
model = model(prior=prior, p=fix_p)

# get the extra args
dynesty_args = model.get_kwargs(nu, F_mJy, F_error, lum_dist=lum_dist, t=t)
Expand Down
29 changes: 16 additions & 13 deletions src/syncfit/models/b1b2_b3b4_weighted_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,24 @@ class B1B2_B3B4_Weighted(SyncfitModel):
the B3B4 model. The idea of this model is from XXXYXYXYX et al. (YYYY).
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,12],
log_nu_m=[6,12]
)
self.prior = prior

super().__init__(self.prior, p=p)

Expand Down
33 changes: 18 additions & 15 deletions src/syncfit/models/b1b2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@ class B1B2(SyncfitModel):
(nu_m). This model uses nu_m > nu_a, the opposite of the B4B5 model.
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[7,15]
)
self.prior = prior

super().__init__(p=p)
super().__init__(self.prior, p=p)

# the model, must be named SED!!!
def SED(self, nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
Expand All @@ -51,7 +54,7 @@ def lnprior(self, theta, nu, F, upperlimit, **kwargs):
Logarithmic prior function that can be changed based on the SED model.
'''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, self.SED
nu, F, upperlimit, theta, self.SED, **kwargs
)

packed_theta = self.pack_theta(theta)
Expand Down
33 changes: 18 additions & 15 deletions src/syncfit/models/b4b5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,27 @@ class B4B5(SyncfitModel):
subclass this class and overwrite lnprior to redefine this.
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[-np.inf,6]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[-np.inf,6]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[-np.inf,6]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[-np.inf,6]
)
self.prior = prior


super().__init__(p=p)
super().__init__(self.prior, p=p)

# the model, must be named SED!!!
def SED(self, nu, p, log_F_nu, log_nu_a, log_nu_m, **kwargs):
Expand All @@ -53,7 +56,7 @@ def lnprior(self, theta, nu, F, upperlimit, **kwargs):
Logarithmic prior function that can be changed based on the SED model.
'''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, self.SED
nu, F, upperlimit, theta, self.SED, **kwargs
)

packed_theta = self.pack_theta(theta)
Expand Down
37 changes: 20 additions & 17 deletions src/syncfit/models/b4b5b3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,27 @@ class B4B5B3(SyncfitModel):
and minimum energy break (nu_m). This model always requires that nu_m < nu_a < nu_c.
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[0, 6],
log_nu_c=[7,15]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[0, 6],
log_nu_c=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[0, 6],
log_nu_c=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_m=[0, 6],
log_nu_c=[7,15]
)

self.prior = prior

super().__init__(self.prior, p=p)

# the model, must be named SED!!!
Expand Down Expand Up @@ -58,7 +61,7 @@ def lnprior(self, theta, nu, F, upperlimit, **kwargs):
Logarithmic prior function that can be changed based on the SED model.
'''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, self.SED
nu, F, upperlimit, theta, self.SED, **kwargs
)

packed_theta = self.pack_theta(theta)
Expand Down
25 changes: 14 additions & 11 deletions src/syncfit/models/b5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ class B5(SyncfitModel):
Single break model for just the self-absorption break.
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11]
)
self.prior = prior

super().__init__(self.prior, p=p)

Expand Down
31 changes: 17 additions & 14 deletions src/syncfit/models/b5b3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@ class B5B3(SyncfitModel):
break.
'''

def __init__(self, p=None):
def __init__(self, prior=None, p=None):
# then set the default prior for this model
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_c=[7,15]
)
if prior is None:
if p is None:
self.prior = dict(
p=[2,4],
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_c=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_c=[7,15]
)
else:
self.prior = dict(
log_F_nu=[-4,2],
log_nu_a=[6,11],
log_nu_c=[7,15]
)
self.prior = prior

super().__init__(self.prior, p=p)

Expand All @@ -52,7 +55,7 @@ def lnprior(self, theta, nu, F, upperlimit, **kwargs):
Logarithmic prior function that can be changed based on the SED model.
'''
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, self.SED
nu, F, upperlimit, theta, self.SED, **kwargs
)

packed_theta = self.pack_theta(theta)
Expand Down
Loading

0 comments on commit 4a05ad3

Please sign in to comment.