Skip to content

Commit

Permalink
more features for the gam seasonality
Browse files Browse the repository at this point in the history
  • Loading branch information
hagne committed Mar 4, 2024
1 parent 6c93d1c commit aeb7bd8
Showing 1 changed file with 224 additions and 36 deletions.
260 changes: 224 additions & 36 deletions atmPy/general/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib.colors as _mcolors
#import plt_tools as _plt_tools
from atmPy.tools import array_tools as _array_tools
import atmPy.tools.plt_tool_kit.colors as _atmcols
# from pygam import LinearGAM, pygam_s, pygam_l

import datetime
Expand All @@ -16,6 +17,8 @@
import matplotlib.lines as _mpllines
import matplotlib.dates as _mpldates

import scipy.stats as scistats


class Statistics(object):
def __init__(self, parent_ts):
Expand Down Expand Up @@ -55,6 +58,8 @@ def __init__(self, parent_stats = None):
self._seasonality_nsplines = None
self._trend_lam = None
self._trend_nsplines = None
self._prediction_grid_size = None
self._prediction_confidence = None

self._rerun()

Expand All @@ -63,7 +68,33 @@ def _rerun(self):
self._prediction = None
self._gam = None

@property
def prediction_confidence(self):
if isinstance(self._prediction_confidence, type(None)):
self._prediction_confidence = 0.95

return self._prediction_confidence

@prediction_confidence.setter
def prediction_confidence(self,value):
assert(0<value<1)
self._prediction_confidence = value
self._prediction = None
return

@property
def prediction_grid_size(self):
if isinstance(self._prediction_grid_size, type(None)):
self._prediction_grid_size = 100
return self._prediction_grid_size

@prediction_grid_size.setter
def prediction_grid_size(self, value):
self._prediction_grid_size = value
self._prediction = None
return


@property
def splines_per_year(self):
return self._splines_per_year
Expand Down Expand Up @@ -146,6 +177,7 @@ def data(self, data_column = None):#, linear = False, resolution = 1e5):

row = data.iloc[0,:]
start_date = row.name
self._start_date = start_date
Xdf['dsincestart'] = data.apply(lambda row: (row.name - start_date )/ datetime.timedelta(days = 1), axis = 1)
Xdf['doy'] = data.apply(lambda row: (row.name - _pd.to_datetime(row.name.year, format= '%Y')) / datetime.timedelta(days = 1), axis = 1)
# if sun:
Expand Down Expand Up @@ -192,8 +224,11 @@ def gam_inst(self, linear = False, ):
else:
year = pygam.s(0, lam = self.trend_lambda, n_splines=self.trend_nsplines)#int(n_splines))

doy = pygam.s(1, basis = 'cp', lam = self.seasonality_lambda, n_splines= self.seasonality_nsplines)
self._gam = pygam.GAM(year + doy , distribution = self.distribution, link = self.link)
doy = pygam.s(1, basis = 'cp', lam = self.seasonality_lambda, n_splines= self.seasonality_nsplines)
if self.distribution == 'normal':
self._gam = pygam.LinearGAM(year + doy)
else:
self._gam = pygam.GAM(year + doy , distribution = self.distribution, link = self.link)
return self._gam

