Skip to content

Commit

Permalink
Ready for interfacing with Graddnodi
Browse files Browse the repository at this point in the history
  • Loading branch information
CaderIdris committed Nov 1, 2023
1 parent 311c5a9 commit 579dee8
Show file tree
Hide file tree
Showing 7 changed files with 4,657 additions and 4,570 deletions.
7,912 changes: 3,957 additions & 3,955 deletions docs/calidhayte/calibrate.html

Large diffs are not rendered by default.

1,212 changes: 638 additions & 574 deletions docs/calidhayte/graphs.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/search.js

Large diffs are not rendered by default.

50 changes: 24 additions & 26 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,49 @@ anyio==4.0.0; python_version >= '3.8'
argon2-cffi==23.1.0; python_version >= '3.7'
argon2-cffi-bindings==21.2.0; python_version >= '3.6'
arrow==1.3.0; python_version >= '3.8'
asttokens==2.4.0
asttokens==2.4.1
async-lru==2.0.4; python_version >= '3.8'
attrs==23.1.0; python_version >= '3.7'
babel==2.13.0; python_version >= '3.7'
backcall==0.2.0
babel==2.13.1; python_version >= '3.7'
beautifulsoup4==4.12.2; python_full_version >= '3.6.0'
bleach==6.1.0; python_version >= '3.8'
cachetools==5.3.1; python_version >= '3.7'
cachetools==5.3.2; python_version >= '3.7'
certifi==2023.7.22; python_version >= '3.6'
cffi==1.16.0; python_version >= '3.8'
chardet==5.2.0; python_version >= '3.7'
charset-normalizer==3.3.1; python_full_version >= '3.7.0'
charset-normalizer==3.3.2; python_full_version >= '3.7.0'
colorama==0.4.6; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
comm==0.1.4; python_version >= '3.6'
coverage[toml]==7.3.2; python_version >= '3.8'
debugpy==1.8.0; python_version >= '3.8'
decorator==5.1.1; python_version >= '3.5'
defusedxml==0.7.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
distlib==0.3.7
executing==2.0.0
executing==2.0.1; python_version >= '3.5'
fastjsonschema==2.18.1
filelock==3.12.4; python_version >= '3.8'
filelock==3.13.1; python_version >= '3.8'
flake8==6.1.0; python_full_version >= '3.8.1'
fqdn==1.5.1
idna==3.4; python_version >= '3.5'
iniconfig==2.0.0; python_version >= '3.7'
ipykernel==6.25.2; python_version >= '3.8'
ipython==8.16.1; python_version >= '3.9'
ipykernel==6.26.0; python_version >= '3.8'
ipython==8.17.2; python_version >= '3.9'
ipython-genutils==0.2.0
ipywidgets==8.1.1; python_version >= '3.7'
isoduration==20.11.0
jedi==0.19.1; python_version >= '3.6'
jinja2==3.1.2; python_version >= '3.7'
json5==0.9.14
jsonpointer==2.4
jsonschema[format-nongpl]==4.19.1; python_version >= '3.8'
jsonschema[format-nongpl]==4.19.2; python_version >= '3.8'
jsonschema-specifications==2023.7.1; python_version >= '3.8'
jupyter==1.0.0
jupyter-client==8.4.0; python_version >= '3.8'
jupyter-client==8.5.0; python_version >= '3.8'
jupyter-console==6.6.3; python_version >= '3.7'
jupyter-core==5.4.0; python_version >= '3.8'
jupyter-core==5.5.0; python_version >= '3.8'
jupyter-events==0.8.0; python_version >= '3.8'
jupyter-lsp==2.2.0; python_version >= '3.8'
jupyter-server==2.8.0; python_version >= '3.8'
jupyter-server==2.9.1; python_version >= '3.8'
jupyter-server-terminals==0.4.4; python_version >= '3.8'
jupyterlab==4.0.7; python_version >= '3.8'
jupyterlab-pygments==0.2.2; python_version >= '3.7'
Expand All @@ -59,7 +58,7 @@ mistune==3.0.2; python_version >= '3.7'
mypy==1.6.1; python_version >= '3.8'
mypy-extensions==1.0.0; python_version >= '3.5'
nbclient==0.8.0; python_full_version >= '3.8.0'
nbconvert==7.9.2; python_version >= '3.8'
nbconvert==7.10.0; python_version >= '3.8'
nbformat==5.9.2; python_version >= '3.8'
nest-asyncio==1.5.8; python_version >= '3.5'
notebook==7.0.6; python_version >= '3.8'
Expand All @@ -70,10 +69,9 @@ pandocfilters==1.5.0; python_version >= '2.7' and python_version not in '3.0, 3.
parso==0.8.3; python_version >= '3.6'
pdoc==14.1.0; python_version >= '3.8'
pexpect==4.8.0; sys_platform != 'win32'
pickleshare==0.7.5
platformdirs==3.11.0; python_version >= '3.7'
pluggy==1.3.0; python_version >= '3.8'
prometheus-client==0.17.1; python_version >= '3.6'
prometheus-client==0.18.0; python_version >= '3.8'
prompt-toolkit==3.0.39; python_full_version >= '3.7.0'
psutil==5.9.6; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
ptyprocess==0.7.0
Expand All @@ -83,7 +81,7 @@ pycparser==2.21
pyflakes==3.1.0; python_version >= '3.8'
pygments==2.16.1; python_version >= '3.7'
pyproject-api==1.6.1; python_version >= '3.8'
pytest==7.4.2; python_version >= '3.7'
pytest==7.4.3; python_version >= '3.7'
pytest-cov==4.1.0; python_version >= '3.7'
pytest-html==4.0.2; python_version >= '3.8'
pytest-metadata==3.0.0; python_version >= '3.7'
Expand All @@ -92,7 +90,7 @@ python-json-logger==2.0.7; python_version >= '3.6'
pyyaml==6.0.1; python_version >= '3.6'
pyzmq==25.1.1; python_version >= '3.6'
qtconsole==5.4.4; python_version >= '3.7'
qtpy==2.4.0; python_version >= '3.7'
qtpy==2.4.1; python_version >= '3.7'
referencing==0.30.2; python_version >= '3.8'
requests==2.31.0; python_version >= '3.7'
rfc3339-validator==0.1.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
Expand All @@ -107,19 +105,19 @@ terminado==0.17.1; python_version >= '3.7'
tinycss2==1.2.1; python_version >= '3.7'
tornado==6.3.3; python_version >= '3.8'
tox==4.11.3; python_version >= '3.8'
traitlets==5.11.2; python_version >= '3.8'
traitlets==5.13.0; python_version >= '3.8'
types-python-dateutil==2.8.19.14
typing-extensions==4.8.0; python_version >= '3.8'
uri-template==1.3.0
urllib3==2.0.7; python_version >= '3.7'
virtualenv==20.24.5; python_version >= '3.7'
wcwidth==0.2.8
virtualenv==20.24.6; python_version >= '3.7'
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4; python_version >= '3.8'
widgetsnbextension==4.0.9; python_version >= '3.7'
arviz==0.16.1; python_version >= '3.9'
bambi==0.12.0; python_version >= '3.8'
bambi==0.13.0; python_version >= '3.8'
-e .
cloudpickle==3.0.0; python_version >= '3.8'
cons==0.4.6; python_version >= '3.6'
Expand All @@ -136,14 +134,14 @@ joblib==1.3.2; python_version >= '3.7'
kiwisolver==1.4.5; python_version >= '3.7'
llvmlite==0.41.1; python_version >= '3.8'
logical-unification==0.4.6; python_version >= '3.6'
matplotlib==3.8.0; python_version >= '3.9'
matplotlib==3.8.1; python_version >= '3.9'
minikanren==1.0.3; python_version >= '3.6'
multipledispatch==1.0.0
numba==0.58.1; python_version >= '3.8'
numpy==1.25.2; python_version >= '3.9'
pandas==2.1.1; python_version >= '3.9'
pandas==2.1.2; python_version >= '3.9'
pillow==10.1.0; python_version >= '3.8'
pymc==5.9.0; python_version >= '3.9'
pymc==5.9.1; python_version >= '3.9'
pyparsing==3.1.1; python_full_version >= '3.6.8'
pytensor==2.17.3; python_version < '3.12' and python_version >= '3.9'
pytz==2023.3.post1
Expand All @@ -158,4 +156,4 @@ tqdm==4.66.1; python_version >= '3.7'
tzdata==2023.3; python_version >= '2'
xarray==2023.10.1; python_version >= '3.9'
xarray-einstats==0.6.0; python_version >= '3.9'
xgboost==2.0.0; python_version >= '3.8'
xgboost==2.0.1; python_version >= '3.8'
6 changes: 3 additions & 3 deletions src/calidhayte/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import scipy
from scipy.stats import uniform
import sklearn as skl
from sklearn import cross_decomposition as cd
from sklearn import ensemble as en
from sklearn import gaussian_process as gp
from sklearn import isotonic as iso
Expand Down Expand Up @@ -1445,8 +1444,9 @@ def ransac(
]
] = {
'estimator': [
lm.LinearRegression()
# TODO: ADD
lm.LinearRegression(),
lm.TheilSenRegressor(),
lm.LassoLarsCV()
]
},
**kwargs
Expand Down
44 changes: 33 additions & 11 deletions src/calidhayte/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,40 @@ def lin_reg_plot(self, title=None):
def shap(self, pipeline_keys: list[str], title=None):
x = self.x
y = self.y
pipeline = self.models[pipeline_keys[0]][pipeline_keys[1]][pipeline_keys[2]]

