Skip to content

Commit

Permalink
add model override decorator functionality and reversion
Browse files Browse the repository at this point in the history
  • Loading branch information
noahfranz13 committed Jun 14, 2024
1 parent 7cee994 commit 6381123
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 51 deletions.
2 changes: 1 addition & 1 deletion docs/source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Source Documentation
:members:
:inherited-members:

.. autoclass:: syncfit.models.BaseModel
.. autoclass:: syncfit.models.SyncfitModel
:members:

`syncfit.analysis`
Expand Down
157 changes: 133 additions & 24 deletions docs/tutorials/2_fit_custom_model.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/syncfit/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.1"
__version__ = "0.1.0"
8 changes: 4 additions & 4 deletions src/syncfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import emcee
from .analysis import *
from .models.b5_model import B5
from .models.base_model import BaseModel
from .models.syncfit_model import SyncfitModel

def do_emcee(theta_init:list[float], nu:list[float], F_muJy:list[float],
F_error:list[float], model:BaseModel=B5, niter:int=2000,
F_error:list[float], model:SyncfitModel=SyncfitModel, niter:int=2000,
nwalkers:int=100, fix_p:float=None, upperlimits:list[bool]=None,
day:str=None, plot:bool=False
) -> tuple[list[float],list[float]]:
Expand All @@ -22,8 +22,8 @@ def do_emcee(theta_init:list[float], nu:list[float], F_muJy:list[float],
nu (list): list of frequencies in GHz
F_muJy (list): list of fluxes in micro janskies
F_error (list): list of flux error in micro janskies
model (BaseModel): Model class to use from syncfit.fitter.models. Can also be a custom model
but it must be a subclass of BaseModel!
model (SyncfitModel): Model class to use from syncfit.fitter.models. Can also be a custom model
but it must be a subclass of SyncfitModel!
niter (int): The number of iterations to run on.
nwalkers (int): The number of walkers to use for emcee
fix_p (float): Will fix the p value to whatever you give, do not provide p in theta_init
Expand Down
2 changes: 1 addition & 1 deletion src/syncfit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from syncfit.models.b1b2_model import B1B2
from syncfit.models.b1b2_b3b4_weighted_model import B1B2_B3B4_Weighted
from syncfit.models.b5b3_model import B5B3
from syncfit.models.base_model import BaseModel
from syncfit.models.syncfit_model import SyncfitModel
6 changes: 3 additions & 3 deletions src/syncfit/models/b1b2_b3b4_weighted_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B1B2_B3B4_Weighted(BaseModel):
class B1B2_B3B4_Weighted(SyncfitModel):
'''
This is a specialized model that uses a weighted combination of the B1B2 model and
the B3B4 model. The idea of this model is from XXXYXYXYX et al. (YYYY).
Expand Down Expand Up @@ -59,7 +59,7 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B1B2_B3B4_Weighted.SED, p=p
)

Expand Down
6 changes: 3 additions & 3 deletions src/syncfit/models/b1b2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B1B2(BaseModel):
class B1B2(SyncfitModel):
'''
Two-break model for the self-absorption break (nu_a) and the minimal energy break
(nu_m). This model uses nu_m > nu_a, the opposite of the B4B5 model.
Expand Down Expand Up @@ -36,7 +36,7 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B1B2.SED, p=p
)

Expand Down
6 changes: 3 additions & 3 deletions src/syncfit/models/b4b5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B4B5(BaseModel):
class B4B5(SyncfitModel):
'''
Two-break model for a combination of the self-absorption break (nu_a) and the
minimal energy break (nu_m). This model requires that nu_m < nu_a, you should
Expand Down Expand Up @@ -37,7 +37,7 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B4B5.SED, p=p
)

Expand Down
6 changes: 3 additions & 3 deletions src/syncfit/models/b4b5b3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B4B5B3(BaseModel):
class B4B5B3(SyncfitModel):
'''
Three-break model using the self-absorption break (nu_a), cooling break (nu_c),
and minimum energy break (nu_m). This model always requires that nu_m < nu_a < nu_c.
Expand Down Expand Up @@ -41,7 +41,7 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_m, log_nu_c):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B4B5B3.SED, p=p
)

Expand Down
6 changes: 3 additions & 3 deletions src/syncfit/models/b5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B5(BaseModel):
class B5(SyncfitModel):
'''
Single break model for just the self-absorption break.
'''
Expand All @@ -30,7 +30,7 @@ def SED(nu, p, log_F_nu, log_nu_a):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B5.SED, p=p
)

Expand Down
6 changes: 3 additions & 3 deletions src/syncfit/models/b5b3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Various models to use in MCMC fitting
'''
import numpy as np
from .base_model import BaseModel
from .syncfit_model import SyncfitModel

class B5B3(BaseModel):
class B5B3(SyncfitModel):
'''
Two-break model that uses both the self-absorption break and the cooling break.
This model forces the cooling break to always be larger than the self-absorption
Expand Down Expand Up @@ -38,7 +38,7 @@ def SED(nu, p, log_F_nu, log_nu_a, log_nu_c):

def lnprior(theta, nu, F, upperlimit, p=None, **kwargs):
''' Priors: '''
uppertest = BaseModel._is_below_upperlimits(
uppertest = SyncfitModel._is_below_upperlimits(
nu, F, upperlimit, theta, B5B3.SED, p=p
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
import numpy as np

class _BaseModelMeta(type):
class _SyncfitModelMeta(type):
'''
This just gives all the subclasses for BaseModel the same docstrings
for the inherited abstract methods
Expand All @@ -18,7 +18,7 @@ def __new__(mcls, classname, bases, cls_dict):
member.__doc__ = getattr(bases[-1], name).__doc__
return cls

class BaseModel(object, metaclass=_BaseModelMeta):
class SyncfitModel(object, metaclass=_SyncfitModelMeta):
'''
An Abstract Base Class to define the basic methods that all syncfit
models must contain. This will help maintain some level of standard for the models
Expand Down Expand Up @@ -184,3 +184,11 @@ def __subclasshook__(cls, C):
if all(any(arg in B.__dict__ for B in C.__mro__) for arg in reqs):
return True
return NotImplemented

# add a register method so users don't have to create a new class
@classmethod
def override(cls,func):
'''
This method should be used as a decorator to override other methods
'''
exec(f'cls.{func.__name__} = func')

0 comments on commit 6381123

Please sign in to comment.