Skip to content

Commit

Permalink
[ENH] Refactor of BaseDistribution and descendants - generalised di…
Browse files Browse the repository at this point in the history
…stribution param broadcasting in base class (#54)

Mirror of sktime sktime/sktime#5176

Refactors the `_get_bc_params` function in the `BaseDistribution` class.
Moves the `_get_bc_params` method from child distributions to the parent distribtion class, `BaseDistribution`.


This implementation is quite simple and doesn't use `_tags` and still means the child distributions have to call the `_get_bc_params` method.
  • Loading branch information
Alex-JG3 authored Aug 31, 2023
1 parent d1b693f commit 1ce6f30
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 34 deletions.
37 changes: 37 additions & 0 deletions skpro/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,43 @@ def _method_error_msg(self, method="this method", severity="warn", fill_in=None)
else:
return msg

def _get_bc_params(self, *args, dtype=None):
"""Fully broadcast tuple of parameters given param shapes and index, columns.
Parameters
----------
args : float, int, array of floats, or array of ints (1D or 2D)
Distribution parameters that are to be made broadcastable. If no positional
arguments are provided, all parameters of `self` are used except for `index`
and `columns`.
dtype : str, optional
broadcasted arrays are cast to all have datatype `dtype`. If None, then no
datatype casting is done.
Returns
-------
Tuple of float or integer arrays
Each element of the tuple represents a different broadcastable distribution
parameter.
"""
number_of_params = len(args)
if number_of_params == 0:
# Handle case where no positional arguments are provided
params = self.get_params()
params.pop("index")
params.pop("columns")
args = tuple(params.values())
number_of_params = len(args)

if hasattr(self, "index") and self.index is not None:
args += (self.index.to_numpy().reshape(-1, 1),)
if hasattr(self, "columns") and self.columns is not None:
args += (self.columns.to_numpy(),)
bc = np.broadcast_arrays(*args)
if dtype is not None:
bc = [array.astype(dtype) for array in bc]
return bc[:number_of_params]

def pdf(self, x):
r"""Probability density function.
Expand Down
12 changes: 1 addition & 11 deletions skpro/distributions/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, mu, scale, index=None, columns=None):
# todo: untangle index handling
# and broadcast of parameters.
# move this functionality to the base class
self._mu, self._scale = self._get_bc_params()
self._mu, self._scale = self._get_bc_params(self.mu, self.scale)
shape = self._mu.shape

if index is None:
Expand All @@ -56,16 +56,6 @@ def __init__(self, mu, scale, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.scale]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1]

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.
Expand Down
12 changes: 1 addition & 11 deletions skpro/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# todo: untangle index handling
# and broadcast of parameters.
# move this functionality to the base class
self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand All @@ -57,16 +57,6 @@ def __init__(self, mu, sigma, index=None, columns=None):

super(Normal, self).__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1]

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.
Expand Down
14 changes: 3 additions & 11 deletions skpro/distributions/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma, self._df = self._get_bc_params()
self._mu, self._sigma, self._df = self._get_bc_params(
self.mu, self.sigma, self.df
)
shape = self._mu.shape

if index is None:
Expand All @@ -56,16 +58,6 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma, self.df]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1], bc[2]

def mean(self):
r"""Return expected value of the distribution.
Expand Down
2 changes: 1 addition & 1 deletion skpro/distributions/tests/test_base_default_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand Down

0 comments on commit 1ce6f30

Please sign in to comment.