pipeline = self.models[
pipeline_keys[0]
][
pipeline_keys[1]
][
pipeline_keys[2]
]

if not self.plots.get(pipeline_keys[0]):
self.plots[pipeline_keys[0]] = dict()
if not self.plots[pipeline_keys[0]].get(pipeline_keys[1]):
self.plots[pipeline_keys[0]][pipeline_keys[1]] = dict()
if not self.plots[pipeline_keys[0]][pipeline_keys[1]].get(pipeline_keys[2]):
self.plots[pipeline_keys[0]][pipeline_keys[1]][pipeline_keys[2]] = dict()
if not self.plots[
pipeline_keys[0]
][
pipeline_keys[1]
].get(pipeline_keys[2]):

self.plots[
pipeline_keys[0]
][
pipeline_keys[1]][pipeline_keys[2]] = dict()
with plt.rc_context({'backend': self.backend}), \
plt.style.context(self.style):
shap_df = get_shap(x, y, pipeline)
self.plots[pipeline_keys[0]][pipeline_keys[1]][pipeline_keys[2]]['Shap'] = shap_plot(shap_df, x)


self.plots[
pipeline_keys[0]
][
pipeline_keys[1]
][
pipeline_keys[2]
][
'Shap'
] = shap_plot(shap_df, x)

def save_plots(
self,
Expand Down Expand Up @@ -408,6 +428,7 @@ def ecdf_plot(
fig.suptitle(title)
return fig


def shap_plot(shaps: pd.DataFrame, x: pd.DataFrame):
"""
"""
Expand Down Expand Up @@ -470,11 +491,12 @@ def shap_plot(shaps: pd.DataFrame, x: pd.DataFrame):
plt.tight_layout()
return fig


def get_shap(
x: pd.DataFrame,
y: pd.DataFrame,
pipeline: dict[int, Pipeline]
):
x: pd.DataFrame,
y: pd.DataFrame,
pipeline: dict[int, Pipeline]
):
shaps = pd.DataFrame()
for fold in pipeline.keys():
if len(pipeline.keys()) > 1:
Expand Down
1 change: 1 addition & 0 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_bland_altman(
results.bland_altman_plot()
results.save_plots('.tmp/tests')


@pytest.mark.plots
def test_shap(
trained_models
Expand Down

0 comments on commit 579dee8

Please sign in to comment.