Skip to content

Commit

Permalink
Test suite tidy (handley-lab#319)
Browse files Browse the repository at this point in the history
* Removed fastkde from test_samples

* bump version to 2.1.1

* Added bump version script for convenience

* Moved astropy to skip rather than fail, and marked all fastkde

* removed erroneous kind check

* Added utils

* More updates for other conditional skipping

* Changed getdist to be an importerror

* Moved importerrors into the function itself

* bump version to 2.1.2

* update merged changes to new `fastkde_skipif` functionality

* Corrected merge of fastkde keys

* Caught importerror

* Renamed skipif functions

* Converted xfails to xskips

* Reverted tests

* Reordered fastkde

* converted xfails to skips

* Removed final xskips

* reinstated xskips

---------

Co-authored-by: Lukas Hergt <[email protected]>
  • Loading branch information
williamjameshandley and lukashergt authored Aug 4, 2023
1 parent 4a35d6e commit dc15c5b
Show file tree
Hide file tree
Showing 14 changed files with 440 additions and 391 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.1.4
:Version: 2.1.5
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.4'
__version__ = '2.1.5'
5 changes: 4 additions & 1 deletion anesthetic/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def to_getdist(samples):
getdist_samples : :class:`getdist.mcsamples.MCSamples`
getdist equivalent samples
"""
import getdist
try:
import getdist
except ModuleNotFoundError:
raise ImportError("You need to install getdist to use to_getdist")
labels = np.char.strip(samples.get_labels().astype(str), '$')
samples = samples.drop_labels()
ranges = samples.agg(['min', 'max']).T.apply(tuple, axis=1).to_dict()
Expand Down
17 changes: 6 additions & 11 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
from scipy.special import erf
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.axes import Axes
try:
from astropy.visualization import hist
except ImportError:
pass
try:
from anesthetic.kde import fastkde_1d, fastkde_2d
except ImportError:
pass
import matplotlib.cbook as cbook
import matplotlib.lines as mlines
from matplotlib.ticker import MaxNLocator, AutoMinorLocator
Expand Down Expand Up @@ -768,8 +760,9 @@ def fastkde_plot_1d(ax, data, *args, **kwargs):
q = quantile_plot_interval(q=q)

try:
from anesthetic.kde import fastkde_1d
x, p, xmin, xmax = fastkde_1d(data, xmin, xmax)
except NameError:
except ImportError:
raise ImportError("You need to install fastkde to use fastkde")
p /= p.max()
i = ((x > quantile(x, q[0], p)) & (x < quantile(x, q[-1], p)))
Expand Down Expand Up @@ -971,10 +964,11 @@ def hist_plot_1d(ax, data, *args, **kwargs):

if isinstance(bins, str) and bins in ['knuth', 'freedman', 'blocks']:
try:
from astropy.visualization import hist
h, edges, bars = hist(data, ax=ax, bins=bins,
range=range, histtype=histtype,
color=color, *args, **kwargs)
except NameError:
except ImportError:
raise ImportError("You need to install astropy to use astropyhist")
else:
h, edges, bars = ax.hist(data, weights=weights, bins=bins,
Expand Down Expand Up @@ -1048,10 +1042,11 @@ def fastkde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs):
kwargs.pop('q', None)

try:
from anesthetic.kde import fastkde_2d
x, y, pdf, xmin, xmax, ymin, ymax = fastkde_2d(data_x, data_y,
xmin=xmin, xmax=xmax,
ymin=ymin, ymax=ymax)
except NameError:
except ImportError:
raise ImportError("You need to install fastkde to use fastkde")

levels = iso_probability_contours(pdf, contours=levels)
Expand Down
8 changes: 4 additions & 4 deletions anesthetic/read/ultranest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
import os
import json
from anesthetic.samples import NestedSamples
try:
import h5py
except ImportError:
pass


def read_ultranest(root, *args, **kwargs):
Expand All @@ -23,6 +19,10 @@ def read_ultranest(root, *args, **kwargs):
num_params = len(labels)

filepath = os.path.join(root, 'results', 'points.hdf5')
try:
import h5py
except ImportError:
raise ImportError('h5py is required to read UltraNest results')
with h5py.File(filepath, 'r') as fileobj:
points = fileobj['points']
_, ncols = points.shape
Expand Down
38 changes: 38 additions & 0 deletions bin/bump_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python
from utils import run
from packaging import version
import sys

vfile = "anesthetic/_version.py"
README = "README.rst"

current_version = run("cat", vfile)
current_version = current_version.split("=")[-1].strip().strip("'")
current_version = version.parse(current_version)

if len(sys.argv) > 1:
update_type = sys.argv[1]
else:
update_type = "micro"

major = current_version.major
minor = current_version.minor
micro = current_version.micro

if update_type == "micro":
micro+=1
elif update_type == "minor":
minor+=1
micro=0
elif update_type == "major":
major+=1
minor=0
micro=0

new_version = version.parse(f"{major}.{minor}.{micro}")

for f in [vfile, README]:
run("sed", "-i", f"s/{current_version}/{new_version}/g", f)

run("git", "add", vfile, README)
run("git", "commit", "-m", f"bump version to {new_version}")
10 changes: 2 additions & 8 deletions bin/check_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
import sys
import subprocess
from packaging import version
from utils import unit_incremented
from utils import unit_incremented, run

vfile = "anesthetic/_version.py"
README = "README.rst"


def run(*args):
return subprocess.run(args, text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).stdout


current_version = run("cat", vfile)
current_version = current_version.split("=")[-1].strip().strip("'")

Expand Down
8 changes: 8 additions & 0 deletions bin/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from packaging import version
import subprocess


def run(*args):
"""Run a bash command and return the output in Python."""
return subprocess.run(args, text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).stdout


def unit_incremented(a, b):
Expand Down
21 changes: 10 additions & 11 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from anesthetic import read_chains
from anesthetic.convert import to_getdist
from numpy.testing import assert_array_equal
from utils import getdist_mark_xfail


@getdist_mark_xfail
def test_to_getdist():
try:
anesthetic_samples = read_chains('./tests/example_data/gd')
getdist_samples = to_getdist(anesthetic_samples)
anesthetic_samples = read_chains('./tests/example_data/gd')
getdist_samples = to_getdist(anesthetic_samples)

assert_array_equal(getdist_samples.samples, anesthetic_samples)
assert_array_equal(getdist_samples.weights,
anesthetic_samples.get_weights())
assert_array_equal(getdist_samples.samples, anesthetic_samples)
assert_array_equal(getdist_samples.weights,
anesthetic_samples.get_weights())

for param, p in zip(getdist_samples.getParamNames().names,
anesthetic_samples.drop_labels().columns):
for param, p in zip(getdist_samples.getParamNames().names,
anesthetic_samples.drop_labels().columns):

assert param.name == p
except ModuleNotFoundError:
pass
assert param.name == p
18 changes: 4 additions & 14 deletions tests/test_gui.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import sys
import anesthetic.examples._matplotlib_agg # noqa: F401
from anesthetic import read_chains
import pytest
import pandas._testing as tm
try:
import h5py # noqa: F401
except ImportError:
pass
from utils import skipif_no_h5py


@pytest.fixture(autouse=True)
Expand All @@ -16,10 +12,8 @@ def close_figures_on_teardown():

@pytest.mark.parametrize('root', ["./tests/example_data/pc",
"./tests/example_data/mn",
"./tests/example_data/un"])
skipif_no_h5py("./tests/example_data/un")])
def test_gui(root):
if 'un' in root and 'h5py' not in sys.modules:
pytest.skip("`h5py` not in sys.modules, but needed for ultranest.")
samples = read_chains(root)
plotter = samples.gui()

Expand Down Expand Up @@ -65,10 +59,8 @@ def test_gui(root):

@pytest.mark.parametrize('root', ["./tests/example_data/pc",
"./tests/example_data/mn",
"./tests/example_data/un"])
skipif_no_h5py("./tests/example_data/un")])
def test_gui_params(root):
if 'un' in root and 'h5py' not in sys.modules:
pytest.skip("`h5py` not in sys.modules, but needed for ultranest.")
samples = read_chains(root)
params = samples.columns.get_level_values(0).to_list()
plotter = samples.gui()
Expand All @@ -80,10 +72,8 @@ def test_gui_params(root):

@pytest.mark.parametrize('root', ["./tests/example_data/pc",
"./tests/example_data/mn",
"./tests/example_data/un"])
skipif_no_h5py("./tests/example_data/un")])
def test_slider_reset_range(root):
if 'un' in root and 'h5py' not in sys.modules:
pytest.skip("`h5py` not in sys.modules, but needed for ultranest.")
plotter = read_chains(root).gui()
plotter.evolution.reset_range(-3, 3)
assert plotter.evolution.ax.get_xlim() == (-3.0, 3.0)
Loading

0 comments on commit dc15c5b

Please sign in to comment.