Skip to content

Commit

Permalink
feat: extend optional parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
prokolyvakis committed Aug 25, 2023
1 parent 20bbc58 commit 32fc168
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions experiments/synthetic/two_gaussians_mix.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Unimodality hypothesis testing experiments with a mixture of 2D gaussians.
Usage:
two_gaussians_mix.py <pj> <pv> <sims> [--samples=<s> --noise=<n> --seed=<sd>]
two_gaussians_mix.py <pj> <pv> <sims> [--samples=<s> --noise=<n> --seed=<sd> --dist=<ds> --obs=<o> --plot=<f>]
two_gaussians_mix.py -h | --help
Options:
-h --help Show this screen.
--samples=<s> The number of samples [default: 200].
--noise=<n> The standard deviation inside the clusters [default: 0].
--seed=<sd> The seed [default: 42].
--dist=<ds> The type of distance [default: mahalanobis].
--obs=<o> The type of the observer [default: percentile].
--plot=<f> Whether to produce a plot or not [default: False].
"""
import sys
import warnings
Expand All @@ -22,6 +25,7 @@
from hdunim.projections import IdentityProjector
from hdunim.projections import JohnsonLindenstrauss
from hdunim.observer import PercentileObserver
from hdunim.observer import RandomObserver
from hdunim.projections import View
from hdunim.unimodality import UnimodalityTest
from hdunim.unimodality import MonteCarloUnimodalityTest
Expand All @@ -40,8 +44,23 @@
set_seed(SEED)

pt = str(arguments['<pj>'])
p = JohnsonLindenstrauss() if pt == 'jl' else IdentityProjector()
v = View(p, PercentileObserver(0.99))
if pt == 'jl':
p = JohnsonLindenstrauss()
elif pt == 'i':
p = IdentityProjector()
else:
raise ValueError(f'The projection type: {pt} is not supported!')

dt = str(arguments['--dist'])
ot = str(arguments['--obs'])
if ot == 'percentile':
o = PercentileObserver(0.99, dt)
elif ot == 'random':
o = RandomObserver()
else:
raise ValueError(f'The observer type: {ot} is not supported!')

v = View(p, o, dt)
t = UnimodalityTest(v, float(arguments['<pv>']))
mct = MonteCarloUnimodalityTest(
t,
Expand All @@ -58,15 +77,16 @@
)

tr = 'unimodal' if mct.test(g.x) else 'bimodal'
logger.info(f'The statistical test says {tr} and the data were {g.t}!')
msg = dict(arguments)
msg['groundtruth'] = g.t
msg['result'] = tr
msg.pop('--help')
msg['parity'] = int(tr == g.t)

logger.info(
'The inputs and the output of the experiments is: '
f'{msg}'
)

# plot_clustered_data(g.x, g.y)
if eval(arguments['--plot']):
plot_clustered_data(g.x, g.y)

0 comments on commit 32fc168

Please sign in to comment.