@property
Expand Down Expand Up @@ -228,6 +263,18 @@ def fit_res(self, data_column = None,
self._fit_res = self.gam_inst.fit(X, y) #this is just a nother handle to the very same gam instance from above!! If the lambda optimization is run it still connects to the same gam instance!!
return self._fit_res

def get_quantile(self, q = None, std = None):
if isinstance(q, type(None)):
if isinstance(std, type(None)):
raise ValueError('Dude, at last one needs to be given!')
else:
q = scistats.norm.cdf(std)
XX = self.data.iloc[:,1:].values
pred = self.gam_inst.prediction_intervals(XX, quantiles=q)
dindex = self._start_date + _pd.to_timedelta(XX[:,0], 'd')
da = _xr.DataArray(pred[:,0], coords = {'datetime_full': dindex})
return da

@property
def prediction(self):
if isinstance(self._prediction, type(None)):
Expand All @@ -237,8 +284,10 @@ def prediction(self):
gam = self.fit_res
for i,col in enumerate(Xdf.columns):
term = gam.terms[i]
XX = gam.generate_X_grid(term=i)
pdep, confi = gam.partial_dependence(term=i, X=XX, width=0.95)
XX = gam.generate_X_grid(term=i, n = self.prediction_grid_size)
pdep, confi = gam.partial_dependence(term=i, X=XX, width=self.prediction_confidence)
self.tp_confi = confi
self.tp_pdep = pdep
dsincestart_feat = XX[:, term.feature]
colname = Xdf.columns[i]
if colname == 'dsincestart':
Expand All @@ -248,26 +297,22 @@ def prediction(self):
else:
index = dsincestart_feat
# print(pdep.shape)
self.tp_colname = colname
self.tp_index = index
ds[f'partial_{colname}'] = _xr.DataArray(pdep, coords={colname:index})
# return ds
# return pdep, dsincestart_feat
# if i != 0:
# df = _pd.DataFrame(pdep, index = dsincestart_feat)
# else:
# # assert(not isinstance(Xdf, type(None))), 'Xdf needed if i==0'
# idxl = []
# for val in dsincestart_feat:
# idxl.append(Xdf.dsincestart.sub(val).abs().idxmin())
# tidx = Xdf.loc[idxl].index
# df = _pd.DataFrame(pdep, index = tidx)
# df.index.name = 'datetime'
# df.index = df.index.tz_localize(None)
ds[f'partial_{colname}_confidence'] = _xr.DataArray(confi, coords={colname:index,
'boundary': ['top', 'bottom']})

# # if col:
# df.columns = [Xdf.columns[i]]
# if i == 1:
# df.index.name = 'day of year'
self._prediction = ds
XX = self.data.iloc[:,1:].values
pred = gam.predict(XX)
dindex = self._start_date + _pd.to_timedelta(XX[:,0], 'd')
ds['prediction'] = _xr.DataArray(pred, coords = {'datetime_full': dindex})

if hasattr(gam, 'prediction_intervals'):
pred = gam.prediction_intervals(XX)
ds['prediction_confidence'] = _xr.DataArray(pred, coords = {'datetime_full': dindex,
'boundary': ['top', 'bottom']})
self._prediction = ds
return self._prediction

@prediction.setter
Expand All @@ -277,6 +322,7 @@ def prediction(self, value):
def plot_seasonality(self, ax = None,
offset = 0,
xticklablesmonth = True,
show_confidence = True,
**plot_kwargs):

if isinstance(ax, type(None)):
Expand All @@ -290,6 +336,12 @@ def plot_seasonality(self, ax = None,

(self.prediction.partial_doy + offset).plot(ax = a, **plot_kwargs)

if show_confidence:
bot,up = self.prediction.partial_doy_confidence.values.transpose()
fill = a.fill_between(self.prediction.doy,bot,up, color = [0,0,0,0.3])
else:
fill = None

if xticklablesmonth:
a.xaxis.set_major_locator(mdates.MonthLocator()) # Set the major ticks to be at the beginning of each month
a.xaxis.set_minor_locator(mdates.WeekdayLocator()) # Set the minor ticks to be at the beginning of each week
Expand All @@ -298,27 +350,154 @@ def plot_seasonality(self, ax = None,
a.xaxis.set_minor_locator(_plt.NullLocator())
a.set_xlabel('')
# a.set_xlabel('Day of year')
return f,a
return f,a,fill

def plot_trend(self, ax = None,
shade_seasons = True,
offset= 0, **plot_kwargs):
def plot_prediction(self, ax = None, show_confidence=True, show_original_data=True, **plot_kwargs):
if isinstance(ax, type(None)):
f, a = _plt.subplots()
else:
a = ax
f = a.get_figure()


if 'label' not in plot_kwargs:
plot_kwargs['label'] = 'prediction'


self.prediction.prediction.plot(ax = a, zorder = 3, label = 'prediction')
if show_confidence:
a.fill_between(self.prediction.datetime_full,
self.prediction.prediction_confidence.sel(boundary='top'),
self.prediction.prediction_confidence.sel(boundary='bottom'),
alpha=0.5, zorder = 2, color = '0.5', label = 'confidence')
if show_original_data:
data = self._parent_stats._parent_ts.data
a.plot(data.index, data.iloc[:,0], ls = '', marker = '.', markersize = 1, zorder = 1, label = 'observation', color = '0.7')
return f,a

def plot_overview(self, axis = None, show_confidence = True, show_original_data = True):
if isinstance(axis, type(None)):
aa = []
f = _plt.figure()
f.set_figheight(f.get_figheight() * 1.5)
aa.append(f.add_subplot(3,1,1))
aa.append(f.add_subplot(3,1,2, sharex = aa[0]))
aa.append(f.add_subplot(3,1,3))
shade_seasons = {'color': '#02401A', 'alpha': 0.5}
else:
aa = axis
f = aa[0].get_figure()
shade_seasons = False

xshift = 0

self.plot_prediction(ax = aa[0], show_confidence=show_confidence, show_original_data=show_original_data)
self.plot_trend(ax=aa[1], shade_seasons=shade_seasons, show_confidence=show_confidence)
self.plot_seasonality(ax=aa[2], show_confidence = show_confidence)


# a = aa[0]
# fillb = a.get_children()[2]
# fillb.set_color('0.3')
# fillb.set_zorder(10)
for e,a in enumerate(aa):
if e == 0:
continue
poslast = aa[e-1].get_position().bounds
pos = list(a.get_position().bounds)
if e == 1:
xshift += poslast[1] - (pos[1] + pos[3])
pos[1] = pos[1] + xshift
a.set_position(pos)
aa[0].legend(fontsize = 'small')

return f,aa

def plot_trend(self, ax = None,
shade_seasons = False,
show_confidence = True,
offset= 0, **plot_kwargs):
"""
Parameters
----------
ax : TYPE, optional
DESCRIPTION. The default is None.
shade_seasons : bool or dict, optional
If to shade the seasons. If dict is provided they will be used as axvspan kwargs, e.g. color or alpha. The default is False.
show_confidence : TYPE, optional
DESCRIPTION. The default is True.
offset : TYPE, optional
DESCRIPTION. The default is 0.
**plot_kwargs : TYPE
DESCRIPTION.
Returns
-------
f : TYPE
DESCRIPTION.
a : TYPE
DESCRIPTION.
"""

if isinstance(shade_seasons, dict):
shade_kwargs = shade_seasons
shade_seasons = True
else:
shade_kwargs = {}

if not 'color' in shade_kwargs:
shade_kwargs['color'] = '0.5'

if 'label' not in plot_kwargs:
plot_kwargs['label'] = 'seasonal'

if 'zorder' not in plot_kwargs:
plot_kwargs['zorder'] = 3

if isinstance(ax, type(None)):
f, a = _plt.subplots()

else:
a = ax
f = a.get_figure()


(self.prediction.partial_datetime + offset).plot(ax = a, **plot_kwargs)



if show_confidence:
a.fill_between(self.prediction.datetime,
self.prediction.partial_datetime_confidence.sel(boundary='top'),
self.prediction.partial_datetime_confidence.sel(boundary='bottom'),
alpha=0.5, zorder = 2, color = '0.5', label = 'confidence')


if shade_seasons:
seasons = [3,6,9,12]
colorlam = lambda x: x/12 - 0.1
for year in range(_pd.Timestamp(self.prediction.datetime.min().values).year, _pd.Timestamp(self.prediction.datetime.max().values).year + 1):
for month in seasons:
# colorlam = lambda x: x/12 - 0.1

coll = []
col = _atmcols.Color(shade_kwargs['color'],
# colors[0]
# model = 'hex'
)
sats = _np.linspace(col.saturation, 0 + 0.0, 4)
brits = _np.linspace(col.brightness, 1 - 0.0, 4)

for brt,sat in zip(brits, sats):
col.saturation = sat
col.brightness = brt
coll.append(col.rgb)

shade_kwargs.pop('color')
start_year = _pd.Timestamp(self.prediction.datetime.min().values).year
end_year = _pd.Timestamp(self.prediction.datetime.max().values).year + 1
for year in range(start_year, end_year):
for e, month in enumerate(seasons):
# start = _pd.Timestamp(year,month,1)
# end = start + _pd.DateOffset(month = 3)
# print(f'{start}, {end}')
Expand All @@ -327,21 +506,30 @@ def plot_trend(self, ax = None,
end = start + _pd.to_timedelta(31*3, 'd')
end = _pd.Timestamp(end.year, end.month, 1)

a.axvspan(start, end, color = f'{colorlam(month)}', lw = 0)
# a.axvspan(start, end, color = f'{colorlam(month)}', lw = 0)
a.axvspan(start, end, color = coll[e], lw = 0, **shade_kwargs)


a.set_xlim(self.prediction.datetime.min(), self.prediction.datetime.max())

custom_lines = [_mpllines.Line2D([0], [0], color=f'{colorlam(seasons[0])}', lw=4),
_mpllines.Line2D([0], [0], color=f'{colorlam(seasons[1])}', lw=4),
_mpllines.Line2D([0], [0], color=f'{colorlam(seasons[2])}', lw=4),
_mpllines.Line2D([0], [0], color=f'{colorlam(seasons[3])}', lw=4),
# custom_lines = [_mpllines.Line2D([0], [0], color=f'{colorlam(seasons[0])}', lw=4),
# _mpllines.Line2D([0], [0], color=f'{colorlam(seasons[1])}', lw=4),
# _mpllines.Line2D([0], [0], color=f'{colorlam(seasons[2])}', lw=4),
# _mpllines.Line2D([0], [0], color=f'{colorlam(seasons[3])}', lw=4),
# ]
custom_lines = [_mpllines.Line2D([0], [0], color=coll[0], lw=4, **shade_kwargs),
_mpllines.Line2D([0], [0], color=coll[1], lw=4, **shade_kwargs),
_mpllines.Line2D([0], [0], color=coll[2], lw=4, **shade_kwargs),
_mpllines.Line2D([0], [0], color=coll[3], lw=4, **shade_kwargs),
]

a.legend(custom_lines, ['spring', 'summer', 'fall', 'winter'])

a.xaxis.set_major_locator(_mpldates.YearLocator())
noy = end_year - start_year
base = int(_np.ceil(noy/10))
a.xaxis.set_major_locator(_mpldates.YearLocator(base=base))
a.xaxis.set_major_formatter(_mpldates.DateFormatter('%Y'))
a.xaxis.set_minor_locator(_mpldates.MonthLocator())
a.xaxis.set_minor_locator(_mpldates.MonthLocator(interval = base))
a.xaxis.set_tick_params(reset = True)
a.xaxis.tick_bottom()
a.set_xlabel('')
Expand Down

0 comments on commit aeb7bd8

Please sign in to comment.