From fc6e9dbc80d4d68d83481d2afa7337dd1ae1cf93 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:23:56 +0100 Subject: [PATCH 01/46] Style tweaks to all files --- dataset.py | 29 ++++++------ models.py | 100 ++++++++++++++++++++-------------------- plot.py | 123 ++++++++++++++++++++++++------------------------- rfretrieval.py | 65 +++++++++++++------------- utils.py | 22 ++++----- wpercentile.py | 29 ++++++------ 6 files changed, 178 insertions(+), 190 deletions(-) diff --git a/dataset.py b/dataset.py index 727db15..00f5cb9 100644 --- a/dataset.py +++ b/dataset.py @@ -1,4 +1,3 @@ - import os import json import logging @@ -10,46 +9,44 @@ logger = logging.getLogger(__name__) - Dataset = namedtuple("Dataset", ["training_x", "training_y", "testing_x", "testing_y", "names", "ranges", "colors"]) def load_data_file(data_file, num_features): - data = np.load(data_file) - + if data.ndim == 1: data = data[None, :] - + x = data[:, :num_features] y = data[:, num_features:] - + return x, y def load_dataset(dataset_file): - with open(dataset_file, "r") as f: dataset_info = json.load(f) - + metadata = dataset_info["metadata"] - + base_path = os.path.dirname(dataset_file) - + # Load training data - training_file = os.path.join(base_path, dataset_info["training_data"]) + training_file = os.path.join(base_path, dataset_info["training_data"]) logger.debug("Loading training data from '{}'...".format(training_file)) - training_x, training_y = load_data_file(training_file, metadata["num_features"]) - + training_x, training_y = load_data_file(training_file, + metadata["num_features"]) + # Optionally, load testing data testing_x, testing_y = None, None if dataset_info["testing_data"] is not None: testing_file = os.path.join(base_path, dataset_info["testing_data"]) logger.debug("Loading testing data from '{}'...".format(testing_file)) - testing_x, testing_y = load_data_file(testing_file, metadata["num_features"]) - + testing_x, testing_y = load_data_file(testing_file, + metadata["num_features"]) + return Dataset(training_x, training_y, testing_x, testing_y, metadata["names"], metadata["ranges"], metadata["colors"]) - diff --git a/models.py b/models.py index d9ce96f..de13a8b 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,3 @@ - import logging from collections import namedtuple @@ -11,7 +10,8 @@ try: from tqdm import tqdm except ImportError: - def tqdm(x, *_, **__): return x + def tqdm(x, *_, **__): + return x __all__ = [ "Model", @@ -24,21 +24,20 @@ def tqdm(x, *_, **__): return x # Posteriors are represented as a collection of weighted samples Posterior = namedtuple("Posterior", ["samples", "weights"]) + def resample_posterior(posterior, num_draws): - p = posterior.weights / posterior.weights.sum() indices = np.random.choice(len(posterior.samples), size=num_draws, p=p) - + new_weights = np.bincount(indices, minlength=len(posterior.samples)) mask = new_weights != 0 new_samples = posterior.samples[mask] new_weights = posterior.weights[mask] - + return Posterior(new_samples, new_weights) class Model: - def __init__(self, num_trees, num_jobs, names, ranges, colors, enable_posterior=True, verbose=1): @@ -49,72 +48,73 @@ def __init__(self, num_trees, num_jobs, n_jobs=num_jobs, max_features="sqrt", min_impurity_decrease=0.01) - + self.scaler = scaler self.rf = rf - + self.num_trees = num_trees self.num_jobs = num_jobs self.verbose = verbose - + self.ranges = ranges self.names = names self.colors = colors - + # To compute the posteriors self.enable_posterior = enable_posterior self.data_leaves = None self.data_weights = None self.data_y = None - + def _scaler_fit(self, y): if y.ndim == 1: y = y[:, None] - + self.scaler.fit(y) - + def _scaler_transform(self, y): if y.ndim == 1: y = y[:, None] return self.scaler.transform(y)[:, 0] - + return self.scaler.transform(y) - + def _scaler_inverse_transform(self, y): - + if y.ndim == 1: y = y[:, None] # return self.scaler.inverse_transform(y)[:, 0] - + return self.scaler.inverse_transform(y) - + def fit(self, x, y): self._scaler_fit(y) self.rf.fit(x, self._scaler_transform(y)) - + # Build the structures to quickly compute the posteriors if self.enable_posterior: data_leaves = self.rf.apply(x).T self.data_leaves = _as_smallest_udtype(data_leaves) - self.data_weights = np.array([_tree_weights(tree, len(y)) for tree in self.rf]) + self.data_weights = np.array( + [_tree_weights(tree, len(y)) for tree in self.rf]) self.data_y = y - + def predict(self, x): pred = self.rf.predict(x) return self._scaler_inverse_transform(pred) - + def predict_median(self, x): return self.predict_percentile(x, 50) - + def predict_percentile(self, x, percentile): - + if not self.enable_posterior: raise ValueError("Cannot compute posteriors with this model. " "Set `enable_posterior` to True to enable posterior computation.") - + # Find the leaves for the query points leaves_x = self.rf.apply(x) - + if len(x) > self.num_trees: # If there are many queries, it is faster to find points using a cache return _posterior_percentile_cache( @@ -127,24 +127,25 @@ def predict_percentile(self, x, percentile): self.data_leaves, self.data_weights, self.data_y, leaves_x, percentile ) - + def get_params(self, deep=True): return {"num_trees": self.num_trees, "num_jobs": self.num_jobs, "names": self.names, "ranges": self.ranges, - "colors": self.colors, "enable_posterior": self.enable_posterior, + "colors": self.colors, + "enable_posterior": self.enable_posterior, "verbose": self.verbose} - + def posterior(self, x): - + if not self.enable_posterior: raise ValueError("Cannot compute posteriors with this model. " "Set `enable_posterior` to True to enable posterior computation.") - + if x.ndim > 1: raise ValueError("x.ndim must be 1") - + leaves_x = self.rf.apply(x[None, :])[0] - + return _posterior( self.data_leaves, self.data_weights, self.data_y, leaves_x @@ -152,22 +153,21 @@ def posterior(self, x): def _posterior(data_leaves, data_weights, data_y, query_leaves): - weights_x = (query_leaves[:, None] == data_leaves) * data_weights weights_x = _as_smallest_udtype(weights_x.sum(0)) - + # Remove samples with weight zero mask = weights_x != 0 samples = data_y[mask] weights = weights_x[mask] - + return Posterior(samples, weights) -def _posterior_percentile_nocache(data_leaves, data_weights, data_y, query_leaves, percentile): - +def _posterior_percentile_nocache(data_leaves, data_weights, data_y, + query_leaves, percentile): values = [] - + logger.info("Computing percentiles...") for leaves_x_i in tqdm(query_leaves): posterior = _posterior( @@ -177,19 +177,19 @@ def _posterior_percentile_nocache(data_leaves, data_weights, data_y, query_leave samples = np.repeat(posterior.samples, posterior.weights, axis=0) value = np.percentile(samples, percentile, axis=0) values.append(value) - + return np.array(values) -def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, percentile): - +def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, + percentile): # Build a dictionary for fast access of the contents of the leaves. logger.info("Building cache...") cache = [ _build_leaves_cache(leaves_i, weights_i) for leaves_i, weights_i in zip(data_leaves, data_weights) - ] - + ] + values = [] # Check the contents of the leaves in leaves_x logger.info("Computing percentiles...") @@ -200,22 +200,21 @@ def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, data_elements.extend(aux) value = np.percentile(data_y[data_elements], percentile, axis=0) values.append(value) - + return np.array(values) def _build_leaves_cache(leaves, weights): - result = {} for index, (leaf, weight) in enumerate(zip(leaves, weights)): if weight == 0: continue - + if leaf not in result: result[leaf] = [index] * weight else: result[leaf].extend([index] * weight) - + return result @@ -237,11 +236,10 @@ def _as_smallest_udtype(arr): def _smallest_udtype(value): - dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] - + for dtype in dtypes: if value <= np.iinfo(dtype).max: return dtype - + raise ValueError("value is too large for any dtype") diff --git a/plot.py b/plot.py index 8243e99..1fc696a 100644 --- a/plot.py +++ b/plot.py @@ -1,4 +1,3 @@ - from itertools import product import numpy as np @@ -15,7 +14,8 @@ try: from tqdm import tqdm except ImportError: - def tqdm(x, *_, **__): return x + def tqdm(x, *_, **__): + return x __all__ = [ "predicted_vs_real", @@ -27,20 +27,20 @@ def tqdm(x, *_, **__): return x def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): - num_plots = y_pred.shape[1] num_plot_rows = int(np.sqrt(num_plots)) num_plot_cols = (num_plots - 1) // num_plot_rows + 1 - + fig, axes = plt.subplots(num_plot_rows, num_plot_cols, - figsize=(5*num_plot_cols, 5*num_plot_rows), + figsize=(5 * num_plot_cols, 5 * num_plot_rows), squeeze=False) - - for dim, (ax, name_i, range_i) in enumerate(zip(axes.ravel(), names, ranges)): - + + for dim, (ax, name_i, range_i) in enumerate( + zip(axes.ravel(), names, ranges)): + current_real = y_real[:, dim] current_pred = y_pred[:, dim] - + if alpha == 'auto': # TODO: this is a quick fix. Check at some point in the future. aux, *_ = np.histogram2d(current_real, current_pred, bins=60) @@ -49,13 +49,13 @@ def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): alpha_ = None else: alpha_ = alpha - + r2 = metrics.r2_score(current_real, current_pred) label = "$R^2 = {:.3f}$".format(r2) ax.plot(current_real, current_pred, '.', label=label, alpha=alpha_) - + ax.plot(range_i, range_i, '--', linewidth=3, color="C3", alpha=0.8) - + ax.axis("equal") ax.grid() ax.set_xlim(range_i) @@ -63,57 +63,57 @@ def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): ax.set_xlabel("Real {}".format(name_i), fontsize=18) ax.set_ylabel("Predicted {}".format(name_i), fontsize=18) ax.legend(loc="upper left", fontsize=14) - + fig.tight_layout() return fig def feature_importances(forests, names, colors): - num_plots = len(forests) num_plot_rows = (num_plots - 1) // 2 + 1 num_plot_cols = 2 - + fig, axes = plt.subplots(num_plot_rows, num_plot_cols, - figsize=(15, 3.5*num_plot_rows)) - - for ax, forest_i, name_i, color_i in zip(axes.ravel(), forests, names, colors): - ax.bar(np.arange(len(forest_i.feature_importances_)), forest_i.feature_importances_, + figsize=(15, 3.5 * num_plot_rows)) + + for ax, forest_i, name_i, color_i in zip(axes.ravel(), forests, names, + colors): + ax.bar(np.arange(len(forest_i.feature_importances_)), + forest_i.feature_importances_, label="Importance for {}".format(name_i), width=0.4, color=color_i) ax.set_xlabel("Feature index", fontsize=18) ax.legend(fontsize=16) ax.grid() - + fig.tight_layout() return fig def posterior_matrix(posterior, names, ranges, colors, soft_colors=None): - samples, weights = posterior - + cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) for c in colors] - + ranges = np.array(ranges) - + if soft_colors is None: soft_colors = colors - + num_dims = samples.shape[1] - + fig, axes = plt.subplots(nrows=num_dims, ncols=num_dims, figsize=(2 * num_dims, 2 * num_dims)) - fig.subplots_adjust(left=0.07, right=1-0.05, - bottom=0.07, top=1-0.05, + fig.subplots_adjust(left=0.07, right=1 - 0.05, + bottom=0.07, top=1 - 0.05, hspace=0.05, wspace=0.05) - + iterable = zip(axes.flat, product(range(num_dims), range(num_dims))) - for ax, dims in tqdm(iterable, total=num_dims**2): + for ax, dims in tqdm(iterable, total=num_dims ** 2): # Flip dims. dims = [dims[1], dims[0]] - + ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) ax.title.set_visible(False) @@ -138,7 +138,7 @@ def posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ax.set_ylabel("") if ax.is_last_col() and ax.is_last_row(): ax.yaxis.set_visible(False) - + if dims[0] < dims[1]: _plot_histogram2d( ax, posterior, @@ -159,27 +159,28 @@ def posterior_matrix(posterior, names, ranges, colors, soft_colors=None): samples[:, dims[:1]], weights, ranges=ranges[dims[:1]] ) - ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], width=bins[1]-bins[0]) - + ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], + width=bins[1] - bins[0]) + kd_probs = histogram expected = wmedian(samples[:, dims[0]], weights) - ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', linewidth=1, color='#222222') - + ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', + linewidth=1, color='#222222') + ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], 0, 1.1 * kd_probs.max()]) - + # fig.tight_layout(pad=0) return fig def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): - samples, weights = posterior # For efficiency, do not compute the kernel density # over all the samples of the posterior. Subsample first. if len(samples) > POSTERIOR_MAX_SIZE: samples, weights = resample_posterior(posterior, POSTERIOR_MAX_SIZE) - + locations, kd_probs, *_ = _kernel_density_joint( samples[:, dims], weights, @@ -191,7 +192,7 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): colors=color, linewidths=0.5 ) - + # For the rest of the plot we use the complete posterior samples, weights = posterior histogram, grid_x, grid_y = _histogram2d( @@ -199,7 +200,7 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): ranges ) ax.pcolormesh(grid_x, grid_y, histogram, cmap=cmap) - + expected = wmedian(samples[:, dims], weights, axis=0) ax.plot([expected[0], expected[0]], [ranges[1][0], ranges[1][1]], '-', linewidth=1, color='#222222') @@ -212,19 +213,18 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): def _plot_samples(ax, posterior, color, dims, ranges): - # For efficiency, do not plot all the samples of the posterior. Subsample first. if len(posterior.samples) > POSTERIOR_MAX_SIZE: posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) - + samples, weights = posterior - + points_alpha = _weights_to_alpha(weights) - + current_colors = to_rgba_array(color) current_colors = np.tile(current_colors, (len(samples), 1)) current_colors[:, 3] = points_alpha - + ax.scatter( x=samples[:, dims[0]], y=samples[:, dims[1]], @@ -233,7 +233,7 @@ def _plot_samples(ax, posterior, color, dims, ranges): marker='.', linewidth=0 ) - + ax.axis([ranges[0][0], ranges[0][1], ranges[1][0], ranges[1][1]]) @@ -243,30 +243,30 @@ def _min_max_scaler(ranges, feature_range=(0, 100)): res.data_max_ = ranges[:, 1] res.data_min_ = ranges[:, 0] res.data_range_ = res.data_max_ - res.data_min_ - res.scale_ = (feature_range[1] - feature_range[0]) / (ranges[:, 1] - ranges[:, 0]) + res.scale_ = (feature_range[1] - feature_range[0]) / ( + ranges[:, 1] - ranges[:, 0]) res.min_ = -res.scale_ * res.data_min_ res.n_samples_seen_ = 1 res.feature_range = feature_range return res -def _kernel_density_joint(samples, weights, ranges, bandwidth=1/25): - +def _kernel_density_joint(samples, weights, ranges, bandwidth=1 / 25): ndims = len(ranges) - + scaler = _min_max_scaler(ranges, feature_range=(0, 100)) - + bandwidth = bandwidth * 100 # step = 1.0 - + kd = neighbors.KernelDensity(bandwidth=bandwidth) kd.fit(scaler.transform(samples), sample_weight=weights) - + grid_shape = [100] * ndims grid = np.indices(grid_shape) locations = np.reshape(grid, (ndims, -1)).T kd_probs = np.exp(kd.score_samples(locations)) - + shape = (ndims, *grid_shape) locations = scaler.inverse_transform(locations) locations = np.reshape(locations.T, shape) @@ -275,9 +275,8 @@ def _kernel_density_joint(samples, weights, ranges, bandwidth=1/25): def _histogram1d(samples, weights, ranges, bins=20): - - assert(len(ranges) == 1) - + assert (len(ranges) == 1) + histogram, edges = np.histogram( samples[:, 0], bins=bins, @@ -288,9 +287,8 @@ def _histogram1d(samples, weights, ranges, bins=20): def _histogram2d(samples, weights, ranges, bins=20): - - assert(len(ranges) == 2) - + assert (len(ranges) == 2) + histogram, xedges, yedges = np.histogram2d( samples[:, 0], samples[:, 1], @@ -299,11 +297,10 @@ def _histogram2d(samples, weights, ranges, bins=20): weights=weights ) grid_x, grid_y = np.meshgrid(xedges, yedges) - return histogram.T, grid_x, grid_y, + return histogram.T, grid_x, grid_y, def _weights_to_alpha(weights): - # Maximum weight (removing potential outliers) max_weight = np.percentile(weights, 98) return np.clip(weights / max_weight, 0, 1) diff --git a/rfretrieval.py b/rfretrieval.py index 662265f..b5ad7df 100644 --- a/rfretrieval.py +++ b/rfretrieval.py @@ -1,4 +1,3 @@ - import argparse import os import logging @@ -27,44 +26,44 @@ def train_model(dataset, num_trees, num_jobs, verbose=1): def test_model(model, dataset, output_path): - if dataset.testing_x is None: return - + logger.info("Testing model...") pred = model.predict(dataset.testing_x) # pred = model.predict_median(dataset.testing_x) r2scores = {name_i: metrics.r2_score(real_i, pred_i) - for name_i, real_i, pred_i in zip(dataset.names, dataset.testing_y.T, pred.T)} + for name_i, real_i, pred_i in + zip(dataset.names, dataset.testing_y.T, pred.T)} print("Testing scores:") for name, values in r2scores.items(): print("\tR^2 score for {}: {:.3f}".format(name, values)) - + logger.info("Plotting testing results...") - fig = plot.predicted_vs_real(dataset.testing_y, pred, dataset.names, dataset.ranges) + fig = plot.predicted_vs_real(dataset.testing_y, pred, dataset.names, + dataset.ranges) fig.savefig(os.path.join(output_path, "predicted_vs_real.pdf"), bbox_inches='tight') def compute_feature_importance(model, dataset, output_path): - logger.info("Computing feature importance for individual parameters...") regr = multioutput.MultiOutputRegressor(model, n_jobs=1) regr.fit(dataset.training_x, dataset.training_y) - - fig = plot.feature_importances(forests=[i.rf for i in regr.estimators_] + [model.rf], - names=dataset.names + ["joint prediction"], - colors=dataset.colors + ["C0"]) - + + fig = plot.feature_importances( + forests=[i.rf for i in regr.estimators_] + [model.rf], + names=dataset.names + ["joint prediction"], + colors=dataset.colors + ["C0"]) + fig.savefig(os.path.join(output_path, "feature_importances.pdf"), bbox_inches='tight') def data_ranges(posterior, percentiles=(50, 16, 84)): - samples, weights = posterior values = wpercentile(samples, weights, percentiles, axis=0) - ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) + ranges = np.array([values[0], values[2] - values[0], values[0] - values[1]]) return ranges.T @@ -72,46 +71,45 @@ def main_train(training_dataset, model_path, num_trees, num_jobs, feature_importance, quiet, **_): - logger.info("Loading dataset '{}'...".format(training_dataset)) dataset = load_dataset(training_dataset) - + logger.info("Training model...") model = train_model(dataset, num_trees, num_jobs, not quiet) - + os.makedirs(model_path, exist_ok=True) model_file = os.path.join(model_path, "model.pkl") logger.info("Saving model to '{}'...".format(model_file)) joblib.dump(model, model_file) - + logger.info("Printing model information...") print("OOB score: {:.4f}".format(model.rf.oob_score_)) - + test_model(model, dataset, model_path) - + if feature_importance: model.enable_posterior = False compute_feature_importance(model, dataset, model_path) def main_predict(model_path, data_file, output_path, plot_posterior, **_): - model_file = os.path.join(model_path, "model.pkl") logger.info("Loading random forest from '{}'...".format(model_file)) model = joblib.load(model_file) - + logger.info("Loading data from '{}'...".format(data_file)) data, _ = load_data_file(data_file, model.rf.n_features_) - + posterior = model.posterior(data[0]) - + posterior_ranges = data_ranges(posterior) for name_i, pred_range_i in zip(model.names, posterior_ranges): - print("Prediction for {}: {:.3g} [+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) - + print("Prediction for {}: {:.3g} [+{:.3g} -{:.3g}]".format(name_i, + *pred_range_i)) + if plot_posterior: logger.info("Plotting the posterior matrix...") - + fig = plot.posterior_matrix( posterior, names=model.names, @@ -130,12 +128,12 @@ def show_usage(parser, **_): def main(): - - parser = argparse.ArgumentParser(description="rfretrieval: Atmospheric retrieval with random forests.") + parser = argparse.ArgumentParser( + description="rfretrieval: Atmospheric retrieval with random forests.") parser.set_defaults(func=show_usage, parser=parser) parser.add_argument("--quiet", action='store_true') subparsers = parser.add_subparsers() - + parser_train = subparsers.add_parser('train', help="train a model") parser_train.add_argument("training_dataset", type=str, help="JSON file with the training dataset description") @@ -148,8 +146,9 @@ def main(): parser_train.add_argument("--feature-importance", action='store_true', help="compute feature importances after training") parser_train.set_defaults(func=main_train) - - parser_test = subparsers.add_parser('predict', help="use a trained model to perform a prediction") + + parser_test = subparsers.add_parser('predict', + help="use a trained model to perform a prediction") parser_test.set_defaults(func=main_predict) parser_test.add_argument("model_path", type=str, help="path to the trained model") @@ -159,7 +158,7 @@ def main(): help="path to write the results of the prediction") parser_test.add_argument("--plot-posterior", action='store_true', help="plot and save the scatter matrix of the posterior distribution") - + args = parser.parse_args() config_logger(level=logging.WARNING if args.quiet else logging.INFO) args.func(**vars(args)) diff --git a/utils.py b/utils.py index 414cc49..b0a999f 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,3 @@ - import os import logging @@ -8,33 +7,32 @@ def config_logger(log_file="/dev/null", level=logging.INFO): - class MyFormatter(logging.Formatter): - + info_format = "\x1b[32;1m%(asctime)s [%(name)s]\x1b[0m %(message)s" error_format = "\x1b[31;1m%(asctime)s [%(name)s] [%(levelname)s]\x1b[0m %(message)s" - + def format(self, record): - + if record.levelno > logging.INFO: self._style._fmt = self.error_format else: self._style._fmt = self.info_format - + res = super(MyFormatter, self).format(record) return res - + rootLogger = logging.getLogger() - + fileHandler = logging.FileHandler(log_file) - fileFormatter = logging.Formatter("%(asctime)s [%(name)s] [%(levelname)s]> %(message)s") + fileFormatter = logging.Formatter( + "%(asctime)s [%(name)s] [%(levelname)s]> %(message)s") fileHandler.setFormatter(fileFormatter) rootLogger.addHandler(fileHandler) - + consoleHandler = logging.StreamHandler() consoleFormatter = MyFormatter() consoleHandler.setFormatter(consoleFormatter) rootLogger.addHandler(consoleHandler) - - rootLogger.setLevel(level) + rootLogger.setLevel(level) diff --git a/wpercentile.py b/wpercentile.py index f030db9..0985556 100644 --- a/wpercentile.py +++ b/wpercentile.py @@ -1,4 +1,3 @@ - import numpy as np __all__ = [ @@ -8,26 +7,25 @@ def _wpercentile1d(data, weights, percentiles): - if data.ndim > 1 or weights.ndim > 1: raise ValueError("data and weights must be one-dimensional arrays") - + if data.shape != weights.shape: raise ValueError("data and weights must have the same shape") - + data = np.asarray(data) weights = np.asarray(weights) percentiles = np.asarray(percentiles) - + sort_indices = np.argsort(data) sorted_data = data[sort_indices] sorted_weights = weights[sort_indices] - + cumsum_weights = np.cumsum(sorted_weights) sum_weights = cumsum_weights[-1] - - pn = 100 * (cumsum_weights - 0.5*sorted_weights) / sum_weights - + + pn = 100 * (cumsum_weights - 0.5 * sorted_weights) / sum_weights + return np.interp(percentiles, pn, sorted_data) @@ -39,22 +37,23 @@ def wpercentile(data, weights, percentiles, axis=None): data = np.ravel(data) weights = np.ravel(weights) return _wpercentile1d(data, weights, percentiles) - + axis = np.atleast_1d(axis) - + # Reshape the arrays for proper computation # Move the requested axis to the final dimensions dest_axis = list(range(len(axis))) data2 = np.moveaxis(data, axis, dest_axis) - + ndim = len(axis) shape = data2.shape newshape = (np.prod(shape[:ndim]), np.prod(shape[ndim:])) newdata = np.reshape(data2, newshape) newweights = np.reshape(weights, newshape[0]) - - result = np.apply_along_axis(_wpercentile1d, 0, newdata, newweights, percentiles) - + + result = np.apply_along_axis(_wpercentile1d, 0, newdata, newweights, + percentiles) + final_shape = (*np.shape(percentiles), *shape[ndim:]) return np.reshape(result, final_shape) From 87b160527b4f76d7adc0e2c5522e6a54242a5920 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:41:27 +0100 Subject: [PATCH 02/46] Building CI environment --- .rtd-environment.yml | 12 +++ .travis.yml | 136 +++++++++++++++++++++++++++ hela/version.py | 218 +++++++++++++++++++++++++++++++++++++++++++ readthedocs.yml | 5 + 4 files changed, 371 insertions(+) create mode 100644 .rtd-environment.yml create mode 100644 .travis.yml create mode 100644 hela/version.py create mode 100644 readthedocs.yml diff --git a/.rtd-environment.yml b/.rtd-environment.yml new file mode 100644 index 0000000..fd0ed40 --- /dev/null +++ b/.rtd-environment.yml @@ -0,0 +1,12 @@ +name: hela + +channels: + - astropy + +dependencies: + - astropy + - numpy + - scikit-learn + - pip: + - sphinx-automodapi + - numpydoc diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..1f076a6 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,136 @@ +# We set the language to c because python isn't supported on the MacOS X nodes +# on Travis. However, the language ends up being irrelevant anyway, since we +# install Python ourselves using conda. +language: c + +os: + - linux + +# The apt packages below are needed for sphinx builds. A full list of packages +# that can be included can be found here: +# +# https://github.com/travis-ci/apt-package-whitelist/blob/master/ubuntu-precise + +addons: + apt: + packages: + - graphviz + + +stage: Comprehensive tests + +stages: + # Do the style check and a single test job, don't proceed if it fails + - name: Initial tests + # Test docs, astropy dev, and without optional dependencies + - name: Comprehensive tests + # Slow tests that should only run if comprehensive ones passed + - name: Slow tests + # These will only run when cron is opted in + - name: Cron tests + if: type = cron + + +env: + global: + # The following versions are the 'default' for tests, unless + # overridden underneath. They are defined here in order to save having + # to repeat them for all configurations. + - PYTHON_VERSION=3.7 + - NUMPY_VERSION=stable + - ASTROPY_VERSION=stable + - MAIN_CMD='python setup.py' + - SETUP_CMD='test' + - EVENT_TYPE='pull_request push' + + + # List runtime dependencies for the package that are available as conda + # packages here. + - CONDA_DEPENDENCIES='scikit-learn' + - CONDA_DEPENDENCIES_DOC='sphinx-astropy' + + # List other runtime dependencies for the package that are available as + # pip packages here. + - PIP_DEPENDENCIES='scipy matplotlib' + + # Conda packages for affiliated packages are hosted in channel + # "astropy" while builds for astropy LTS with recent numpy versions + # are in astropy-ci-extras. If your package uses either of these, + # add the channels to CONDA_CHANNELS along with any other channels + # you want to use. + - CONDA_CHANNELS='astropy' + + # If there are matplotlib or other GUI tests, uncomment the following + # line to use the X virtual framebuffer. + # - SETUP_XVFB=True + + # If you want to ignore certain flake8 errors, you can list them + # in FLAKE8_OPT, for example: + # - FLAKE8_OPT='--ignore=E501' + - FLAKE8_OPT='' + +matrix: + + # Don't wait for allowed failures + fast_finish: true + + include: + # Make sure that egg_info works without dependencies + - stage: Initial tests + env: PYTHON_VERSION=3.7 SETUP_CMD='egg_info' + + # Try MacOS X, usually enough only to run from cron as hardly there are + # issues that are not picked up by a linux worker + - os: osx + stage: Cron tests + env: SETUP_CMD='test' EVENT_TYPE='cron' + + # Do a coverage test. + - os: linux + stage: Initial tests + env: SETUP_CMD='test --coverage' + + # Do a PEP8 test with flake8 + - os: linux + stage: Initial tests + env: MAIN_CMD='flake8 hipparchus --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' + + allow_failures: + # Do a PEP8 test with flake8 + # (do allow to fail unless your code completely compliant) + # - os: linux + # stage: Initial tests + # env: MAIN_CMD='flake8 hipparchus --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' + +install: + + # We now use the ci-helpers package to set up our testing environment. + # This is done by using Miniconda and then using conda and pip to install + # dependencies. Which dependencies are installed using conda and pip is + # determined by the CONDA_DEPENDENCIES and PIP_DEPENDENCIES variables, + # which should be space-delimited lists of package names. See the README + # in https://github.com/astropy/ci-helpers for information about the full + # list of environment variables that can be used to customize your + # environment. In some cases, ci-helpers may not offer enough flexibility + # in how to install a package, in which case you can have additional + # commands in the install: section below. + + - git clone --depth 1 git://github.com/astropy/ci-helpers.git + - source ci-helpers/travis/setup_conda.sh + + # As described above, using ci-helpers, you should be able to set up an + # environment with dependencies installed using conda and pip, but in some + # cases this may not provide enough flexibility in how to install a + # specific dependency (and it will not be able to install non-Python + # dependencies). Therefore, you can also include commands below (as + # well as at the start of the install section or in the before_install + # section if they are needed before setting up conda) to install any + # other dependencies. + +script: + - $MAIN_CMD $SETUP_CMD + +after_success: + # If coveralls.io is set up for this package, uncomment the line below. + # The coveragerc file may be customized as needed for your package. + # - if [[ $SETUP_CMD == *coverage* ]]; then coveralls --rcfile='hipparchus/tests/coveragerc'; fi diff --git a/hela/version.py b/hela/version.py new file mode 100644 index 0000000..ca2074f --- /dev/null +++ b/hela/version.py @@ -0,0 +1,218 @@ +# Autogenerated by Astropy-affiliated package hela's setup.py on 2019-11-07 19:28:46 UTC +import datetime + + +import locale +import os +import subprocess +import warnings + +__all__ = ['get_git_devstr'] + + +def _decode_stdio(stream): + try: + stdio_encoding = locale.getdefaultlocale()[1] or 'utf-8' + except ValueError: + stdio_encoding = 'utf-8' + + try: + text = stream.decode(stdio_encoding) + except UnicodeDecodeError: + # Final fallback + text = stream.decode('latin1') + + return text + + +def update_git_devstr(version, path=None): + """ + Updates the git revision string if and only if the path is being imported + directly from a git working copy. This ensures that the revision number in + the version string is accurate. + """ + + try: + # Quick way to determine if we're in git or not - returns '' if not + devstr = get_git_devstr(sha=True, show_warning=False, path=path) + except OSError: + return version + + if not devstr: + # Probably not in git so just pass silently + return version + + if 'dev' in version: # update to the current git revision + version_base = version.split('.dev', 1)[0] + devstr = get_git_devstr(sha=False, show_warning=False, path=path) + + return version_base + '.dev' + devstr + else: + # otherwise it's already the true/release version + return version + + +def get_git_devstr(sha=False, show_warning=True, path=None): + """ + Determines the number of revisions in this repository. + + Parameters + ---------- + sha : bool + If True, the full SHA1 hash will be returned. Otherwise, the total + count of commits in the repository will be used as a "revision + number". + + show_warning : bool + If True, issue a warning if git returns an error code, otherwise errors + pass silently. + + path : str or None + If a string, specifies the directory to look in to find the git + repository. If `None`, the current working directory is used, and must + be the root of the git repository. + If given a filename it uses the directory containing that file. + + Returns + ------- + devversion : str + Either a string with the revision number (if `sha` is False), the + SHA1 hash of the current commit (if `sha` is True), or an empty string + if git version info could not be identified. + + """ + + if path is None: + path = os.getcwd() + + if not os.path.isdir(path): + path = os.path.abspath(os.path.dirname(path)) + + if sha: + # Faster for getting just the hash of HEAD + cmd = ['rev-parse', 'HEAD'] + else: + cmd = ['rev-list', '--count', 'HEAD'] + + def run_git(cmd): + try: + p = subprocess.Popen(['git'] + cmd, cwd=path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE) + stdout, stderr = p.communicate() + except OSError as e: + if show_warning: + warnings.warn('Error running git: ' + str(e)) + return (None, b'', b'') + + if p.returncode == 128: + if show_warning: + warnings.warn('No git repository present at {0!r}! Using ' + 'default dev version.'.format(path)) + return (p.returncode, b'', b'') + if p.returncode == 129: + if show_warning: + warnings.warn('Your git looks old (does it support {0}?); ' + 'consider upgrading to v1.7.2 or ' + 'later.'.format(cmd[0])) + return (p.returncode, stdout, stderr) + elif p.returncode != 0: + if show_warning: + warnings.warn('Git failed while determining revision ' + 'count: {0}'.format(_decode_stdio(stderr))) + return (p.returncode, stdout, stderr) + + return p.returncode, stdout, stderr + + returncode, stdout, stderr = run_git(cmd) + + if not sha and returncode == 128: + # git returns 128 if the command is not run from within a git + # repository tree. In this case, a warning is produced above but we + # return the default dev version of '0'. + return '0' + elif not sha and returncode == 129: + # git returns 129 if a command option failed to parse; in + # particular this could happen in git versions older than 1.7.2 + # where the --count option is not supported + # Also use --abbrev-commit and --abbrev=0 to display the minimum + # number of characters needed per-commit (rather than the full hash) + cmd = ['rev-list', '--abbrev-commit', '--abbrev=0', 'HEAD'] + returncode, stdout, stderr = run_git(cmd) + # Fall back on the old method of getting all revisions and counting + # the lines + if returncode == 0: + return str(stdout.count(b'\n')) + else: + return '' + elif sha: + return _decode_stdio(stdout)[:40] + else: + return _decode_stdio(stdout).strip() + + +# This function is tested but it is only ever executed within a subprocess when +# creating a fake package, so it doesn't get picked up by coverage metrics. +def _get_repo_path(pathname, levels=None): # pragma: no cover + """ + Given a file or directory name, determine the root of the git repository + this path is under. If given, this won't look any higher than ``levels`` + (that is, if ``levels=0`` then the given path must be the root of the git + repository and is returned if so. + + Returns `None` if the given path could not be determined to belong to a git + repo. + """ + + if os.path.isfile(pathname): + current_dir = os.path.abspath(os.path.dirname(pathname)) + elif os.path.isdir(pathname): + current_dir = os.path.abspath(pathname) + else: + return None + + current_level = 0 + + while levels is None or current_level <= levels: + if os.path.exists(os.path.join(current_dir, '.git')): + return current_dir + + current_level += 1 + if current_dir == os.path.dirname(current_dir): + break + + current_dir = os.path.dirname(current_dir) + + return None + + +_packagename = "hela" +_last_generated_version = "0.0.dev035" +_last_githash = "fc6e9dbc80d4d68d83481d2afa7337dd1ae1cf93" + +# Determine where the source code for this module +# lives. If __file__ is not a filesystem path then +# it is assumed not to live in a git repo at all. +if _get_repo_path(__file__, levels=len(_packagename.split('.'))): + version = update_git_devstr(_last_generated_version, path=__file__) + githash = get_git_devstr(sha=True, show_warning=False, + path=__file__) or _last_githash +else: + # The file does not appear to live in a git repo so don't bother + # invoking git + version = _last_generated_version + githash = _last_githash + + +major = 0 +minor = 0 +bugfix = 0 + +version_info = (major, minor, bugfix) + +release = False +timestamp = datetime.datetime(2019, 11, 7, 19, 28, 46) +debug = False + +astropy_helpers_version = "unknown" diff --git a/readthedocs.yml b/readthedocs.yml new file mode 100644 index 0000000..47fe6ad --- /dev/null +++ b/readthedocs.yml @@ -0,0 +1,5 @@ +conda: + file: .rtd-environment.yml + +python: + setup_py_install: true From d8e4844d62aeca82009ecd456a8a692373c78e60 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:42:57 +0100 Subject: [PATCH 03/46] Adding hela module --- hela/__init__.py | 31 ++++++ hela/_astropy_init.py | 59 +++++++++++ hela/dataset.py | 64 ++++++++++++ hela/forest.py | 225 ++++++++++++++++++++++++++++++++++++++++++ hela/models.py | 86 ++++++++++++++++ hela/plot.py | 210 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 675 insertions(+) create mode 100644 hela/__init__.py create mode 100644 hela/_astropy_init.py create mode 100644 hela/dataset.py create mode 100644 hela/forest.py create mode 100644 hela/models.py create mode 100644 hela/plot.py diff --git a/hela/__init__.py b/hela/__init__.py new file mode 100644 index 0000000..4b59994 --- /dev/null +++ b/hela/__init__.py @@ -0,0 +1,31 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +# Packages may add whatever they like to this file, but +# should keep this content at the top. +# ---------------------------------------------------------------------------- +from ._astropy_init import * +# ---------------------------------------------------------------------------- + +# Enforce Python version check during package import. +# This is the same check as the one at the top of setup.py +import sys + +__minimum_python_version__ = "3.5" + + +class UnsupportedPythonError(Exception): + pass + + +if sys.version_info < tuple( + (int(val) for val in __minimum_python_version__.split('.'))): + raise UnsupportedPythonError( + "packagename does not support Python < {}".format( + __minimum_python_version__)) + +if not _ASTROPY_SETUP_: + # For egg_info test builds to pass, put package imports here. + from .forest import * + from .models import * + from .dataset import * + from .plot import * diff --git a/hela/_astropy_init.py b/hela/_astropy_init.py new file mode 100644 index 0000000..5371878 --- /dev/null +++ b/hela/_astropy_init.py @@ -0,0 +1,59 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +__all__ = ['__version__', '__githash__'] + +# this indicates whether or not we are in the package's setup.py +try: + _ASTROPY_SETUP_ +except NameError: + from sys import version_info + if version_info[0] >= 3: + import builtins + else: + import __builtin__ as builtins + builtins._ASTROPY_SETUP_ = False + +try: + from .version import version as __version__ +except ImportError: + __version__ = '' +try: + from .version import githash as __githash__ +except ImportError: + __githash__ = '' + + +if not _ASTROPY_SETUP_: # noqa + import os + from warnings import warn + from astropy.config.configuration import ( + update_default_config, + ConfigurationDefaultMissingError, + ConfigurationDefaultMissingWarning) + + # Create the test function for self test + from astropy.tests.runner import TestRunner + test = TestRunner.make_test_runner_in(os.path.dirname(__file__)) + __all__ += ['test'] + + # add these here so we only need to cleanup the namespace at the end + config_dir = None + + if not os.environ.get('ASTROPY_SKIP_CONFIG_UPDATE', False): + config_dir = os.path.dirname(__file__) + config_template = os.path.join(config_dir, __package__ + ".cfg") + if os.path.isfile(config_template): + try: + update_default_config( + __package__, config_dir, version=__version__) + except TypeError as orig_error: + try: + update_default_config(__package__, config_dir) + except ConfigurationDefaultMissingError as e: + wmsg = (e.args[0] + + " Cannot install default profile. If you are " + "importing from source, this is expected.") + warn(ConfigurationDefaultMissingWarning(wmsg)) + del e + except Exception: + raise orig_error diff --git a/hela/dataset.py b/hela/dataset.py new file mode 100644 index 0000000..b65a3be --- /dev/null +++ b/hela/dataset.py @@ -0,0 +1,64 @@ +import os +import json +from collections import namedtuple + +import numpy as np + +__all__ = ["Dataset", "load_dataset", "load_data_file"] + + +Dataset = namedtuple("Dataset", ["training_x", "training_y", + "testing_x", "testing_y", + "names", "ranges", "colors"]) + + +def load_data_file(data_file, num_features): + data = np.load(data_file) + + if data.ndim == 1: + data = data[None, :] + + x = data[:, :num_features] + y = data[:, num_features:] + + return x, y + + +def load_dataset(dataset_file): + """ + Load a dataset from a JSON file. + + Parameters + ---------- + dataset_file + + Returns + ------- + + """ + with open(dataset_file, "r") as f: + dataset_info = json.load(f) + + metadata = dataset_info["metadata"] + + base_path = os.path.dirname(dataset_file) + + # Load training data + training_file = os.path.join(base_path, dataset_info["training_data"]) + # Loading training data from '{}'...".format(training_file) + training_x, training_y = load_data_file(training_file, + metadata["num_features"]) + # TODO: slice training_x (data) and training_y (params) to the same length + # but something smaller for fast docs + + # Optionally, load testing data + testing_x, testing_y = None, None + if dataset_info["testing_data"] is not None: + testing_file = os.path.join(base_path, dataset_info["testing_data"]) + # Loading testing data from '{}'...".format(testing_file) + testing_x, testing_y = load_data_file(testing_file, + metadata["num_features"]) + + return Dataset(training_x, training_y, testing_x, testing_y, + metadata["names"], metadata["ranges"], metadata["colors"]) + diff --git a/hela/forest.py b/hela/forest.py new file mode 100644 index 0000000..d57fd73 --- /dev/null +++ b/hela/forest.py @@ -0,0 +1,225 @@ +import os +import json + +import numpy as np +from sklearn import metrics, multioutput +import joblib + +from .dataset import load_dataset, load_data_file +from .models import Model +from .plot import predicted_vs_real, feature_importances, posterior_matrix + +__all__ = ['RandomForest', 'generate_example_data'] + + +def train_model(dataset, num_trees, num_jobs, verbose=1): + pipeline = Model(num_trees, num_jobs, + names=dataset.names, + ranges=dataset.ranges, + colors=dataset.colors, + verbose=verbose) + pipeline.fit(dataset.training_x, dataset.training_y) + return pipeline + + +def test_model(model, dataset, output_path): + if dataset.testing_x is None: + return + + pred = model.predict(dataset.testing_x) + r2scores = {name_i: metrics.r2_score(real_i, pred_i) + for name_i, real_i, pred_i in + zip(dataset.names, dataset.testing_y.T, pred.T)} + print("Testing scores:") + for name, values in r2scores.items(): + print("\tR^2 score for {}: {:.3f}".format(name, values)) + + fig = predicted_vs_real(dataset.testing_y, pred, dataset.names, + dataset.ranges) + fig.savefig(os.path.join(output_path, "predicted_vs_real.pdf"), + bbox_inches='tight') + return r2scores + + +def compute_feature_importance(model, dataset, output_path): + regr = multioutput.MultiOutputRegressor(model, n_jobs=1) + regr.fit(dataset.training_x, dataset.training_y) + + forests = [i.rf for i in regr.estimators_] + [model.rf] + + fig = feature_importances( + forests=[i.rf for i in regr.estimators_] + [model.rf], + names=dataset.names + ["joint prediction"], + colors=dataset.colors + ["C0"]) + + fig.savefig(os.path.join(output_path, "feature_importances.pdf"), + bbox_inches='tight') + return np.array([forest_i.feature_importances_ for forest_i in forests]) + + +def prediction_ranges(preds): + percentiles = (np.percentile(pred_i, [50, 16, 84]) for pred_i in preds.T) + return np.array([(a, c - a, a - b) for a, b, c in percentiles]) + + +class RandomForest(object): + """ + A class for a random forest. + """ + def __init__(self, training_dataset, model_path, data_file): + """ + Parameters + ---------- + training_dataset + model_path + data_file + """ + self.training_dataset = training_dataset + self.model_path = model_path + self.data_file = data_file + self.output_path = self.model_path + + self.dataset = None + self.model = None + + def train(self, num_trees=1000, num_jobs=5, quiet=False): + """ + Train the random forest on a set of observations. + + Parameters + ---------- + num_trees + num_jobs + quiet + kwargs + + Returns + ------- + r2scores : dict + """ + # Loading dataset + self.dataset = load_dataset(self.training_dataset) + + # Training model + self.model = train_model(self.dataset, num_trees, num_jobs, not quiet) + + os.makedirs(self.model_path, exist_ok=True) + model_file = os.path.join(self.model_path, "model.pkl") + # Saving model + joblib.dump(self.model, model_file) + + # Printing model information... + print("OOB score: {:.4f}".format(self.model.rf.oob_score_)) + + r2scores = test_model(self.model, self.dataset, self.model_path) + + return r2scores + + def feature_importance(self): + """ + Compute feature importance. + + Parameters + ---------- + model + dataset + + Returns + ------- + feature_importances : `~numpy.ndarray` + """ + return compute_feature_importance(self.model, self.dataset, + self.model_path) + + def predict(self, plot_posterior=True): + """ + Predict values from the trained random forest. + + Parameters + ---------- + plot_posterior + + Returns + ------- + preds : `~numpy.ndarray` + N x M values where N is number of parameters, M is number of + samples/trees (check out attributes of model for metadata) + """ + model_file = os.path.join(self.model_path, "model.pkl") + # Loading random forest from '{}'...".format(model_file) + model = joblib.load(model_file) + + # Loading data from '{}'...".format(data_file) + data, _ = load_data_file(self.data_file, model.rf.n_features_) + + preds = model.trees_predict(data[0]) + + pred_ranges = prediction_ranges(preds) + + for name_i, pred_range_i in zip(model.names, pred_ranges): + print("Prediction for {}: {:.3g} " + "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) + + if plot_posterior: + # Plotting and saving the posterior matrix..." + fig = posterior_matrix(preds, None, + names=model.names, + ranges=model.ranges, + colors=model.colors) + os.makedirs(self.output_path, exist_ok=True) + fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), + bbox_inches='tight') + return preds.T + + +def generate_example_data(): + """ + Generate an example dataset in the new directory ``linear_dataset`` + """ + example_dir = 'linear_dataset' + training_dataset = os.path.join(example_dir, 'example_dataset.json') + samples_path = 'samples.npy' + + os.makedirs(example_dir, exist_ok=True) + + # Save the dataset metadata to a JSON file + dataset = { + "metadata": { + "names": ["slope", "intercept"], + "ranges": [[0, 1], [0, 1]], + "colors": ["#F14532", "#4a98c9"], + "num_features": 1000 + }, + "training_data": "training.npy", + "testing_data": "testing.npy" + } + + with open(training_dataset, 'w') as fp: + json.dump(dataset, fp) + + # Generate fake training data + npoints = 1000 + + slopes = np.random.rand(npoints) + ints = np.random.rand(npoints) + x = np.linspace(0, 1, 1000)[:, np.newaxis] + data = slopes * x + ints + + labels = np.vstack([slopes, ints]) + X = np.vstack([data, labels]) + + # Split dataset into training and testing segments + training = X[:, :int(0.8 * npoints)].T + testing = X[:, int(-0.2 * npoints):].T + + np.save(os.path.join(example_dir, 'training.npy'), training) + np.save(os.path.join(example_dir, 'testing.npy'), testing) + + # Generate a bunch of samples with a test value to "retrieve" with the + # random forest: + true_slope = 0.2 + true_intercept = 0.5 + + samples = true_slope * x + true_intercept + np.save(samples_path, samples.T) + return example_dir, training_dataset, samples_path \ No newline at end of file diff --git a/hela/models.py b/hela/models.py new file mode 100644 index 0000000..5093c0f --- /dev/null +++ b/hela/models.py @@ -0,0 +1,86 @@ +import numpy as np + +from sklearn import ensemble +from sklearn.preprocessing import MinMaxScaler + + +class Model(object): + """ + Class for models. + """ + def __init__(self, num_trees, num_jobs, + names, ranges, colors, + verbose=1): + """ + Parameters + ---------- + num_trees + num_jobs + names + ranges + colors + verbose + """ + scaler = MinMaxScaler(feature_range=(0, 100)) + rf = ensemble.RandomForestRegressor(n_estimators=num_trees, + oob_score=True, + verbose=verbose, + n_jobs=num_jobs, + max_features="sqrt", + min_impurity_decrease=0.01) + + self.scaler = scaler + self.rf = rf + + self.num_trees = num_trees + self.num_jobs = num_jobs + self.verbose = verbose + + self.ranges = ranges + self.names = names + self.colors = colors + + def _scaler_fit(self, y): + if y.ndim == 1: + y = y[:, None] + + self.scaler.fit(y) + + def _scaler_transform(self, y): + if y.ndim == 1: + y = y[:, None] + return self.scaler.transform(y)[:, 0] + + return self.scaler.transform(y) + + def _scaler_inverse_transform(self, y): + + if y.ndim == 1: + y = y[:, None] + # return self.scaler.inverse_transform(y)[:, 0] + + return self.scaler.inverse_transform(y) + + def fit(self, x, y): + self._scaler_fit(y) + self.rf.fit(x, self._scaler_transform(y)) + + def predict(self, x): + pred = self.rf.predict(x) + return self._scaler_inverse_transform(pred) + + def get_params(self, deep=True): + return {"num_trees": self.num_trees, "num_jobs": self.num_jobs, + "names": self.names, "ranges": self.ranges, + "colors": self.colors, + "verbose": self.verbose} + + def trees_predict(self, x): + + if x.ndim > 1: + raise ValueError("x.ndim must be 1") + + preds = np.array([i.predict(x[None, :])[0] + for i in self.rf.estimators_]) + return self._scaler_inverse_transform(preds) + diff --git a/hela/plot.py b/hela/plot.py new file mode 100644 index 0000000..b623c30 --- /dev/null +++ b/hela/plot.py @@ -0,0 +1,210 @@ +from itertools import product + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + +from sklearn import metrics, neighbors +from sklearn.preprocessing import MinMaxScaler + + +__all__ = ['predicted_vs_real', 'feature_importances', 'posterior_matrix'] + + +def predicted_vs_real(y_real, y_pred, names, ranges): + num_plots = y_pred.shape[1] + num_plot_rows = int(np.sqrt(num_plots)) + num_plot_cols = (num_plots - 1) // num_plot_rows + 1 + + fig, axes = plt.subplots(num_plot_rows, num_plot_cols, + figsize=(5 * num_plot_cols, 5 * num_plot_rows), + squeeze=False) + + for dim, (ax, name_i, range_i) in enumerate( + zip(axes.ravel(), names, ranges)): + current_real = y_real[:, dim] + current_pred = y_pred[:, dim] + + r2 = metrics.r2_score(current_real, current_pred) + label = "$R^2 = {:.3f}$".format(r2) + ax.plot(current_real, current_pred, '.', label=label) + + ax.plot(range_i, range_i, '--', linewidth=3, color="C3", alpha=0.8) + + ax.axis("equal") + ax.grid() + ax.set_xlim(range_i) + ax.set_ylim(range_i) + ax.set_xlabel("Real {}".format(name_i), fontsize=18) + ax.set_ylabel("Predicted {}".format(name_i), fontsize=18) + ax.legend(loc="upper left", fontsize=14) + + fig.tight_layout() + return fig + + +def feature_importances(forests, names, colors): + num_plots = len(forests) + num_plot_rows = (num_plots - 1) // 2 + 1 + num_plot_cols = 2 + + fig, axes = plt.subplots(num_plot_rows, num_plot_cols, + figsize=(15, 3.5 * num_plot_rows)) + + for ax, forest_i, name_i, color_i in zip(axes.ravel(), forests, names, + colors): + ax.bar(np.arange(len(forest_i.feature_importances_)), + forest_i.feature_importances_, + label="Importance for {}".format(name_i), + width=0.4, color=color_i) + ax.set_xlabel("Feature index", fontsize=18) + ax.legend(fontsize=16) + ax.grid() + + fig.tight_layout() + return fig + + +def posterior_matrix(estimations, y, names, ranges, colors, soft_colors=None): + cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) + for c in colors] + + ranges = np.array(ranges) + + if soft_colors is None: + soft_colors = colors + + num_dims = estimations.shape[1] + + fig, axes = plt.subplots(nrows=num_dims, ncols=num_dims, + figsize=(2 * num_dims, 2 * num_dims)) + fig.subplots_adjust(left=0.07, right=1 - 0.05, + bottom=0.07, top=1 - 0.05, + hspace=0.05, wspace=0.05) + + for ax, dims in zip(axes.flat, product(range(num_dims), range(num_dims))): + dims = list(dims[::-1]) + ax.xaxis.set_visible(False) + ax.yaxis.set_visible(False) + ax.title.set_visible(False) + if ax.is_first_col(): + ax.yaxis.set_ticks_position('left') + ax.yaxis.set_visible(True) + if names is not None: + ax.set_ylabel(names[dims[1]], fontsize=18) + if ax.is_last_col(): + ax.yaxis.set_ticks_position('right') + ax.yaxis.set_visible(True) + if ax.is_first_row(): + ax.xaxis.set_ticks_position('top') + ax.xaxis.set_visible(True) + if ax.is_last_row(): + ax.xaxis.set_ticks_position('bottom') + ax.xaxis.set_visible(True) + if names is not None: + ax.set_xlabel(names[dims[0]], fontsize=18) + if ax.is_first_col() and ax.is_first_row(): + ax.yaxis.set_visible(False) + ax.set_ylabel("") + if ax.is_last_col() and ax.is_last_row(): + ax.yaxis.set_visible(False) + + if dims[0] < dims[1]: + locations, kd_probs, *_ = _kernel_density_joint( + estimations[:, dims], ranges[dims]) + ax.contour(locations[0], locations[1], + kd_probs, + colors=colors[dims[0]], + linewidths=0.5 + # 'copper', # 'hot', 'magma' ('copper' with white background) + ) + histogram, grid_x, grid_y = _histogram(estimations[:, dims], + ranges[dims]) + ax.pcolormesh(grid_x, grid_y, histogram, cmap=cmaps[dims[0]]) + + expected = np.median(estimations[:, dims], axis=0) + ax.plot([expected[0], expected[0]], + [ranges[dims[1]][0], ranges[dims[1]][1]], '-', linewidth=1, + color='#222222') + ax.plot([ranges[dims[0]][0], ranges[dims[0]][1]], + [expected[1], expected[1]], '-', linewidth=1, + color='#222222') + ax.plot(expected[0], expected[1], '.', color='#222222') + ax.axis('auto') + if y is not None: + real = y[dims] + ax.plot(real[0], real[1], '*', markersize=10, color='#FF0000') + ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], + ranges[dims[1]][0], ranges[dims[1]][1]]) + elif dims[0] > dims[1]: + ax.plot(estimations[:, dims[0]], estimations[:, dims[1]], '.', + color=soft_colors[dims[1]]) + ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], + ranges[dims[1]][0], ranges[dims[1]][1]]) + else: + histogram, bins = _histogram(estimations[:, dims[:1]], + ranges=ranges[dims[:1]]) + ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], + width=bins[1] - bins[0]) + + kd_probs = histogram + expected = np.median(estimations[:, dims[0]]) + ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', + linewidth=1, color='#222222') + + if y is not None: + real = y[dims[0]] + ax.plot([real, real], [0, kd_probs.max()], 'r-') + ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], + 0, 1.1 * kd_probs.max()]) + + # fig.tight_layout(pad=0) + return fig + + +def _min_max_scaler(ranges, feature_range=(0, 100)): + res = MinMaxScaler() + res.data_max_ = ranges[:, 1] + res.data_min_ = ranges[:, 0] + res.data_range_ = res.data_max_ - res.data_min_ + res.scale_ = (feature_range[1] - feature_range[0]) / ( + ranges[:, 1] - ranges[:, 0]) + res.min_ = -res.scale_ * res.data_min_ + res.n_samples_seen_ = 1 + res.feature_range = feature_range + return res + + +def _kernel_density_joint(estimations, ranges, bandwidth=1 / 25): + ndims = len(ranges) + + scaler = _min_max_scaler(ranges, feature_range=(0, 100)) + + bandwidth = bandwidth * 100 + # step = 1.0 + + kd = neighbors.KernelDensity(bandwidth=bandwidth).fit( + scaler.transform(estimations)) + locations1d = np.arange(0, 100, 1) + locations = np.reshape(np.meshgrid(*[locations1d] * ndims), (ndims, -1)).T + kd_probs = np.exp(kd.score_samples(locations)) + + shape = (ndims,) + (len(locations1d),) * ndims + locations = scaler.inverse_transform(locations) + locations = np.reshape(locations.T, shape) + kd_probs = np.reshape(kd_probs, shape[1:]) + return locations, kd_probs, kd + + +def _histogram(estimations, ranges, bins=20): + if len(ranges) == 1: + histogram, edges = np.histogram(estimations[:, 0], bins=bins, + range=ranges[0]) + return histogram, edges + + if len(ranges) == 2: + histogram, xedges, yedges = np.histogram2d(estimations[:, 0], + estimations[:, 1], + bins=bins, range=ranges) + grid_x, grid_y = np.meshgrid(xedges, yedges) + return histogram.T, grid_x, grid_y, From 90ad5940fd5db394a236326ae2b027dcdc0f6f19 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:44:56 +0100 Subject: [PATCH 04/46] Adding sphinx documentation --- docs/Makefile | 133 ++++++++++++++++++++++++ docs/conf.py | 200 +++++++++++++++++++++++++++++++++++++ docs/hela/installation.rst | 6 ++ docs/hela/tutorial.rst | 55 ++++++++++ docs/index.rst | 11 ++ docs/make.bat | 170 +++++++++++++++++++++++++++++++ 6 files changed, 575 insertions(+) create mode 100644 docs/Makefile create mode 100644 docs/conf.py create mode 100644 docs/hela/installation.rst create mode 100644 docs/hela/tutorial.rst create mode 100644 docs/index.rst create mode 100644 docs/make.bat diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..fb03f26 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,133 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest + +#This is needed with git because git doesn't create a dir if it's empty +$(shell [ -d "_static" ] || mkdir -p _static) + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " linkcheck to check all external links for integrity" + +clean: + -rm -rf $(BUILDDIR) + -rm -rf api + -rm -rf generated + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Astropy.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Astropy.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/Astropy" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Astropy" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + make -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + @echo "Run 'python setup.py test' in the root directory to run doctests " \ + @echo "in the documentation." diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..a14a7f7 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst +# +# Astropy documentation build configuration file. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this file. +# +# All configuration values have a default. Some values are defined in +# the global Astropy configuration which is loaded here before anything else. +# See astropy.sphinx.conf for which values are set there. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# sys.path.insert(0, os.path.abspath('..')) +# IMPORTANT: the above commented section was generated by sphinx-quickstart, but +# is *NOT* appropriate for astropy or Astropy affiliated packages. It is left +# commented out with this explanation to make it clear why this should not be +# done. If the sys.path entry above is added, when the astropy.sphinx.conf +# import occurs, it will import the *source* version of astropy instead of the +# version installed (if invoked as "make html" or directly with sphinx), or the +# version in the build directory (if "python setup.py build_sphinx" is used). +# Thus, any C-extensions that are needed to build the documentation will *not* +# be accessible, and the documentation will not build correctly. + +import datetime +import os +import sys + +try: + from sphinx_astropy.conf.v1 import * # noqa +except ImportError: + print('ERROR: the documentation requires the sphinx-astropy package to be installed') + sys.exit(1) + +# Get configuration information from setup.cfg +try: + from ConfigParser import ConfigParser +except ImportError: + from configparser import ConfigParser +conf = ConfigParser() + +conf.read([os.path.join(os.path.dirname(__file__), '..', 'setup.cfg')]) +setup_cfg = dict(conf.items('metadata')) + +# -- General configuration ---------------------------------------------------- + +# By default, highlight as Python 3. +highlight_language = 'python3' + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.2' + +# To perform a Sphinx version check that needs to be more specific than +# major.minor, call `check_sphinx_version("x.y.z")` here. +# check_sphinx_version("1.2.1") + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns.append('_templates') + +# This is added to the end of RST files - a good place to put substitutions to +# be used globally. +rst_epilog += """ +""" + +# -- Project information ------------------------------------------------------ + +# This does not *have* to match the package name, but typically does +project = setup_cfg['package_name'] +author = setup_cfg['author'] +copyright = '{0}, {1}'.format( + datetime.datetime.now().year, setup_cfg['author']) + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. + +__import__(setup_cfg['package_name']) +package = sys.modules[setup_cfg['package_name']] + +# The short X.Y version. +version = package.__version__.split('-', 1)[0] +# The full version, including alpha/beta/rc tags. +release = package.__version__ + + +# -- Options for HTML output -------------------------------------------------- + +# A NOTE ON HTML THEMES +# The global astropy configuration uses a custom theme, 'bootstrap-astropy', +# which is installed along with astropy. A different theme can be used or +# the options for this theme can be modified by overriding some of the +# variables set in the global configuration. The variables set in the +# global configuration are listed below, commented out. + + +# Add any paths that contain custom themes here, relative to this directory. +# To use a different custom theme, add the directory containing the theme. +#html_theme_path = [] + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. To override the custom theme, set this to the +# name of a builtin theme or the name of a custom theme in html_theme_path. +#html_theme = None + + +# Please update these texts to match the name of your package. +html_theme_options = { + 'logotext1': 'hela', # white, semi-bold + 'logotext2': '', # orange, light + 'logotext3': ':docs' # white, light + } + + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = '' + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = '' + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '' + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +html_title = '{0} v{1}'.format(project, release) + +# Output file base name for HTML help builder. +htmlhelp_basename = project + 'doc' + + +# -- Options for LaTeX output ------------------------------------------------- + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [('index', project + '.tex', project + u' Documentation', + author, 'manual')] + + +# -- Options for manual page output ------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [('index', project.lower(), project + u' Documentation', + [author], 1)] + + +# -- Options for the edit_on_github extension --------------------------------- + +if eval(setup_cfg.get('edit_on_github')): + extensions += ['sphinx_astropy.ext.edit_on_github'] + + versionmod = __import__(setup_cfg['package_name'] + '.version') + edit_on_github_project = setup_cfg['github_project'] + if versionmod.version.release: + edit_on_github_branch = "v" + versionmod.version.version + else: + edit_on_github_branch = "master" + + edit_on_github_source_root = "" + edit_on_github_doc_root = "docs" + +# -- Resolving issue number to links in changelog ----------------------------- +github_issues_url = 'https://github.com/{0}/issues/'.format(setup_cfg['github_project']) + +# -- Turn on nitpicky mode for sphinx (to warn about references not found) ---- +# +# nitpicky = True +# nitpick_ignore = [] +# +# Some warnings are impossible to suppress, and you can list specific references +# that should be ignored in a nitpick-exceptions file which should be inside +# the docs/ directory. The format of the file should be: +# +# +# +# for example: +# +# py:class astropy.io.votable.tree.Element +# py:class astropy.io.votable.tree.SimpleElement +# py:class astropy.io.votable.tree.SimpleElementWithContent +# +# Uncomment the following lines to enable the exceptions: +# +# for line in open('nitpick-exceptions'): +# if line.strip() == "" or line.startswith("#"): +# continue +# dtype, target = line.split(None, 1) +# target = target.strip() +# nitpick_ignore.append((dtype, six.u(target))) diff --git a/docs/hela/installation.rst b/docs/hela/installation.rst new file mode 100644 index 0000000..deb9d64 --- /dev/null +++ b/docs/hela/installation.rst @@ -0,0 +1,6 @@ +Installation +============ + +To install hela, run:: + + python setup.py install \ No newline at end of file diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst new file mode 100644 index 0000000..71d685b --- /dev/null +++ b/docs/hela/tutorial.rst @@ -0,0 +1,55 @@ +Tutorial +======== + +First, we must generate some example data, which we can do using a built-in +function called `~hela.generate_example_data`, which returns the path to the +example file directory, the training dataset path, and the path to the samples +which we'd like to predict on: + +.. code-block:: python + + from hela import generate_example_data + # Generate an example dataset directory + example_dir, training_dataset, samples_path = generate_example_data() + +What did that just do? We created an example directory called ``linear_data``, +which contains a training dataset described by the metadata file +``training_dataset``. This training dataset has... + +We also generated a bunch of samples with a known slope and intercept, called +``samples_path``, on which we'll apply our trained random forest to estimate +the slope and intercept. + +Once we have these three data structures written and their paths saved, we can +run ``hela`` on the data. First, we'll initialize a `~hela.RandomForest` object +with the paths to the three files/directories that it needs to know about: + +.. code-block:: python + + from hela import RandomForest + import matplotlib.pyplot as plt + + # Initialize a random forest object: + rf = RandomForest(training_dataset, example_dir, samples_path) + +We now have a random forest object ``rf`` which is ready for training. We can +train the random forest with 1000 trees and on a single processor: + +.. code-block:: python + + # Train the random forest: + r2scores = rf.train(num_trees=1000, num_jobs=1) + plt.show() + +The `~hela.RandomForest.train` method returns a dictionary called `r2scores` +which contains the :math:`R^2` scores of the slope and intercept. + +Finally, let's estimate the posterior distributions for the slope and intercept +using the trained random forest on the sample data in ``samples_path``: + +.. code-block:: python + + # Predict posterior distirbutions from random forest + posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) + plt.show() + diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..a53eaa8 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,11 @@ +hela +==== + +Random Forest retrieval for exoplanet atmospheres. + +.. toctree:: + :maxdepth: 2 + + hela/installation.rst + hela/tutorial.rst + diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..93dfe92 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,170 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. changes to make an overview over all changed/added/deprecated items + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\Astropy.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\Astropy.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +:end From 1c527855e2211076e16d2a03b2ef92a7f684448a Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:53:55 +0100 Subject: [PATCH 05/46] First self-building docs with plots --- .travis.yml | 4 +- ah_bootstrap.py | 1009 ++++++++++++++++++++++++++++++++++++++++ docs/hela/api.rst | 2 + docs/hela/tutorial.rst | 36 ++ docs/index.rst | 2 +- setup.cfg | 56 +++ setup.py | 153 ++++++ 7 files changed, 1259 insertions(+), 3 deletions(-) create mode 100644 ah_bootstrap.py create mode 100644 docs/hela/api.rst create mode 100644 setup.cfg create mode 100755 setup.py diff --git a/.travis.yml b/.travis.yml index 1f076a6..6724ac8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -93,14 +93,14 @@ matrix: # Do a PEP8 test with flake8 - os: linux stage: Initial tests - env: MAIN_CMD='flake8 hipparchus --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' + env: MAIN_CMD='flake8 hela --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' allow_failures: # Do a PEP8 test with flake8 # (do allow to fail unless your code completely compliant) # - os: linux # stage: Initial tests - # env: MAIN_CMD='flake8 hipparchus --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' + # env: MAIN_CMD='flake8 hela --count --show-source --statistics $FLAKE8_OPT' SETUP_CMD='' install: diff --git a/ah_bootstrap.py b/ah_bootstrap.py new file mode 100644 index 0000000..67ca92b --- /dev/null +++ b/ah_bootstrap.py @@ -0,0 +1,1009 @@ +""" +This bootstrap module contains code for ensuring that the astropy_helpers +package will be importable by the time the setup.py script runs. It also +includes some workarounds to ensure that a recent-enough version of setuptools +is being used for the installation. + +This module should be the first thing imported in the setup.py of distributions +that make use of the utilities in astropy_helpers. If the distribution ships +with its own copy of astropy_helpers, this module will first attempt to import +from the shipped copy. However, it will also check PyPI to see if there are +any bug-fix releases on top of the current version that may be useful to get +past platform-specific bugs that have been fixed. When running setup.py, use +the ``--offline`` command-line option to disable the auto-upgrade checks. + +When this module is imported or otherwise executed it automatically calls a +main function that attempts to read the project's setup.cfg file, which it +checks for a configuration section called ``[ah_bootstrap]`` the presences of +that section, and options therein, determine the next step taken: If it +contains an option called ``auto_use`` with a value of ``True``, it will +automatically call the main function of this module called +`use_astropy_helpers` (see that function's docstring for full details). +Otherwise no further action is taken and by default the system-installed version +of astropy-helpers will be used (however, ``ah_bootstrap.use_astropy_helpers`` +may be called manually from within the setup.py script). + +This behavior can also be controlled using the ``--auto-use`` and +``--no-auto-use`` command-line flags. For clarity, an alias for +``--no-auto-use`` is ``--use-system-astropy-helpers``, and we recommend using +the latter if needed. + +Additional options in the ``[ah_boostrap]`` section of setup.cfg have the same +names as the arguments to `use_astropy_helpers`, and can be used to configure +the bootstrap script when ``auto_use = True``. + +See https://github.com/astropy/astropy-helpers for more details, and for the +latest version of this module. +""" + +import contextlib +import errno +import io +import locale +import os +import re +import subprocess as sp +import sys + +from distutils import log +from distutils.debug import DEBUG + +from configparser import ConfigParser, RawConfigParser + +import pkg_resources + +from setuptools import Distribution +from setuptools.package_index import PackageIndex + +# This is the minimum Python version required for astropy-helpers +__minimum_python_version__ = (3, 5) + +# TODO: Maybe enable checking for a specific version of astropy_helpers? +DIST_NAME = 'astropy-helpers' +PACKAGE_NAME = 'astropy_helpers' +UPPER_VERSION_EXCLUSIVE = None + +# Defaults for other options +DOWNLOAD_IF_NEEDED = True +INDEX_URL = 'https://pypi.python.org/simple' +USE_GIT = True +OFFLINE = False +AUTO_UPGRADE = True + +# A list of all the configuration options and their required types +CFG_OPTIONS = [ + ('auto_use', bool), ('path', str), ('download_if_needed', bool), + ('index_url', str), ('use_git', bool), ('offline', bool), + ('auto_upgrade', bool) +] + +# Start off by parsing the setup.cfg file + +_err_help_msg = """ +If the problem persists consider installing astropy_helpers manually using pip +(`pip install astropy_helpers`) or by manually downloading the source archive, +extracting it, and installing by running `python setup.py install` from the +root of the extracted source code. +""" + +SETUP_CFG = ConfigParser() + +if os.path.exists('setup.cfg'): + + try: + SETUP_CFG.read('setup.cfg') + except Exception as e: + if DEBUG: + raise + + log.error( + "Error reading setup.cfg: {0!r}\n{1} will not be " + "automatically bootstrapped and package installation may fail." + "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) + +# We used package_name in the package template for a while instead of name +if SETUP_CFG.has_option('metadata', 'name'): + parent_package = SETUP_CFG.get('metadata', 'name') +elif SETUP_CFG.has_option('metadata', 'package_name'): + parent_package = SETUP_CFG.get('metadata', 'package_name') +else: + parent_package = None + +if SETUP_CFG.has_option('options', 'python_requires'): + + python_requires = SETUP_CFG.get('options', 'python_requires') + + # The python_requires key has a syntax that can be parsed by SpecifierSet + # in the packaging package. However, we don't want to have to depend on that + # package, so instead we can use setuptools (which bundles packaging). We + # have to add 'python' to parse it with Requirement. + + from pkg_resources import Requirement + req = Requirement.parse('python' + python_requires) + + # We want the Python version as a string, which we can get from the platform module + import platform + # strip off trailing '+' incase this is a dev install of python + python_version = platform.python_version().strip('+') + # allow pre-releases to count as 'new enough' + if not req.specifier.contains(python_version, True): + if parent_package is None: + message = "ERROR: Python {} is required by this package\n".format(req.specifier) + else: + message = "ERROR: Python {} is required by {}\n".format(req.specifier, parent_package) + sys.stderr.write(message) + sys.exit(1) + +if sys.version_info < __minimum_python_version__: + + if parent_package is None: + message = "ERROR: Python {} or later is required by astropy-helpers\n".format( + __minimum_python_version__) + else: + message = "ERROR: Python {} or later is required by astropy-helpers for {}\n".format( + __minimum_python_version__, parent_package) + + sys.stderr.write(message) + sys.exit(1) + +_str_types = (str, bytes) + + +# What follows are several import statements meant to deal with install-time +# issues with either missing or misbehaving pacakges (including making sure +# setuptools itself is installed): + +# Check that setuptools 30.3 or later is present +from distutils.version import LooseVersion + +try: + import setuptools + assert LooseVersion(setuptools.__version__) >= LooseVersion('30.3') +except (ImportError, AssertionError): + sys.stderr.write("ERROR: setuptools 30.3 or later is required by astropy-helpers\n") + sys.exit(1) + +# typing as a dependency for 1.6.1+ Sphinx causes issues when imported after +# initializing submodule with ah_boostrap.py +# See discussion and references in +# https://github.com/astropy/astropy-helpers/issues/302 + +try: + import typing # noqa +except ImportError: + pass + + +# Note: The following import is required as a workaround to +# https://github.com/astropy/astropy-helpers/issues/89; if we don't import this +# module now, it will get cleaned up after `run_setup` is called, but that will +# later cause the TemporaryDirectory class defined in it to stop working when +# used later on by setuptools +try: + import setuptools.py31compat # noqa +except ImportError: + pass + + +# matplotlib can cause problems if it is imported from within a call of +# run_setup(), because in some circumstances it will try to write to the user's +# home directory, resulting in a SandboxViolation. See +# https://github.com/matplotlib/matplotlib/pull/4165 +# Making sure matplotlib, if it is available, is imported early in the setup +# process can mitigate this (note importing matplotlib.pyplot has the same +# issue) +try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot +except: + # Ignore if this fails for *any* reason* + pass + + +# End compatibility imports... + + +class _Bootstrapper(object): + """ + Bootstrapper implementation. See ``use_astropy_helpers`` for parameter + documentation. + """ + + def __init__(self, path=None, index_url=None, use_git=None, offline=None, + download_if_needed=None, auto_upgrade=None): + + if path is None: + path = PACKAGE_NAME + + if not (isinstance(path, _str_types) or path is False): + raise TypeError('path must be a string or False') + + if not isinstance(path, str): + fs_encoding = sys.getfilesystemencoding() + path = path.decode(fs_encoding) # path to unicode + + self.path = path + + # Set other option attributes, using defaults where necessary + self.index_url = index_url if index_url is not None else INDEX_URL + self.offline = offline if offline is not None else OFFLINE + + # If offline=True, override download and auto-upgrade + if self.offline: + download_if_needed = False + auto_upgrade = False + + self.download = (download_if_needed + if download_if_needed is not None + else DOWNLOAD_IF_NEEDED) + self.auto_upgrade = (auto_upgrade + if auto_upgrade is not None else AUTO_UPGRADE) + + # If this is a release then the .git directory will not exist so we + # should not use git. + git_dir_exists = os.path.exists(os.path.join(os.path.dirname(__file__), '.git')) + if use_git is None and not git_dir_exists: + use_git = False + + self.use_git = use_git if use_git is not None else USE_GIT + # Declared as False by default--later we check if astropy-helpers can be + # upgraded from PyPI, but only if not using a source distribution (as in + # the case of import from a git submodule) + self.is_submodule = False + + @classmethod + def main(cls, argv=None): + if argv is None: + argv = sys.argv + + config = cls.parse_config() + config.update(cls.parse_command_line(argv)) + + auto_use = config.pop('auto_use', False) + bootstrapper = cls(**config) + + if auto_use: + # Run the bootstrapper, otherwise the setup.py is using the old + # use_astropy_helpers() interface, in which case it will run the + # bootstrapper manually after reconfiguring it. + bootstrapper.run() + + return bootstrapper + + @classmethod + def parse_config(cls): + + if not SETUP_CFG.has_section('ah_bootstrap'): + return {} + + config = {} + + for option, type_ in CFG_OPTIONS: + if not SETUP_CFG.has_option('ah_bootstrap', option): + continue + + if type_ is bool: + value = SETUP_CFG.getboolean('ah_bootstrap', option) + else: + value = SETUP_CFG.get('ah_bootstrap', option) + + config[option] = value + + return config + + @classmethod + def parse_command_line(cls, argv=None): + if argv is None: + argv = sys.argv + + config = {} + + # For now we just pop recognized ah_bootstrap options out of the + # arg list. This is imperfect; in the unlikely case that a setup.py + # custom command or even custom Distribution class defines an argument + # of the same name then we will break that. However there's a catch22 + # here that we can't just do full argument parsing right here, because + # we don't yet know *how* to parse all possible command-line arguments. + if '--no-git' in argv: + config['use_git'] = False + argv.remove('--no-git') + + if '--offline' in argv: + config['offline'] = True + argv.remove('--offline') + + if '--auto-use' in argv: + config['auto_use'] = True + argv.remove('--auto-use') + + if '--no-auto-use' in argv: + config['auto_use'] = False + argv.remove('--no-auto-use') + + if '--use-system-astropy-helpers' in argv: + config['auto_use'] = False + argv.remove('--use-system-astropy-helpers') + + return config + + def run(self): + strategies = ['local_directory', 'local_file', 'index'] + dist = None + + # First, remove any previously imported versions of astropy_helpers; + # this is necessary for nested installs where one package's installer + # is installing another package via setuptools.sandbox.run_setup, as in + # the case of setup_requires + for key in list(sys.modules): + try: + if key == PACKAGE_NAME or key.startswith(PACKAGE_NAME + '.'): + del sys.modules[key] + except AttributeError: + # Sometimes mysterious non-string things can turn up in + # sys.modules + continue + + # Check to see if the path is a submodule + self.is_submodule = self._check_submodule() + + for strategy in strategies: + method = getattr(self, 'get_{0}_dist'.format(strategy)) + dist = method() + if dist is not None: + break + else: + raise _AHBootstrapSystemExit( + "No source found for the {0!r} package; {0} must be " + "available and importable as a prerequisite to building " + "or installing this package.".format(PACKAGE_NAME)) + + # This is a bit hacky, but if astropy_helpers was loaded from a + # directory/submodule its Distribution object gets a "precedence" of + # "DEVELOP_DIST". However, in other cases it gets a precedence of + # "EGG_DIST". However, when activing the distribution it will only be + # placed early on sys.path if it is treated as an EGG_DIST, so always + # do that + dist = dist.clone(precedence=pkg_resources.EGG_DIST) + + # Otherwise we found a version of astropy-helpers, so we're done + # Just active the found distribution on sys.path--if we did a + # download this usually happens automatically but it doesn't hurt to + # do it again + # Note: Adding the dist to the global working set also activates it + # (makes it importable on sys.path) by default. + + try: + pkg_resources.working_set.add(dist, replace=True) + except TypeError: + # Some (much) older versions of setuptools do not have the + # replace=True option here. These versions are old enough that all + # bets may be off anyways, but it's easy enough to work around just + # in case... + if dist.key in pkg_resources.working_set.by_key: + del pkg_resources.working_set.by_key[dist.key] + pkg_resources.working_set.add(dist) + + @property + def config(self): + """ + A `dict` containing the options this `_Bootstrapper` was configured + with. + """ + + return dict((optname, getattr(self, optname)) + for optname, _ in CFG_OPTIONS if hasattr(self, optname)) + + def get_local_directory_dist(self): + """ + Handle importing a vendored package from a subdirectory of the source + distribution. + """ + + if not os.path.isdir(self.path): + return + + log.info('Attempting to import astropy_helpers from {0} {1!r}'.format( + 'submodule' if self.is_submodule else 'directory', + self.path)) + + dist = self._directory_import() + + if dist is None: + log.warn( + 'The requested path {0!r} for importing {1} does not ' + 'exist, or does not contain a copy of the {1} ' + 'package.'.format(self.path, PACKAGE_NAME)) + elif self.auto_upgrade and not self.is_submodule: + # A version of astropy-helpers was found on the available path, but + # check to see if a bugfix release is available on PyPI + upgrade = self._do_upgrade(dist) + if upgrade is not None: + dist = upgrade + + return dist + + def get_local_file_dist(self): + """ + Handle importing from a source archive; this also uses setup_requires + but points easy_install directly to the source archive. + """ + + if not os.path.isfile(self.path): + return + + log.info('Attempting to unpack and import astropy_helpers from ' + '{0!r}'.format(self.path)) + + try: + dist = self._do_download(find_links=[self.path]) + except Exception as e: + if DEBUG: + raise + + log.warn( + 'Failed to import {0} from the specified archive {1!r}: ' + '{2}'.format(PACKAGE_NAME, self.path, str(e))) + dist = None + + if dist is not None and self.auto_upgrade: + # A version of astropy-helpers was found on the available path, but + # check to see if a bugfix release is available on PyPI + upgrade = self._do_upgrade(dist) + if upgrade is not None: + dist = upgrade + + return dist + + def get_index_dist(self): + if not self.download: + log.warn('Downloading {0!r} disabled.'.format(DIST_NAME)) + return None + + log.warn( + "Downloading {0!r}; run setup.py with the --offline option to " + "force offline installation.".format(DIST_NAME)) + + try: + dist = self._do_download() + except Exception as e: + if DEBUG: + raise + log.warn( + 'Failed to download and/or install {0!r} from {1!r}:\n' + '{2}'.format(DIST_NAME, self.index_url, str(e))) + dist = None + + # No need to run auto-upgrade here since we've already presumably + # gotten the most up-to-date version from the package index + return dist + + def _directory_import(self): + """ + Import astropy_helpers from the given path, which will be added to + sys.path. + + Must return True if the import succeeded, and False otherwise. + """ + + # Return True on success, False on failure but download is allowed, and + # otherwise raise SystemExit + path = os.path.abspath(self.path) + + # Use an empty WorkingSet rather than the man + # pkg_resources.working_set, since on older versions of setuptools this + # will invoke a VersionConflict when trying to install an upgrade + ws = pkg_resources.WorkingSet([]) + ws.add_entry(path) + dist = ws.by_key.get(DIST_NAME) + + if dist is None: + # We didn't find an egg-info/dist-info in the given path, but if a + # setup.py exists we can generate it + setup_py = os.path.join(path, 'setup.py') + if os.path.isfile(setup_py): + # We use subprocess instead of run_setup from setuptools to + # avoid segmentation faults - see the following for more details: + # https://github.com/cython/cython/issues/2104 + sp.check_output([sys.executable, 'setup.py', 'egg_info'], cwd=path) + + for dist in pkg_resources.find_distributions(path, True): + # There should be only one... + return dist + + return dist + + def _do_download(self, version='', find_links=None): + if find_links: + allow_hosts = '' + index_url = None + else: + allow_hosts = None + index_url = self.index_url + + # Annoyingly, setuptools will not handle other arguments to + # Distribution (such as options) before handling setup_requires, so it + # is not straightforward to programmatically augment the arguments which + # are passed to easy_install + class _Distribution(Distribution): + def get_option_dict(self, command_name): + opts = Distribution.get_option_dict(self, command_name) + if command_name == 'easy_install': + if find_links is not None: + opts['find_links'] = ('setup script', find_links) + if index_url is not None: + opts['index_url'] = ('setup script', index_url) + if allow_hosts is not None: + opts['allow_hosts'] = ('setup script', allow_hosts) + return opts + + if version: + req = '{0}=={1}'.format(DIST_NAME, version) + else: + if UPPER_VERSION_EXCLUSIVE is None: + req = DIST_NAME + else: + req = '{0}<{1}'.format(DIST_NAME, UPPER_VERSION_EXCLUSIVE) + + attrs = {'setup_requires': [req]} + + # NOTE: we need to parse the config file (e.g. setup.cfg) to make sure + # it honours the options set in the [easy_install] section, and we need + # to explicitly fetch the requirement eggs as setup_requires does not + # get honored in recent versions of setuptools: + # https://github.com/pypa/setuptools/issues/1273 + + try: + + context = _verbose if DEBUG else _silence + with context(): + dist = _Distribution(attrs=attrs) + try: + dist.parse_config_files(ignore_option_errors=True) + dist.fetch_build_eggs(req) + except TypeError: + # On older versions of setuptools, ignore_option_errors + # doesn't exist, and the above two lines are not needed + # so we can just continue + pass + + # If the setup_requires succeeded it will have added the new dist to + # the main working_set + return pkg_resources.working_set.by_key.get(DIST_NAME) + except Exception as e: + if DEBUG: + raise + + msg = 'Error retrieving {0} from {1}:\n{2}' + if find_links: + source = find_links[0] + elif index_url != INDEX_URL: + source = index_url + else: + source = 'PyPI' + + raise Exception(msg.format(DIST_NAME, source, repr(e))) + + def _do_upgrade(self, dist): + # Build up a requirement for a higher bugfix release but a lower minor + # release (so API compatibility is guaranteed) + next_version = _next_version(dist.parsed_version) + + req = pkg_resources.Requirement.parse( + '{0}>{1},<{2}'.format(DIST_NAME, dist.version, next_version)) + + package_index = PackageIndex(index_url=self.index_url) + + upgrade = package_index.obtain(req) + + if upgrade is not None: + return self._do_download(version=upgrade.version) + + def _check_submodule(self): + """ + Check if the given path is a git submodule. + + See the docstrings for ``_check_submodule_using_git`` and + ``_check_submodule_no_git`` for further details. + """ + + if (self.path is None or + (os.path.exists(self.path) and not os.path.isdir(self.path))): + return False + + if self.use_git: + return self._check_submodule_using_git() + else: + return self._check_submodule_no_git() + + def _check_submodule_using_git(self): + """ + Check if the given path is a git submodule. If so, attempt to initialize + and/or update the submodule if needed. + + This function makes calls to the ``git`` command in subprocesses. The + ``_check_submodule_no_git`` option uses pure Python to check if the given + path looks like a git submodule, but it cannot perform updates. + """ + + cmd = ['git', 'submodule', 'status', '--', self.path] + + try: + log.info('Running `{0}`; use the --no-git option to disable git ' + 'commands'.format(' '.join(cmd))) + returncode, stdout, stderr = run_cmd(cmd) + except _CommandNotFound: + # The git command simply wasn't found; this is most likely the + # case on user systems that don't have git and are simply + # trying to install the package from PyPI or a source + # distribution. Silently ignore this case and simply don't try + # to use submodules + return False + + stderr = stderr.strip() + + if returncode != 0 and stderr: + # Unfortunately the return code alone cannot be relied on, as + # earlier versions of git returned 0 even if the requested submodule + # does not exist + + # This is a warning that occurs in perl (from running git submodule) + # which only occurs with a malformatted locale setting which can + # happen sometimes on OSX. See again + # https://github.com/astropy/astropy/issues/2749 + perl_warning = ('perl: warning: Falling back to the standard locale ' + '("C").') + if not stderr.strip().endswith(perl_warning): + # Some other unknown error condition occurred + log.warn('git submodule command failed ' + 'unexpectedly:\n{0}'.format(stderr)) + return False + + # Output of `git submodule status` is as follows: + # + # 1: Status indicator: '-' for submodule is uninitialized, '+' if + # submodule is initialized but is not at the commit currently indicated + # in .gitmodules (and thus needs to be updated), or 'U' if the + # submodule is in an unstable state (i.e. has merge conflicts) + # + # 2. SHA-1 hash of the current commit of the submodule (we don't really + # need this information but it's useful for checking that the output is + # correct) + # + # 3. The output of `git describe` for the submodule's current commit + # hash (this includes for example what branches the commit is on) but + # only if the submodule is initialized. We ignore this information for + # now + _git_submodule_status_re = re.compile( + r'^(?P[+-U ])(?P[0-9a-f]{40}) ' + r'(?P\S+)( .*)?$') + + # The stdout should only contain one line--the status of the + # requested submodule + m = _git_submodule_status_re.match(stdout) + if m: + # Yes, the path *is* a git submodule + self._update_submodule(m.group('submodule'), m.group('status')) + return True + else: + log.warn( + 'Unexpected output from `git submodule status`:\n{0}\n' + 'Will attempt import from {1!r} regardless.'.format( + stdout, self.path)) + return False + + def _check_submodule_no_git(self): + """ + Like ``_check_submodule_using_git``, but simply parses the .gitmodules file + to determine if the supplied path is a git submodule, and does not exec any + subprocesses. + + This can only determine if a path is a submodule--it does not perform + updates, etc. This function may need to be updated if the format of the + .gitmodules file is changed between git versions. + """ + + gitmodules_path = os.path.abspath('.gitmodules') + + if not os.path.isfile(gitmodules_path): + return False + + # This is a minimal reader for gitconfig-style files. It handles a few of + # the quirks that make gitconfig files incompatible with ConfigParser-style + # files, but does not support the full gitconfig syntax (just enough + # needed to read a .gitmodules file). + gitmodules_fileobj = io.StringIO() + + # Must use io.open for cross-Python-compatible behavior wrt unicode + with io.open(gitmodules_path) as f: + for line in f: + # gitconfig files are more flexible with leading whitespace; just + # go ahead and remove it + line = line.lstrip() + + # comments can start with either # or ; + if line and line[0] in (':', ';'): + continue + + gitmodules_fileobj.write(line) + + gitmodules_fileobj.seek(0) + + cfg = RawConfigParser() + + try: + cfg.readfp(gitmodules_fileobj) + except Exception as exc: + log.warn('Malformatted .gitmodules file: {0}\n' + '{1} cannot be assumed to be a git submodule.'.format( + exc, self.path)) + return False + + for section in cfg.sections(): + if not cfg.has_option(section, 'path'): + continue + + submodule_path = cfg.get(section, 'path').rstrip(os.sep) + + if submodule_path == self.path.rstrip(os.sep): + return True + + return False + + def _update_submodule(self, submodule, status): + if status == ' ': + # The submodule is up to date; no action necessary + return + elif status == '-': + if self.offline: + raise _AHBootstrapSystemExit( + "Cannot initialize the {0} submodule in --offline mode; " + "this requires being able to clone the submodule from an " + "online repository.".format(submodule)) + cmd = ['update', '--init'] + action = 'Initializing' + elif status == '+': + cmd = ['update'] + action = 'Updating' + if self.offline: + cmd.append('--no-fetch') + elif status == 'U': + raise _AHBootstrapSystemExit( + 'Error: Submodule {0} contains unresolved merge conflicts. ' + 'Please complete or abandon any changes in the submodule so that ' + 'it is in a usable state, then try again.'.format(submodule)) + else: + log.warn('Unknown status {0!r} for git submodule {1!r}. Will ' + 'attempt to use the submodule as-is, but try to ensure ' + 'that the submodule is in a clean state and contains no ' + 'conflicts or errors.\n{2}'.format(status, submodule, + _err_help_msg)) + return + + err_msg = None + cmd = ['git', 'submodule'] + cmd + ['--', submodule] + log.warn('{0} {1} submodule with: `{2}`'.format( + action, submodule, ' '.join(cmd))) + + try: + log.info('Running `{0}`; use the --no-git option to disable git ' + 'commands'.format(' '.join(cmd))) + returncode, stdout, stderr = run_cmd(cmd) + except OSError as e: + err_msg = str(e) + else: + if returncode != 0: + err_msg = stderr + + if err_msg is not None: + log.warn('An unexpected error occurred updating the git submodule ' + '{0!r}:\n{1}\n{2}'.format(submodule, err_msg, + _err_help_msg)) + +class _CommandNotFound(OSError): + """ + An exception raised when a command run with run_cmd is not found on the + system. + """ + + +def run_cmd(cmd): + """ + Run a command in a subprocess, given as a list of command-line + arguments. + + Returns a ``(returncode, stdout, stderr)`` tuple. + """ + + try: + p = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE) + # XXX: May block if either stdout or stderr fill their buffers; + # however for the commands this is currently used for that is + # unlikely (they should have very brief output) + stdout, stderr = p.communicate() + except OSError as e: + if DEBUG: + raise + + if e.errno == errno.ENOENT: + msg = 'Command not found: `{0}`'.format(' '.join(cmd)) + raise _CommandNotFound(msg, cmd) + else: + raise _AHBootstrapSystemExit( + 'An unexpected error occurred when running the ' + '`{0}` command:\n{1}'.format(' '.join(cmd), str(e))) + + + # Can fail of the default locale is not configured properly. See + # https://github.com/astropy/astropy/issues/2749. For the purposes under + # consideration 'latin1' is an acceptable fallback. + try: + stdio_encoding = locale.getdefaultlocale()[1] or 'latin1' + except ValueError: + # Due to an OSX oddity locale.getdefaultlocale() can also crash + # depending on the user's locale/language settings. See: + # http://bugs.python.org/issue18378 + stdio_encoding = 'latin1' + + # Unlikely to fail at this point but even then let's be flexible + if not isinstance(stdout, str): + stdout = stdout.decode(stdio_encoding, 'replace') + if not isinstance(stderr, str): + stderr = stderr.decode(stdio_encoding, 'replace') + + return (p.returncode, stdout, stderr) + + +def _next_version(version): + """ + Given a parsed version from pkg_resources.parse_version, returns a new + version string with the next minor version. + + Examples + ======== + >>> _next_version(pkg_resources.parse_version('1.2.3')) + '1.3.0' + """ + + if hasattr(version, 'base_version'): + # New version parsing from setuptools >= 8.0 + if version.base_version: + parts = version.base_version.split('.') + else: + parts = [] + else: + parts = [] + for part in version: + if part.startswith('*'): + break + parts.append(part) + + parts = [int(p) for p in parts] + + if len(parts) < 3: + parts += [0] * (3 - len(parts)) + + major, minor, micro = parts[:3] + + return '{0}.{1}.{2}'.format(major, minor + 1, 0) + + +class _DummyFile(object): + """A noop writeable object.""" + + errors = '' # Required for Python 3.x + encoding = 'utf-8' + + def write(self, s): + pass + + def flush(self): + pass + + +@contextlib.contextmanager +def _verbose(): + yield + +@contextlib.contextmanager +def _silence(): + """A context manager that silences sys.stdout and sys.stderr.""" + + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = _DummyFile() + sys.stderr = _DummyFile() + exception_occurred = False + try: + yield + except: + exception_occurred = True + # Go ahead and clean up so that exception handling can work normally + sys.stdout = old_stdout + sys.stderr = old_stderr + raise + + if not exception_occurred: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +class _AHBootstrapSystemExit(SystemExit): + def __init__(self, *args): + if not args: + msg = 'An unknown problem occurred bootstrapping astropy_helpers.' + else: + msg = args[0] + + msg += '\n' + _err_help_msg + + super(_AHBootstrapSystemExit, self).__init__(msg, *args[1:]) + + +BOOTSTRAPPER = _Bootstrapper.main() + + +def use_astropy_helpers(**kwargs): + """ + Ensure that the `astropy_helpers` module is available and is importable. + This supports automatic submodule initialization if astropy_helpers is + included in a project as a git submodule, or will download it from PyPI if + necessary. + + Parameters + ---------- + + path : str or None, optional + A filesystem path relative to the root of the project's source code + that should be added to `sys.path` so that `astropy_helpers` can be + imported from that path. + + If the path is a git submodule it will automatically be initialized + and/or updated. + + The path may also be to a ``.tar.gz`` archive of the astropy_helpers + source distribution. In this case the archive is automatically + unpacked and made temporarily available on `sys.path` as a ``.egg`` + archive. + + If `None` skip straight to downloading. + + download_if_needed : bool, optional + If the provided filesystem path is not found an attempt will be made to + download astropy_helpers from PyPI. It will then be made temporarily + available on `sys.path` as a ``.egg`` archive (using the + ``setup_requires`` feature of setuptools. If the ``--offline`` option + is given at the command line the value of this argument is overridden + to `False`. + + index_url : str, optional + If provided, use a different URL for the Python package index than the + main PyPI server. + + use_git : bool, optional + If `False` no git commands will be used--this effectively disables + support for git submodules. If the ``--no-git`` option is given at the + command line the value of this argument is overridden to `False`. + + auto_upgrade : bool, optional + By default, when installing a package from a non-development source + distribution ah_boostrap will try to automatically check for patch + releases to astropy-helpers on PyPI and use the patched version over + any bundled versions. Setting this to `False` will disable that + functionality. If the ``--offline`` option is given at the command line + the value of this argument is overridden to `False`. + + offline : bool, optional + If `False` disable all actions that require an internet connection, + including downloading packages from the package index and fetching + updates to any git submodule. Defaults to `True`. + """ + + global BOOTSTRAPPER + + config = BOOTSTRAPPER.config + config.update(**kwargs) + + # Create a new bootstrapper with the updated configuration and run it + BOOTSTRAPPER = _Bootstrapper(**config) + BOOTSTRAPPER.run() diff --git a/docs/hela/api.rst b/docs/hela/api.rst new file mode 100644 index 0000000..037d0cf --- /dev/null +++ b/docs/hela/api.rst @@ -0,0 +1,2 @@ + +.. automodapi:: hela diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 71d685b..d7378c0 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -41,6 +41,22 @@ train the random forest with 1000 trees and on a single processor: r2scores = rf.train(num_trees=1000, num_jobs=1) plt.show() +.. plot:: + + from hela import generate_example_data + # Generate an example dataset directory + example_dir, training_dataset, samples_path = generate_example_data() + + from hela import RandomForest + import matplotlib.pyplot as plt + + # Initialize a random forest object: + rf = RandomForest(training_dataset, example_dir, samples_path) + + # Train the random forest: + r2scores = rf.train(num_trees=1000, num_jobs=1) + plt.show() + The `~hela.RandomForest.train` method returns a dictionary called `r2scores` which contains the :math:`R^2` scores of the slope and intercept. @@ -53,3 +69,23 @@ using the trained random forest on the sample data in ``samples_path``: posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) plt.show() +.. plot:: + + from hela import generate_example_data + # Generate an example dataset directory + example_dir, training_dataset, samples_path = generate_example_data() + + from hela import RandomForest + import matplotlib.pyplot as plt + + # Initialize a random forest object: + rf = RandomForest(training_dataset, example_dir, samples_path) + + # Train the random forest: + r2scores = rf.train(num_trees=1000, num_jobs=1) + plt.close() + + # Predict posterior distirbutions from random forest + posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) + plt.show() + diff --git a/docs/index.rst b/docs/index.rst index a53eaa8..2b33a81 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,4 +8,4 @@ Random Forest retrieval for exoplanet atmospheres. hela/installation.rst hela/tutorial.rst - + hela/api.rst diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..4aae0e7 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,56 @@ +[build_sphinx] +source-dir = docs +build-dir = docs/_build +all_files = 1 + +[build_docs] +source-dir = docs +build-dir = docs/_build +all_files = 1 + +[upload_docs] +upload-dir = docs/_build/html +show-response = 1 + +[tool:pytest] +minversion = 3.0 +norecursedirs = build docs/_build +doctest_plus = enabled +addopts = -p no:warnings + +[ah_bootstrap] +auto_use = False + +[pycodestyle] +# E101 - mix of tabs and spaces +# W191 - use of tabs +# W291 - trailing whitespace +# W292 - no newline at end of file +# W293 - trailing whitespace +# W391 - blank line at end of file +# E111 - 4 spaces per indentation level +# E112 - 4 spaces per indentation level +# E113 - 4 spaces per indentation level +# E901 - SyntaxError or IndentationError +# E902 - IOError +select = E101,W191,W291,W292,W293,W391,E111,E112,E113,E901,E902 +exclude = extern,sphinx,*parsetab.py + +[metadata] +package_name = hela +description = A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. +author = Pablo Márquez-Neila and Chloe Fisher +license = Other +url = https://github.com/exoclime/HELA +edit_on_github = False +github_project = exoclime/HELA +# install_requires should be formatted as a comma-separated list, e.g.: +# install_requires = astropy, scipy, matplotlib +install_requires = astropy, sklearn, matplotlib, joblib +# version should be PEP386 compatible (http://www.python.org/dev/peps/pep-0386) +version = 0.0.dev0 +# Note: you will also need to change this in your package's __init__.py +minimum_python_version = 3.5 + +[entry_points] + diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..2474ff3 --- /dev/null +++ b/setup.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +import glob +import os +import sys + +try: + from configparser import ConfigParser +except ImportError: + from ConfigParser import ConfigParser + +# Get some values from the setup.cfg +conf = ConfigParser() +conf.read(['setup.cfg']) +metadata = dict(conf.items('metadata')) + +PACKAGENAME = metadata.get('package_name', 'packagename') +DESCRIPTION = metadata.get('description', 'Astropy Package Template') +AUTHOR = metadata.get('author', 'Astropy Developers') +AUTHOR_EMAIL = metadata.get('author_email', '') +LICENSE = metadata.get('license', 'unknown') +URL = metadata.get('url', 'http://astropy.org') +__minimum_python_version__ = metadata.get("minimum_python_version", "2.7") + +# Enforce Python version check - this is the same check as in __init__.py but +# this one has to happen before importing ah_bootstrap. +if sys.version_info < tuple((int(val) for val in __minimum_python_version__.split('.'))): + sys.stderr.write("ERROR: packagename requires Python {} or later\n".format(__minimum_python_version__)) + sys.exit(1) + +# Import ah_bootstrap after the python version validation + +import ah_bootstrap +from setuptools import setup + +# A dirty hack to get around some early import/configurations ambiguities +if sys.version_info[0] >= 3: + import builtins +else: + import __builtin__ as builtins +builtins._ASTROPY_SETUP_ = True + +from astropy_helpers.astropy_helpers.setup_helpers import (register_commands, + get_debug_option, + get_package_info) +from astropy_helpers.astropy_helpers.git_helpers import get_git_devstr +from astropy_helpers.astropy_helpers.version_helpers import generate_version_py + + +# order of priority for long_description: +# (1) set in setup.cfg, +# (2) load LONG_DESCRIPTION.rst, +# (3) load README.rst, +# (4) package docstring +readme_glob = 'README*' +_cfg_long_description = metadata.get('long_description', '') +if _cfg_long_description: + LONG_DESCRIPTION = _cfg_long_description + +elif os.path.exists('LONG_DESCRIPTION.rst'): + with open('LONG_DESCRIPTION.rst') as f: + LONG_DESCRIPTION = f.read() + +elif len(glob.glob(readme_glob)) > 0: + with open(glob.glob(readme_glob)[0]) as f: + LONG_DESCRIPTION = f.read() + +else: + # Get the long description from the package's docstring + __import__(PACKAGENAME) + package = sys.modules[PACKAGENAME] + LONG_DESCRIPTION = package.__doc__ + +# Store the package name in a built-in variable so it's easy +# to get from other parts of the setup infrastructure +builtins._ASTROPY_PACKAGE_NAME_ = PACKAGENAME + +# VERSION should be PEP440 compatible (http://www.python.org/dev/peps/pep-0440) +VERSION = metadata.get('version', '0.0.dev0') + +# Indicates if this version is a release version +RELEASE = 'dev' not in VERSION + +if not RELEASE: + VERSION += get_git_devstr(False) + +# Populate the dict of setup command overrides; this should be done before +# invoking any other functionality from distutils since it can potentially +# modify distutils' behavior. +cmdclassd = register_commands(PACKAGENAME, VERSION, RELEASE) + +# Freeze build information in version.py +generate_version_py(PACKAGENAME, VERSION, RELEASE, + get_debug_option(PACKAGENAME)) + +# Treat everything in scripts except README* as a script to be installed +scripts = [fname for fname in glob.glob(os.path.join('scripts', '*')) + if not os.path.basename(fname).startswith('README')] + + +# Get configuration information from all of the various subpackages. +# See the docstring for setup_helpers.update_package_files for more +# details. +package_info = get_package_info() + +# Add the project-global data +package_info['package_data'].setdefault(PACKAGENAME, []) +package_info['package_data'][PACKAGENAME].append('data/*') + +# Define entry points for command-line scripts +entry_points = {'console_scripts': []} + +if conf.has_section('entry_points'): + entry_point_list = conf.items('entry_points') + for entry_point in entry_point_list: + entry_points['console_scripts'].append('{0} = {1}'.format( + entry_point[0], entry_point[1])) + +# Include all .c files, recursively, including those generated by +# Cython, since we can not do this in MANIFEST.in with a "dynamic" +# directory name. +c_files = [] +for root, dirs, files in os.walk(PACKAGENAME): + for filename in files: + if filename.endswith('.c'): + c_files.append( + os.path.join( + os.path.relpath(root, PACKAGENAME), filename)) +package_info['package_data'][PACKAGENAME].extend(c_files) + +# Note that requires and provides should not be included in the call to +# ``setup``, since these are now deprecated. See this link for more details: +# https://groups.google.com/forum/#!topic/astropy-dev/urYO8ckB2uM + +setup(name=PACKAGENAME, + version=VERSION, + description=DESCRIPTION, + scripts=scripts, + install_requires=[s.strip() for s in metadata.get('install_requires', + 'astropy').split(',')], + author=AUTHOR, + author_email=AUTHOR_EMAIL, + license=LICENSE, + url=URL, + long_description=LONG_DESCRIPTION, + cmdclass=cmdclassd, + zip_safe=False, + use_2to3=False, + entry_points=entry_points, + python_requires='>={}'.format(__minimum_python_version__), + **package_info +) From 9c546b0b51faebecd011954e87652c9c7761577e Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 20:56:22 +0100 Subject: [PATCH 06/46] Putting in 'fitting a line' header --- docs/hela/tutorial.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index d7378c0..133cc9e 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -1,6 +1,9 @@ Tutorial ======== +Fitting a line +-------------- + First, we must generate some example data, which we can do using a built-in function called `~hela.generate_example_data`, which returns the path to the example file directory, the training dataset path, and the path to the samples @@ -87,5 +90,6 @@ using the trained random forest on the sample data in ``samples_path``: # Predict posterior distirbutions from random forest posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) + plt.tight_layout() plt.show() From 17e5d0452fb66c9244a6b55df975faffe5c240f6 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Thu, 7 Nov 2019 21:02:23 +0100 Subject: [PATCH 07/46] Adding logo to docs page --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index a14a7f7..8ccbfb9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -120,7 +120,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = '' +html_logo = '../img/HELA_logo1.png' # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 From 28ab4e062f757fdceddac3595c3f485bcd7463f9 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 02:56:01 +0100 Subject: [PATCH 08/46] Adding upstream changes to the hela package --- .gitignore | 1 + hela/forest.py | 26 ++++-- hela/models.py | 174 +++++++++++++++++++++++++++++++++- hela/plot.py | 222 +++++++++++++++++++++++++++++++------------- hela/wpercentile.py | 61 ++++++++++++ 5 files changed, 406 insertions(+), 78 deletions(-) create mode 100644 hela/wpercentile.py diff --git a/.gitignore b/.gitignore index 72ed8f5..d9e2c3f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ saved output example_dataset +hela/version.py # Created by https://www.gitignore.io/api/osx,linux,python diff --git a/hela/forest.py b/hela/forest.py index d57fd73..9226f64 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -8,6 +8,7 @@ from .dataset import load_dataset, load_data_file from .models import Model from .plot import predicted_vs_real, feature_importances, posterior_matrix +from .wpercentile import wpercentile __all__ = ['RandomForest', 'generate_example_data'] @@ -57,9 +58,11 @@ def compute_feature_importance(model, dataset, output_path): return np.array([forest_i.feature_importances_ for forest_i in forests]) -def prediction_ranges(preds): - percentiles = (np.percentile(pred_i, [50, 16, 84]) for pred_i in preds.T) - return np.array([(a, c - a, a - b) for a, b, c in percentiles]) + def data_ranges(posterior, percentiles=(50, 16, 84)): + samples, weights = posterior + values = wpercentile(samples, weights, percentiles, axis=0) + ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) + return ranges.T class RandomForest(object): @@ -128,6 +131,7 @@ def feature_importance(self): ------- feature_importances : `~numpy.ndarray` """ + self.model.enable_posterior = False return compute_feature_importance(self.model, self.dataset, self.model_path) @@ -152,25 +156,29 @@ def predict(self, plot_posterior=True): # Loading data from '{}'...".format(data_file) data, _ = load_data_file(self.data_file, model.rf.n_features_) - preds = model.trees_predict(data[0]) + posterior = model.posterior(data[0]) - pred_ranges = prediction_ranges(preds) - - for name_i, pred_range_i in zip(model.names, pred_ranges): + posterior_ranges = data_ranges(posterior) + for name_i, pred_range_i in zip(model.names, posterior_ranges): print("Prediction for {}: {:.3g} " "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) if plot_posterior: # Plotting and saving the posterior matrix..." - fig = posterior_matrix(preds, None, + fig = posterior_matrix(posterior, names=model.names, ranges=model.ranges, colors=model.colors) os.makedirs(self.output_path, exist_ok=True) fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), bbox_inches='tight') - return preds.T + return posterior +def data_ranges(posterior, percentiles=(50, 16, 84)): + samples, weights = posterior + values = wpercentile(samples, weights, percentiles, axis=0) + ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) + return ranges.T def generate_example_data(): """ diff --git a/hela/models.py b/hela/models.py index 5093c0f..6dfc18f 100644 --- a/hela/models.py +++ b/hela/models.py @@ -1,16 +1,22 @@ +from collections import namedtuple + import numpy as np from sklearn import ensemble from sklearn.preprocessing import MinMaxScaler +from sklearn.utils import check_random_state + +from tqdm import tqdm + +__all__ = ['Model', 'Posterior', 'resample_posterior'] class Model(object): """ Class for models. """ - def __init__(self, num_trees, num_jobs, - names, ranges, colors, - verbose=1): + def __init__(self, num_trees, num_jobs, names, ranges, colors, + verbose=1, enable_posterior=True): """ Parameters ---------- @@ -20,6 +26,7 @@ def __init__(self, num_trees, num_jobs, ranges colors verbose + enable_posterior """ scaler = MinMaxScaler(feature_range=(0, 100)) rf = ensemble.RandomForestRegressor(n_estimators=num_trees, @@ -40,6 +47,12 @@ def __init__(self, num_trees, num_jobs, self.names = names self.colors = colors + # To compute the posteriors + self.enable_posterior = enable_posterior + self.data_leaves = None + self.data_weights = None + self.data_y = None + def _scaler_fit(self, y): if y.ndim == 1: y = y[:, None] @@ -65,6 +78,13 @@ def fit(self, x, y): self._scaler_fit(y) self.rf.fit(x, self._scaler_transform(y)) + # Build the structures to quickly compute the posteriors + if self.enable_posterior: + data_leaves = self.rf.apply(x).T + self.data_leaves = _as_smallest_udtype(data_leaves) + self.data_weights = np.array([_tree_weights(tree, len(y)) for tree in self.rf]) + self.data_y = y + def predict(self, x): pred = self.rf.predict(x) return self._scaler_inverse_transform(pred) @@ -72,8 +92,8 @@ def predict(self, x): def get_params(self, deep=True): return {"num_trees": self.num_trees, "num_jobs": self.num_jobs, "names": self.names, "ranges": self.ranges, - "colors": self.colors, - "verbose": self.verbose} + "colors": self.colors, "verbose": self.verbose, + "enable_posterior": self.enable_posterior,} def trees_predict(self, x): @@ -84,3 +104,147 @@ def trees_predict(self, x): for i in self.rf.estimators_]) return self._scaler_inverse_transform(preds) + def predict_median(self, x): + return self.predict_percentile(x, 50) + + def predict_percentile(self, x, percentile): + + if not self.enable_posterior: + raise ValueError("Cannot compute posteriors with this model. " + "Set `enable_posterior` to True to enable posterior computation.") + + # Find the leaves for the query points + leaves_x = self.rf.apply(x) + + if len(x) > self.num_trees: + # If there are many queries, it is faster to find points using a cache + return _posterior_percentile_cache( + self.data_leaves, self.data_weights, + self.data_y, leaves_x, percentile + ) + else: + # For few queries, it is faster if we just compute the posterior for each element + return _posterior_percentile_nocache( + self.data_leaves, self.data_weights, + self.data_y, leaves_x, percentile + ) + + def posterior(self, x): + leaves_x = self.rf.apply(x[None, :])[0] + if not self.enable_posterior: + raise ValueError("Cannot compute posteriors with this model. " + "Set `enable_posterior` to True to enable posterior computation.") + + return _posterior( + self.data_leaves, self.data_weights, + self.data_y, leaves_x + ) + +def _posterior(data_leaves, data_weights, data_y, query_leaves): + + weights_x = (query_leaves[:, None] == data_leaves) * data_weights + weights_x = _as_smallest_udtype(weights_x.sum(0)) + + # Remove samples with weight zero + mask = weights_x != 0 + samples = data_y[mask] + weights = weights_x[mask] + + return Posterior(samples, weights) + + +def _posterior_percentile_nocache(data_leaves, data_weights, data_y, query_leaves, percentile): + + values = [] + + # Computing percentiles... + for leaves_x_i in tqdm(query_leaves): + posterior = _posterior( + data_leaves, data_weights, + data_y, leaves_x_i + ) + samples = np.repeat(posterior.samples, posterior.weights, axis=0) + value = np.percentile(samples, percentile, axis=0) + values.append(value) + + return np.array(values) + + +def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, percentile): + + # Build a dictionary for fast access of the contents of the leaves. + # Building cache... + cache = [ + _build_leaves_cache(leaves_i, weights_i) + for leaves_i, weights_i in zip(data_leaves, data_weights) + ] + + values = [] + # Check the contents of the leaves in leaves_x + # Computing percentiles... + for leaves_x_i in tqdm(query_leaves): + data_elements = [] + for tree, leaves_x_i_j in enumerate(leaves_x_i): + aux = cache[tree][leaves_x_i_j] + data_elements.extend(aux) + value = np.percentile(data_y[data_elements], percentile, axis=0) + values.append(value) + + return np.array(values) + + +def _build_leaves_cache(leaves, weights): + + result = {} + for index, (leaf, weight) in enumerate(zip(leaves, weights)): + if weight == 0: + continue + + if leaf not in result: + result[leaf] = [index] * weight + else: + result[leaf].extend([index] * weight) + + return result + + +def _generate_sample_indices(random_state, n_samples): + random_instance = check_random_state(random_state) + sample_indices = random_instance.randint(0, n_samples, n_samples) + + return sample_indices + + +def _tree_weights(tree, n_samples): + indices = _generate_sample_indices(tree.random_state, n_samples) + res = np.bincount(indices, minlength=n_samples) + return _as_smallest_udtype(res) + +# Posteriors are represented as a collection of weighted samples +Posterior = namedtuple("Posterior", ["samples", "weights"]) + +def resample_posterior(posterior, num_draws): + + p = posterior.weights / posterior.weights.sum() + indices = np.random.choice(len(posterior.samples), size=num_draws, p=p) + + new_weights = np.bincount(indices, minlength=len(posterior.samples)) + mask = new_weights != 0 + new_samples = posterior.samples[mask] + new_weights = posterior.weights[mask] + + return Posterior(new_samples, new_weights) + +def _as_smallest_udtype(arr): + return arr.astype(_smallest_udtype(arr.max())) + + +def _smallest_udtype(value): + + dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] + + for dtype in dtypes: + if value <= np.iinfo(dtype).max: + return dtype + + raise ValueError("value is too large for any dtype") diff --git a/hela/plot.py b/hela/plot.py index b623c30..d986d13 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -2,16 +2,23 @@ import numpy as np import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap +from matplotlib.colors import LinearSegmentedColormap, to_rgba_array from sklearn import metrics, neighbors from sklearn.preprocessing import MinMaxScaler +from tqdm import tqdm + +from .models import resample_posterior +from .wpercentile import wmedian + __all__ = ['predicted_vs_real', 'feature_importances', 'posterior_matrix'] +POSTERIOR_MAX_SIZE = 10000 + -def predicted_vs_real(y_real, y_pred, names, ranges): +def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): num_plots = y_pred.shape[1] num_plot_rows = int(np.sqrt(num_plots)) num_plot_cols = (num_plots - 1) // num_plot_rows + 1 @@ -25,9 +32,18 @@ def predicted_vs_real(y_real, y_pred, names, ranges): current_real = y_real[:, dim] current_pred = y_pred[:, dim] + if alpha == 'auto': + # TODO: this is a quick fix. Check at some point in the future. + aux, *_ = np.histogram2d(current_real, current_pred, bins=60) + alpha_ = 1 / np.percentile(aux[aux > 0], 60) + elif alpha == 'none': + alpha_ = None + else: + alpha_ = alpha + r2 = metrics.r2_score(current_real, current_pred) label = "$R^2 = {:.3f}$".format(r2) - ax.plot(current_real, current_pred, '.', label=label) + ax.plot(current_real, current_pred, '.', label=label, alpha=alpha_) ax.plot(range_i, range_i, '--', linewidth=3, color="C3", alpha=0.8) @@ -65,7 +81,10 @@ def feature_importances(forests, names, colors): return fig -def posterior_matrix(estimations, y, names, ranges, colors, soft_colors=None): +def posterior_matrix(posterior, names, ranges, colors, soft_colors=None): + + samples, weights = posterior + cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) for c in colors] @@ -74,7 +93,7 @@ def posterior_matrix(estimations, y, names, ranges, colors, soft_colors=None): if soft_colors is None: soft_colors = colors - num_dims = estimations.shape[1] + num_dims = samples.shape[1] fig, axes = plt.subplots(nrows=num_dims, ncols=num_dims, figsize=(2 * num_dims, 2 * num_dims)) @@ -82,8 +101,11 @@ def posterior_matrix(estimations, y, names, ranges, colors, soft_colors=None): bottom=0.07, top=1 - 0.05, hspace=0.05, wspace=0.05) - for ax, dims in zip(axes.flat, product(range(num_dims), range(num_dims))): - dims = list(dims[::-1]) + iterable = zip(axes.flat, product(range(num_dims), range(num_dims))) + for ax, dims in tqdm(iterable, total=num_dims**2): + # Flip dims. + dims = [dims[1], dims[0]] + ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) ax.title.set_visible(False) @@ -110,72 +132,121 @@ def posterior_matrix(estimations, y, names, ranges, colors, soft_colors=None): ax.yaxis.set_visible(False) if dims[0] < dims[1]: - locations, kd_probs, *_ = _kernel_density_joint( - estimations[:, dims], ranges[dims]) - ax.contour(locations[0], locations[1], - kd_probs, - colors=colors[dims[0]], - linewidths=0.5 - # 'copper', # 'hot', 'magma' ('copper' with white background) - ) - histogram, grid_x, grid_y = _histogram(estimations[:, dims], - ranges[dims]) - ax.pcolormesh(grid_x, grid_y, histogram, cmap=cmaps[dims[0]]) - - expected = np.median(estimations[:, dims], axis=0) - ax.plot([expected[0], expected[0]], - [ranges[dims[1]][0], ranges[dims[1]][1]], '-', linewidth=1, - color='#222222') - ax.plot([ranges[dims[0]][0], ranges[dims[0]][1]], - [expected[1], expected[1]], '-', linewidth=1, - color='#222222') - ax.plot(expected[0], expected[1], '.', color='#222222') - ax.axis('auto') - if y is not None: - real = y[dims] - ax.plot(real[0], real[1], '*', markersize=10, color='#FF0000') - ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], - ranges[dims[1]][0], ranges[dims[1]][1]]) + _plot_histogram2d( + ax, posterior, + color=colors[dims[0]], + cmap=cmaps[dims[0]], + dims=dims, + ranges=ranges[dims] + ) elif dims[0] > dims[1]: - ax.plot(estimations[:, dims[0]], estimations[:, dims[1]], '.', - color=soft_colors[dims[1]]) - ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], - ranges[dims[1]][0], ranges[dims[1]][1]]) + _plot_samples( + ax, posterior, + color=soft_colors[dims[1]], + dims=dims, + ranges=ranges[dims] + ) else: - histogram, bins = _histogram(estimations[:, dims[:1]], - ranges=ranges[dims[:1]]) + histogram, bins = _histogram1d( + samples[:, dims[:1]], weights, + ranges=ranges[dims[:1]] + ) ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], - width=bins[1] - bins[0]) + width=bins[1]-bins[0]) kd_probs = histogram - expected = np.median(estimations[:, dims[0]]) + expected = wmedian(samples[:, dims[0]], weights) ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', linewidth=1, color='#222222') - if y is not None: - real = y[dims[0]] - ax.plot([real, real], [0, kd_probs.max()], 'r-') ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], 0, 1.1 * kd_probs.max()]) + # fig.tight_layout(pad=0) + # fig.tight_layout(pad=0) return fig +def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): + + samples, weights = posterior + # For efficiency, do not compute the kernel density + # over all the samples of the posterior. Subsample first. + if len(samples) > POSTERIOR_MAX_SIZE: + samples, weights = resample_posterior(posterior, POSTERIOR_MAX_SIZE) + + locations, kd_probs, *_ = _kernel_density_joint( + samples[:, dims], + weights, + ranges + ) + ax.contour( + locations[0], locations[1], + kd_probs, + colors=color, + linewidths=0.5 + ) + + # For the rest of the plot we use the complete posterior + samples, weights = posterior + histogram, grid_x, grid_y = _histogram2d( + samples[:, dims], weights, + ranges + ) + ax.pcolormesh(grid_x, grid_y, histogram, cmap=cmap) + + expected = wmedian(samples[:, dims], weights, axis=0) + ax.plot([expected[0], expected[0]], [ranges[1][0], ranges[1][1]], + '-', linewidth=1, color='#222222') + ax.plot([ranges[0][0], ranges[0][1]], [expected[1], expected[1]], + '-', linewidth=1, color='#222222') + ax.plot(expected[0], expected[1], '.', color='#222222') + ax.axis('auto') + ax.axis([ranges[0][0], ranges[0][1], + ranges[1][0], ranges[1][1]]) + + +def _plot_samples(ax, posterior, color, dims, ranges): + + # For efficiency, do not plot all the samples of the posterior. Subsample first. + if len(posterior.samples) > POSTERIOR_MAX_SIZE: + posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) + + samples, weights = posterior + + points_alpha = _weights_to_alpha(weights) + + current_colors = to_rgba_array(color) + current_colors = np.tile(current_colors, (len(samples), 1)) + current_colors[:, 3] = points_alpha + + ax.scatter( + x=samples[:, dims[0]], + y=samples[:, dims[1]], + s=100, + c=current_colors, + marker='.', + linewidth=0 + ) + + ax.axis([ranges[0][0], ranges[0][1], + ranges[1][0], ranges[1][1]]) + def _min_max_scaler(ranges, feature_range=(0, 100)): res = MinMaxScaler() res.data_max_ = ranges[:, 1] res.data_min_ = ranges[:, 0] res.data_range_ = res.data_max_ - res.data_min_ - res.scale_ = (feature_range[1] - feature_range[0]) / ( - ranges[:, 1] - ranges[:, 0]) + res.scale_ = (feature_range[1] - feature_range[0]) / (ranges[:, 1] - ranges[:, 0]) res.min_ = -res.scale_ * res.data_min_ res.n_samples_seen_ = 1 res.feature_range = feature_range return res -def _kernel_density_joint(estimations, ranges, bandwidth=1 / 25): +def _kernel_density_joint(samples, weights, ranges, bandwidth=1/25): + ndims = len(ranges) scaler = _min_max_scaler(ranges, feature_range=(0, 100)) @@ -183,28 +254,51 @@ def _kernel_density_joint(estimations, ranges, bandwidth=1 / 25): bandwidth = bandwidth * 100 # step = 1.0 - kd = neighbors.KernelDensity(bandwidth=bandwidth).fit( - scaler.transform(estimations)) - locations1d = np.arange(0, 100, 1) - locations = np.reshape(np.meshgrid(*[locations1d] * ndims), (ndims, -1)).T + kd = neighbors.KernelDensity(bandwidth=bandwidth) + kd.fit(scaler.transform(samples), sample_weight=weights) + + grid_shape = [100] * ndims + grid = np.indices(grid_shape) + locations = np.reshape(grid, (ndims, -1)).T kd_probs = np.exp(kd.score_samples(locations)) - shape = (ndims,) + (len(locations1d),) * ndims + shape = (ndims, *grid_shape) locations = scaler.inverse_transform(locations) locations = np.reshape(locations.T, shape) - kd_probs = np.reshape(kd_probs, shape[1:]) + kd_probs = np.reshape(kd_probs, grid_shape) return locations, kd_probs, kd -def _histogram(estimations, ranges, bins=20): - if len(ranges) == 1: - histogram, edges = np.histogram(estimations[:, 0], bins=bins, - range=ranges[0]) - return histogram, edges +def _histogram1d(samples, weights, ranges, bins=20): + + assert(len(ranges) == 1) + + histogram, edges = np.histogram( + samples[:, 0], + bins=bins, + range=ranges[0], + weights=weights + ) + return histogram, edges + + +def _histogram2d(samples, weights, ranges, bins=20): + + assert(len(ranges) == 2) + + histogram, xedges, yedges = np.histogram2d( + samples[:, 0], + samples[:, 1], + bins=bins, + range=ranges, + weights=weights + ) + grid_x, grid_y = np.meshgrid(xedges, yedges) + return histogram.T, grid_x, grid_y, + + +def _weights_to_alpha(weights): - if len(ranges) == 2: - histogram, xedges, yedges = np.histogram2d(estimations[:, 0], - estimations[:, 1], - bins=bins, range=ranges) - grid_x, grid_y = np.meshgrid(xedges, yedges) - return histogram.T, grid_x, grid_y, + # Maximum weight (removing potential outliers) + max_weight = np.percentile(weights, 98) + return np.clip(weights / max_weight, 0, 1) diff --git a/hela/wpercentile.py b/hela/wpercentile.py new file mode 100644 index 0000000..456bc12 --- /dev/null +++ b/hela/wpercentile.py @@ -0,0 +1,61 @@ +import numpy as np + +__all__ = ["wpercentile", "wmedian"] + + +def _wpercentile1d(data, weights, percentiles): + + if data.ndim > 1 or weights.ndim > 1: + raise ValueError("data and weights must be one-dimensional arrays") + + if data.shape != weights.shape: + raise ValueError("data and weights must have the same shape") + + data = np.asarray(data) + weights = np.asarray(weights) + percentiles = np.asarray(percentiles) + + sort_indices = np.argsort(data) + sorted_data = data[sort_indices] + sorted_weights = weights[sort_indices] + + cumsum_weights = np.cumsum(sorted_weights) + sum_weights = cumsum_weights[-1] + + pn = 100 * (cumsum_weights - 0.5*sorted_weights) / sum_weights + + return np.interp(percentiles, pn, sorted_data) + + +def wpercentile(data, weights, percentiles, axis=None): + """ + Compute percentiles of a weighted sample. + """ + if axis is None: + data = np.ravel(data) + weights = np.ravel(weights) + return _wpercentile1d(data, weights, percentiles) + + axis = np.atleast_1d(axis) + + # Reshape the arrays for proper computation + # Move the requested axis to the final dimensions + dest_axis = list(range(len(axis))) + data2 = np.moveaxis(data, axis, dest_axis) + + ndim = len(axis) + shape = data2.shape + newshape = (np.prod(shape[:ndim]), np.prod(shape[ndim:])) + newdata = np.reshape(data2, newshape) + newweights = np.reshape(weights, newshape[0]) + + result = np.apply_along_axis(_wpercentile1d, 0, newdata, newweights, + percentiles) + + final_shape = (*np.shape(percentiles), *shape[ndim:]) + return np.reshape(result, final_shape) + + +def wmedian(data, weights, axis=None): + """Compute the weighted median.""" + return wpercentile(data, weights, 50, axis) \ No newline at end of file From eff10669dabc048eea2afc2e867843a3f05c5446 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:01:11 +0100 Subject: [PATCH 09/46] Updating docstrings --- hela/models.py | 33 ++++++++++++++++++++++++++------- hela/wpercentile.py | 26 +++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/hela/models.py b/hela/models.py index 6dfc18f..cebc9a4 100644 --- a/hela/models.py +++ b/hela/models.py @@ -20,13 +20,13 @@ def __init__(self, num_trees, num_jobs, names, ranges, colors, """ Parameters ---------- - num_trees - num_jobs - names - ranges - colors - verbose - enable_posterior + num_trees : int + num_jobs : int + names : list + ranges : list + colors : list + verbose : bool or int + enable_posterior : bool """ scaler = MinMaxScaler(feature_range=(0, 100)) rf = ensemble.RandomForestRegressor(n_estimators=num_trees, @@ -75,6 +75,16 @@ def _scaler_inverse_transform(self, y): return self.scaler.inverse_transform(y) def fit(self, x, y): + """ + Fit the model. + + Follows scikit-learn convention. + + Parameters + ---------- + x : `~numpy.ndarray` + y : `~numpy.ndarray` + """ self._scaler_fit(y) self.rf.fit(x, self._scaler_transform(y)) @@ -86,6 +96,15 @@ def fit(self, x, y): self.data_y = y def predict(self, x): + """ + Predict on values of ``x`` + + Follows scikit-learn convention. + + Parameters + ---------- + x : `~numpy.ndarray` + """ pred = self.rf.predict(x) return self._scaler_inverse_transform(pred) diff --git a/hela/wpercentile.py b/hela/wpercentile.py index 456bc12..241d891 100644 --- a/hela/wpercentile.py +++ b/hela/wpercentile.py @@ -30,6 +30,17 @@ def _wpercentile1d(data, weights, percentiles): def wpercentile(data, weights, percentiles, axis=None): """ Compute percentiles of a weighted sample. + + Parameters + ---------- + data : `~numpy.ndarray` + weights : `~numpy.ndarray` + percentiles : list + axis : int + + Returns + ------- + ar : `~numpy.ndarray` """ if axis is None: data = np.ravel(data) @@ -57,5 +68,18 @@ def wpercentile(data, weights, percentiles, axis=None): def wmedian(data, weights, axis=None): - """Compute the weighted median.""" + """ + Compute the weighted median. + + Parameters + ---------- + data : `~numpy.ndarray` + weights : `~numpy.ndarray` + axis : int + + Returns + ------- + ar : `~numpy.ndarray` + """ + return wpercentile(data, weights, 50, axis) \ No newline at end of file From 5499ec9e84a578d966238041563e9ef998074e15 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:02:21 +0100 Subject: [PATCH 10/46] Removing accidentally committed version.py --- hela/version.py | 218 ------------------------------------------------ 1 file changed, 218 deletions(-) delete mode 100644 hela/version.py diff --git a/hela/version.py b/hela/version.py deleted file mode 100644 index ca2074f..0000000 --- a/hela/version.py +++ /dev/null @@ -1,218 +0,0 @@ -# Autogenerated by Astropy-affiliated package hela's setup.py on 2019-11-07 19:28:46 UTC -import datetime - - -import locale -import os -import subprocess -import warnings - -__all__ = ['get_git_devstr'] - - -def _decode_stdio(stream): - try: - stdio_encoding = locale.getdefaultlocale()[1] or 'utf-8' - except ValueError: - stdio_encoding = 'utf-8' - - try: - text = stream.decode(stdio_encoding) - except UnicodeDecodeError: - # Final fallback - text = stream.decode('latin1') - - return text - - -def update_git_devstr(version, path=None): - """ - Updates the git revision string if and only if the path is being imported - directly from a git working copy. This ensures that the revision number in - the version string is accurate. - """ - - try: - # Quick way to determine if we're in git or not - returns '' if not - devstr = get_git_devstr(sha=True, show_warning=False, path=path) - except OSError: - return version - - if not devstr: - # Probably not in git so just pass silently - return version - - if 'dev' in version: # update to the current git revision - version_base = version.split('.dev', 1)[0] - devstr = get_git_devstr(sha=False, show_warning=False, path=path) - - return version_base + '.dev' + devstr - else: - # otherwise it's already the true/release version - return version - - -def get_git_devstr(sha=False, show_warning=True, path=None): - """ - Determines the number of revisions in this repository. - - Parameters - ---------- - sha : bool - If True, the full SHA1 hash will be returned. Otherwise, the total - count of commits in the repository will be used as a "revision - number". - - show_warning : bool - If True, issue a warning if git returns an error code, otherwise errors - pass silently. - - path : str or None - If a string, specifies the directory to look in to find the git - repository. If `None`, the current working directory is used, and must - be the root of the git repository. - If given a filename it uses the directory containing that file. - - Returns - ------- - devversion : str - Either a string with the revision number (if `sha` is False), the - SHA1 hash of the current commit (if `sha` is True), or an empty string - if git version info could not be identified. - - """ - - if path is None: - path = os.getcwd() - - if not os.path.isdir(path): - path = os.path.abspath(os.path.dirname(path)) - - if sha: - # Faster for getting just the hash of HEAD - cmd = ['rev-parse', 'HEAD'] - else: - cmd = ['rev-list', '--count', 'HEAD'] - - def run_git(cmd): - try: - p = subprocess.Popen(['git'] + cmd, cwd=path, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE) - stdout, stderr = p.communicate() - except OSError as e: - if show_warning: - warnings.warn('Error running git: ' + str(e)) - return (None, b'', b'') - - if p.returncode == 128: - if show_warning: - warnings.warn('No git repository present at {0!r}! Using ' - 'default dev version.'.format(path)) - return (p.returncode, b'', b'') - if p.returncode == 129: - if show_warning: - warnings.warn('Your git looks old (does it support {0}?); ' - 'consider upgrading to v1.7.2 or ' - 'later.'.format(cmd[0])) - return (p.returncode, stdout, stderr) - elif p.returncode != 0: - if show_warning: - warnings.warn('Git failed while determining revision ' - 'count: {0}'.format(_decode_stdio(stderr))) - return (p.returncode, stdout, stderr) - - return p.returncode, stdout, stderr - - returncode, stdout, stderr = run_git(cmd) - - if not sha and returncode == 128: - # git returns 128 if the command is not run from within a git - # repository tree. In this case, a warning is produced above but we - # return the default dev version of '0'. - return '0' - elif not sha and returncode == 129: - # git returns 129 if a command option failed to parse; in - # particular this could happen in git versions older than 1.7.2 - # where the --count option is not supported - # Also use --abbrev-commit and --abbrev=0 to display the minimum - # number of characters needed per-commit (rather than the full hash) - cmd = ['rev-list', '--abbrev-commit', '--abbrev=0', 'HEAD'] - returncode, stdout, stderr = run_git(cmd) - # Fall back on the old method of getting all revisions and counting - # the lines - if returncode == 0: - return str(stdout.count(b'\n')) - else: - return '' - elif sha: - return _decode_stdio(stdout)[:40] - else: - return _decode_stdio(stdout).strip() - - -# This function is tested but it is only ever executed within a subprocess when -# creating a fake package, so it doesn't get picked up by coverage metrics. -def _get_repo_path(pathname, levels=None): # pragma: no cover - """ - Given a file or directory name, determine the root of the git repository - this path is under. If given, this won't look any higher than ``levels`` - (that is, if ``levels=0`` then the given path must be the root of the git - repository and is returned if so. - - Returns `None` if the given path could not be determined to belong to a git - repo. - """ - - if os.path.isfile(pathname): - current_dir = os.path.abspath(os.path.dirname(pathname)) - elif os.path.isdir(pathname): - current_dir = os.path.abspath(pathname) - else: - return None - - current_level = 0 - - while levels is None or current_level <= levels: - if os.path.exists(os.path.join(current_dir, '.git')): - return current_dir - - current_level += 1 - if current_dir == os.path.dirname(current_dir): - break - - current_dir = os.path.dirname(current_dir) - - return None - - -_packagename = "hela" -_last_generated_version = "0.0.dev035" -_last_githash = "fc6e9dbc80d4d68d83481d2afa7337dd1ae1cf93" - -# Determine where the source code for this module -# lives. If __file__ is not a filesystem path then -# it is assumed not to live in a git repo at all. -if _get_repo_path(__file__, levels=len(_packagename.split('.'))): - version = update_git_devstr(_last_generated_version, path=__file__) - githash = get_git_devstr(sha=True, show_warning=False, - path=__file__) or _last_githash -else: - # The file does not appear to live in a git repo so don't bother - # invoking git - version = _last_generated_version - githash = _last_githash - - -major = 0 -minor = 0 -bugfix = 0 - -version_info = (major, minor, bugfix) - -release = False -timestamp = datetime.datetime(2019, 11, 7, 19, 28, 46) -debug = False - -astropy_helpers_version = "unknown" From 9378719e20794224292f1908045461b4e3cbc34b Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:10:07 +0100 Subject: [PATCH 11/46] API updates to docs, docstrings --- docs/hela/tutorial.rst | 1 - hela/forest.py | 33 ++++++++++++++++----------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 133cc9e..3d8656e 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -89,7 +89,6 @@ using the trained random forest on the sample data in ``samples_path``: plt.close() # Predict posterior distirbutions from random forest - posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) plt.tight_layout() plt.show() diff --git a/hela/forest.py b/hela/forest.py index 9226f64..72ed666 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -58,13 +58,6 @@ def compute_feature_importance(model, dataset, output_path): return np.array([forest_i.feature_importances_ for forest_i in forests]) - def data_ranges(posterior, percentiles=(50, 16, 84)): - samples, weights = posterior - values = wpercentile(samples, weights, percentiles, axis=0) - ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) - return ranges.T - - class RandomForest(object): """ A class for a random forest. @@ -91,10 +84,9 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): Parameters ---------- - num_trees - num_jobs - quiet - kwargs + num_trees : int + num_jobs : int + quiet : bool Returns ------- @@ -122,11 +114,6 @@ def feature_importance(self): """ Compute feature importance. - Parameters - ---------- - model - dataset - Returns ------- feature_importances : `~numpy.ndarray` @@ -141,7 +128,7 @@ def predict(self, plot_posterior=True): Parameters ---------- - plot_posterior + plot_posterior : bool Returns ------- @@ -175,6 +162,18 @@ def predict(self, plot_posterior=True): return posterior def data_ranges(posterior, percentiles=(50, 16, 84)): + """ + Return posterior ranges. + + Parameters + ---------- + posterior : `~numpy.ndarray` + percentiles : tuple + + Returns + ------- + ranges : `~numpy.ndarray` + """ samples, weights = posterior values = wpercentile(samples, weights, percentiles, axis=0) ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) From e33a08848ea7b97b445334977d338d8730a3fb40 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:22:39 +0100 Subject: [PATCH 12/46] Better docstrings, patching tutorial for Pablo's update --- docs/hela/tutorial.rst | 9 +++++++-- hela/forest.py | 25 +++++++++++++++++++------ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 3d8656e..9fa8783 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -64,12 +64,15 @@ The `~hela.RandomForest.train` method returns a dictionary called `r2scores` which contains the :math:`R^2` scores of the slope and intercept. Finally, let's estimate the posterior distributions for the slope and intercept -using the trained random forest on the sample data in ``samples_path``: +using the trained random forest on the sample data in ``samples_path``, where +the true values of the slope and intercept are :math:`m=0.3` and :math:`b=0.5` +using the `~hela.RandomForest.predict` method: .. code-block:: python # Predict posterior distirbutions from random forest - posterior_slopes, posterior_intercepts = rf.predict(plot_posterior=True) + samples, weights = rf.predict(plot_posterior=True) + posterior_slopes, posterior_intercepts = samples.T plt.show() .. plot:: @@ -89,6 +92,8 @@ using the trained random forest on the sample data in ``samples_path``: plt.close() # Predict posterior distirbutions from random forest + samples, weights = rf.predict(plot_posterior=True) + posterior_slopes, posterior_intercepts = samples.T plt.tight_layout() plt.show() diff --git a/hela/forest.py b/hela/forest.py index 72ed666..cf81703 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -66,9 +66,12 @@ def __init__(self, training_dataset, model_path, data_file): """ Parameters ---------- - training_dataset - model_path - data_file + training_dataset : str + Path to the dataset metadata JSON file + model_path : str + Path to the output directory to create and populate + data_file : str + Path to the numpy pickle of the samples to predict on """ self.training_dataset = training_dataset self.model_path = model_path @@ -91,6 +94,7 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): Returns ------- r2scores : dict + :math:`R^2` values for each parameter after training """ # Loading dataset self.dataset = load_dataset(self.training_dataset) @@ -133,8 +137,8 @@ def predict(self, plot_posterior=True): Returns ------- preds : `~numpy.ndarray` - N x M values where N is number of parameters, M is number of - samples/trees (check out attributes of model for metadata) + ``N x M`` values where ``N`` is number of parameters, ``M`` is + number of samples/trees (check out attributes of model for metadata) """ model_file = os.path.join(self.model_path, "model.pkl") # Loading random forest from '{}'...".format(model_file) @@ -181,7 +185,16 @@ def data_ranges(posterior, percentiles=(50, 16, 84)): def generate_example_data(): """ - Generate an example dataset in the new directory ``linear_dataset`` + Generate an example dataset in the new directory ``linear_dataset``. + + Returns + ------- + example_dir : str + Path to the directory of the example data + training_dataset : str + Path to the dataset metadata JSON file + samples_path : str + Path to the numpy pickle of the samples to predict on """ example_dir = 'linear_dataset' training_dataset = os.path.join(example_dir, 'example_dataset.json') From 4b4576ac33a20d87c252a6fdccf6ad9c30b8c46c Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:46:12 +0100 Subject: [PATCH 13/46] More docstrings, explicit class definitions rather than namedtuples --- docs/hela/tutorial.rst | 24 ++++++++++++++++---- hela/dataset.py | 30 ++++++++++++++++++++----- hela/forest.py | 16 ++++++------- hela/plot.py | 51 ++++++++++++++++++++++++++++++++++++++---- 4 files changed, 99 insertions(+), 22 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 9fa8783..b20273d 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -16,8 +16,23 @@ which we'd like to predict on: example_dir, training_dataset, samples_path = generate_example_data() What did that just do? We created an example directory called ``linear_data``, -which contains a training dataset described by the metadata file -``training_dataset``. This training dataset has... +which contains a training dataset described by the metadata file located at path +``training_dataset``. This training dataset contains a JSON file describing the +free parameters, which looks like this: + +.. code-block:: python + + {"metadata": + {"names": ["slope", "intercept"], + "ranges": [[0, 1], [0, 1]], + "colors": ["#F14532", "#4a98c9"], + "num_features": 1000}, + "training_data": "training.npy", + "testing_data": "testing.npy"} + +This file tells the model what the two fitting parameters are and their rainges, +where to grab the training and testing datasets (in the npy pickle files), the +number of features (1000), the colors to use for each parameter in the plots. We also generated a bunch of samples with a known slope and intercept, called ``samples_path``, on which we'll apply our trained random forest to estimate @@ -61,7 +76,8 @@ train the random forest with 1000 trees and on a single processor: plt.show() The `~hela.RandomForest.train` method returns a dictionary called `r2scores` -which contains the :math:`R^2` scores of the slope and intercept. +which contains the :math:`R^2` scores of the slope and intercept, which should +both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept using the trained random forest on the sample data in ``samples_path``, where @@ -70,7 +86,7 @@ using the `~hela.RandomForest.predict` method: .. code-block:: python - # Predict posterior distirbutions from random forest + # Predict posterior distributions from random forest samples, weights = rf.predict(plot_posterior=True) posterior_slopes, posterior_intercepts = samples.T plt.show() diff --git a/hela/dataset.py b/hela/dataset.py index b65a3be..ddf832e 100644 --- a/hela/dataset.py +++ b/hela/dataset.py @@ -1,16 +1,35 @@ import os import json -from collections import namedtuple import numpy as np __all__ = ["Dataset", "load_dataset", "load_data_file"] -Dataset = namedtuple("Dataset", ["training_x", "training_y", - "testing_x", "testing_y", - "names", "ranges", "colors"]) - +class Dataset(object): + """ + Class for a dataset used for training the random forest. + """ + def __init__(self, training_x, training_y, testing_x, testing_y, names, + ranges, colors): + """ + Parameters + ---------- + training_x : `~numpy.ndarray` + training_y : `~numpy.ndarray` + testing_x : `~numpy.ndarray` + testing_y : `~numpy.ndarray` + names : list + ranges : list + colors : list + """ + self.training_x = training_x + self.training_y = training_y + self.testing_x = testing_x + self.testing_y = testing_y + self.names = names + self.ranges = ranges + self.colors = colors def load_data_file(data_file, num_features): data = np.load(data_file) @@ -61,4 +80,3 @@ def load_dataset(dataset_file): return Dataset(training_x, training_y, testing_x, testing_y, metadata["names"], metadata["ranges"], metadata["colors"]) - diff --git a/hela/forest.py b/hela/forest.py index cf81703..935f6d3 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -7,7 +7,7 @@ from .dataset import load_dataset, load_data_file from .models import Model -from .plot import predicted_vs_real, feature_importances, posterior_matrix +from .plot import plot_predicted_vs_real, plot_feature_importances, plot_posterior_matrix from .wpercentile import wpercentile __all__ = ['RandomForest', 'generate_example_data'] @@ -35,8 +35,8 @@ def test_model(model, dataset, output_path): for name, values in r2scores.items(): print("\tR^2 score for {}: {:.3f}".format(name, values)) - fig = predicted_vs_real(dataset.testing_y, pred, dataset.names, - dataset.ranges) + fig = plot_predicted_vs_real(dataset.testing_y, pred, dataset.names, + dataset.ranges) fig.savefig(os.path.join(output_path, "predicted_vs_real.pdf"), bbox_inches='tight') return r2scores @@ -48,7 +48,7 @@ def compute_feature_importance(model, dataset, output_path): forests = [i.rf for i in regr.estimators_] + [model.rf] - fig = feature_importances( + fig = plot_feature_importances( forests=[i.rf for i in regr.estimators_] + [model.rf], names=dataset.names + ["joint prediction"], colors=dataset.colors + ["C0"]) @@ -156,10 +156,10 @@ def predict(self, plot_posterior=True): if plot_posterior: # Plotting and saving the posterior matrix..." - fig = posterior_matrix(posterior, - names=model.names, - ranges=model.ranges, - colors=model.colors) + fig = plot_posterior_matrix(posterior, + names=model.names, + ranges=model.ranges, + colors=model.colors) os.makedirs(self.output_path, exist_ok=True) fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), bbox_inches='tight') diff --git a/hela/plot.py b/hela/plot.py index d986d13..a6e5892 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -13,12 +13,28 @@ from .wpercentile import wmedian -__all__ = ['predicted_vs_real', 'feature_importances', 'posterior_matrix'] +__all__ = ['plot_predicted_vs_real', 'plot_feature_importances', + 'plot_posterior_matrix'] POSTERIOR_MAX_SIZE = 10000 -def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): +def plot_predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): + """ + Plot predicted and real parameter values. + + Parameters + ---------- + y_real + y_pred + names + ranges + alpha + + Returns + ------- + + """ num_plots = y_pred.shape[1] num_plot_rows = int(np.sqrt(num_plots)) num_plot_cols = (num_plots - 1) // num_plot_rows + 1 @@ -59,7 +75,20 @@ def predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): return fig -def feature_importances(forests, names, colors): +def plot_feature_importances(forests, names, colors): + """ + Plot the feature importances. + + Parameters + ---------- + forests + names + colors + + Returns + ------- + + """ num_plots = len(forests) num_plot_rows = (num_plots - 1) // 2 + 1 num_plot_cols = 2 @@ -81,8 +110,22 @@ def feature_importances(forests, names, colors): return fig -def posterior_matrix(posterior, names, ranges, colors, soft_colors=None): +def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): + """ + Plot the posterior matrix. + + Parameters + ---------- + posterior + names + ranges + colors + soft_colors + + Returns + ------- + """ samples, weights = posterior cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) From 2a13bd65dfda4c0c723c20baf4ef9668a35374d0 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:48:32 +0100 Subject: [PATCH 14/46] Defining Posterior class --- hela/models.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/hela/models.py b/hela/models.py index cebc9a4..21b97ae 100644 --- a/hela/models.py +++ b/hela/models.py @@ -159,6 +159,20 @@ def posterior(self, x): self.data_y, leaves_x ) +class Posterior(object): + """ + Posteriors are represented as a collection of weighted samples + """ + def __init__(self, samples, weights): + """ + Parameters + ---------- + samples : `~numpy.ndarray` + weights : `~numpy.ndarray` + """ + self.samples = samples + self.weights = weights + def _posterior(data_leaves, data_weights, data_y, query_leaves): weights_x = (query_leaves[:, None] == data_leaves) * data_weights @@ -239,9 +253,6 @@ def _tree_weights(tree, n_samples): res = np.bincount(indices, minlength=n_samples) return _as_smallest_udtype(res) -# Posteriors are represented as a collection of weighted samples -Posterior = namedtuple("Posterior", ["samples", "weights"]) - def resample_posterior(posterior, num_draws): p = posterior.weights / posterior.weights.sum() From 0bdc590bec94d661ee1509f330dec935592ee2fb Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 03:50:35 +0100 Subject: [PATCH 15/46] Better code formatting --- hela/forest.py | 3 ++- hela/models.py | 23 +++++++++++++---------- hela/plot.py | 21 ++++++++------------- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/hela/forest.py b/hela/forest.py index 935f6d3..30a369a 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -7,7 +7,8 @@ from .dataset import load_dataset, load_data_file from .models import Model -from .plot import plot_predicted_vs_real, plot_feature_importances, plot_posterior_matrix +from .plot import (plot_predicted_vs_real, plot_feature_importances, + plot_posterior_matrix) from .wpercentile import wpercentile __all__ = ['RandomForest', 'generate_example_data'] diff --git a/hela/models.py b/hela/models.py index 21b97ae..84526d1 100644 --- a/hela/models.py +++ b/hela/models.py @@ -15,6 +15,7 @@ class Model(object): """ Class for models. """ + def __init__(self, num_trees, num_jobs, names, ranges, colors, verbose=1, enable_posterior=True): """ @@ -92,7 +93,8 @@ def fit(self, x, y): if self.enable_posterior: data_leaves = self.rf.apply(x).T self.data_leaves = _as_smallest_udtype(data_leaves) - self.data_weights = np.array([_tree_weights(tree, len(y)) for tree in self.rf]) + self.data_weights = np.array( + [_tree_weights(tree, len(y)) for tree in self.rf]) self.data_y = y def predict(self, x): @@ -159,10 +161,12 @@ def posterior(self, x): self.data_y, leaves_x ) + class Posterior(object): """ Posteriors are represented as a collection of weighted samples """ + def __init__(self, samples, weights): """ Parameters @@ -173,8 +177,8 @@ def __init__(self, samples, weights): self.samples = samples self.weights = weights -def _posterior(data_leaves, data_weights, data_y, query_leaves): +def _posterior(data_leaves, data_weights, data_y, query_leaves): weights_x = (query_leaves[:, None] == data_leaves) * data_weights weights_x = _as_smallest_udtype(weights_x.sum(0)) @@ -186,8 +190,8 @@ def _posterior(data_leaves, data_weights, data_y, query_leaves): return Posterior(samples, weights) -def _posterior_percentile_nocache(data_leaves, data_weights, data_y, query_leaves, percentile): - +def _posterior_percentile_nocache(data_leaves, data_weights, data_y, + query_leaves, percentile): values = [] # Computing percentiles... @@ -203,14 +207,14 @@ def _posterior_percentile_nocache(data_leaves, data_weights, data_y, query_leave return np.array(values) -def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, percentile): - +def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, + percentile): # Build a dictionary for fast access of the contents of the leaves. # Building cache... cache = [ _build_leaves_cache(leaves_i, weights_i) for leaves_i, weights_i in zip(data_leaves, data_weights) - ] + ] values = [] # Check the contents of the leaves in leaves_x @@ -227,7 +231,6 @@ def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, def _build_leaves_cache(leaves, weights): - result = {} for index, (leaf, weight) in enumerate(zip(leaves, weights)): if weight == 0: @@ -253,8 +256,8 @@ def _tree_weights(tree, n_samples): res = np.bincount(indices, minlength=n_samples) return _as_smallest_udtype(res) -def resample_posterior(posterior, num_draws): +def resample_posterior(posterior, num_draws): p = posterior.weights / posterior.weights.sum() indices = np.random.choice(len(posterior.samples), size=num_draws, p=p) @@ -265,12 +268,12 @@ def resample_posterior(posterior, num_draws): return Posterior(new_samples, new_weights) + def _as_smallest_udtype(arr): return arr.astype(_smallest_udtype(arr.max())) def _smallest_udtype(value): - dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] for dtype in dtypes: diff --git a/hela/plot.py b/hela/plot.py index a6e5892..fa721ae 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -12,7 +12,6 @@ from .models import resample_posterior from .wpercentile import wmedian - __all__ = ['plot_predicted_vs_real', 'plot_feature_importances', 'plot_posterior_matrix'] @@ -145,7 +144,7 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): hspace=0.05, wspace=0.05) iterable = zip(axes.flat, product(range(num_dims), range(num_dims))) - for ax, dims in tqdm(iterable, total=num_dims**2): + for ax, dims in tqdm(iterable, total=num_dims ** 2): # Flip dims. dims = [dims[1], dims[0]] @@ -195,7 +194,7 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ranges=ranges[dims[:1]] ) ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], - width=bins[1]-bins[0]) + width=bins[1] - bins[0]) kd_probs = histogram expected = wmedian(samples[:, dims[0]], weights) @@ -210,8 +209,8 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): # fig.tight_layout(pad=0) return fig -def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): +def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): samples, weights = posterior # For efficiency, do not compute the kernel density # over all the samples of the posterior. Subsample first. @@ -250,7 +249,6 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): def _plot_samples(ax, posterior, color, dims, ranges): - # For efficiency, do not plot all the samples of the posterior. Subsample first. if len(posterior.samples) > POSTERIOR_MAX_SIZE: posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) @@ -281,15 +279,15 @@ def _min_max_scaler(ranges, feature_range=(0, 100)): res.data_max_ = ranges[:, 1] res.data_min_ = ranges[:, 0] res.data_range_ = res.data_max_ - res.data_min_ - res.scale_ = (feature_range[1] - feature_range[0]) / (ranges[:, 1] - ranges[:, 0]) + res.scale_ = (feature_range[1] - feature_range[0]) / ( + ranges[:, 1] - ranges[:, 0]) res.min_ = -res.scale_ * res.data_min_ res.n_samples_seen_ = 1 res.feature_range = feature_range return res -def _kernel_density_joint(samples, weights, ranges, bandwidth=1/25): - +def _kernel_density_joint(samples, weights, ranges, bandwidth=1 / 25): ndims = len(ranges) scaler = _min_max_scaler(ranges, feature_range=(0, 100)) @@ -313,8 +311,7 @@ def _kernel_density_joint(samples, weights, ranges, bandwidth=1/25): def _histogram1d(samples, weights, ranges, bins=20): - - assert(len(ranges) == 1) + assert (len(ranges) == 1) histogram, edges = np.histogram( samples[:, 0], @@ -326,8 +323,7 @@ def _histogram1d(samples, weights, ranges, bins=20): def _histogram2d(samples, weights, ranges, bins=20): - - assert(len(ranges) == 2) + assert (len(ranges) == 2) histogram, xedges, yedges = np.histogram2d( samples[:, 0], @@ -341,7 +337,6 @@ def _histogram2d(samples, weights, ranges, bins=20): def _weights_to_alpha(weights): - # Maximum weight (removing potential outliers) max_weight = np.percentile(weights, 98) return np.clip(weights / max_weight, 0, 1) From 1b0e2c73afa470b2875262a4ba0223d675adf87c Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:32:47 +0100 Subject: [PATCH 16/46] updating Posterior API --- README.md | 5 ++-- docs/hela/installation.rst | 10 ++++++- docs/hela/legacy.rst | 57 ++++++++++++++++++++++++++++++++++++++ docs/hela/tutorial.rst | 10 +++---- docs/index.rst | 1 + hela/forest.py | 4 +-- hela/plot.py | 32 +++++++++------------ 7 files changed, 90 insertions(+), 29 deletions(-) create mode 100644 docs/hela/legacy.rst diff --git a/README.md b/README.md index f7b2b3e..356df9a 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,10 @@ The set-up here is simply a Random Forest algorithm, for use on a training set p HELA is developed for use with Python 3 and requires the following packages: - numpy -- sklearn - matplotlib - +- sklearn +- tqdm +- joblib ## Running HELA diff --git a/docs/hela/installation.rst b/docs/hela/installation.rst index deb9d64..8436f05 100644 --- a/docs/hela/installation.rst +++ b/docs/hela/installation.rst @@ -1,6 +1,14 @@ Installation ============ +There are several dependencies for you to install, which you can do with ``pip`` +like so:: + + pip install numpy matplotlib sklearn tqdm joblib + + To install hela, run:: - python setup.py install \ No newline at end of file + git clone https://github.com/exoclime/HELA.git + cd HELA + python setup.py install diff --git a/docs/hela/legacy.rst b/docs/hela/legacy.rst new file mode 100644 index 0000000..7a1f55e --- /dev/null +++ b/docs/hela/legacy.rst @@ -0,0 +1,57 @@ +Legacy API +========== + +We start with training our forest, on the example dataset provided. To check how +to run the training stage, you can run:: + + python rfretrieval.py train -h + +This will show you the usage of ``train``, in case you need a reminder. So, we +run training as follows:: + + python rfretrieval.py train example_dataset/example_dataset.json example_model/ + + +The ``training_dataset`` refers to the ``.json`` file in the dataset folder. +The ``training.npy`` and ``testing.npy`` files must also be in this folder. +The ``model_path`` is just some new output path you need to choose a name for. +It will be created. + +You can also edit the number of trees used, and the number of jobs, and find the +feature importances, by running with the extra optional arguments:: + + python rfretrieval.py train example_dataset/example_dataset.json example_model/ --num-trees 100 --num-jobs 3 --feature-importance + +The default number of trees is 1000. The default number of jobs is 5. The +default does not run the feature importance. This is because it requires +training a new forest for each parameter, so makes the process much slower, and +you may not need the feature importance every time you use HELA. + +Once running, HELA will update you at several stages of training, telling you +how long each stage has taken. E.g.:: + + [Parallel(n_jobs=5)]: Done 40 tasks | elapsed: 5.0s + +The ``40 tasks`` refers to the first 40 trees having been trained. + +After training is complete, HELA will run testing. It will print an :math:`R^2` +score for each parameter, and plot the results. The forest itself, ``model.pkl`` +and the predicted vs real graph can now be found in the ``example_model/`` +folder. + +You can now use your forest to predict on data. In the example dataset we have +included the WASP-12b data, for which this particular training set was tailored +for. You can check how the prediction stage runs by running:: + + python rfretreival.py predict -h + +For this stage, you must provide the model's path, the data file, and an output +folder. Whether the posteriors are plotted or not is optional. So, to include +the posteriors, we run:: + + python rfretrieval.py predict example_model/ example_dataset/WASP12b.npy example_plots/ --plot-posterior + +This will give you a prediction for each parameter on this data file. The +numbers given are the median, and in brackets the 16th and 84th percentiles, of +the posteriors. The posterior matrix can now be found in the ``example_plots/`` +folder. diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index b20273d..0ae43dc 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -87,8 +87,8 @@ using the `~hela.RandomForest.predict` method: .. code-block:: python # Predict posterior distributions from random forest - samples, weights = rf.predict(plot_posterior=True) - posterior_slopes, posterior_intercepts = samples.T + posterior = rf.predict(plot_posterior=True) + posterior_slopes, posterior_intercepts = posterior.samples.T plt.show() .. plot:: @@ -107,9 +107,9 @@ using the `~hela.RandomForest.predict` method: r2scores = rf.train(num_trees=1000, num_jobs=1) plt.close() - # Predict posterior distirbutions from random forest - samples, weights = rf.predict(plot_posterior=True) - posterior_slopes, posterior_intercepts = samples.T + # Predict posterior distributions from random forest + posterior = rf.predict(plot_posterior=True) + posterior_slopes, posterior_intercepts = posterior.samples.T plt.tight_layout() plt.show() diff --git a/docs/index.rst b/docs/index.rst index 2b33a81..3728648 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,3 +9,4 @@ Random Forest retrieval for exoplanet atmospheres. hela/installation.rst hela/tutorial.rst hela/api.rst + hela/legacy.rst \ No newline at end of file diff --git a/hela/forest.py b/hela/forest.py index 30a369a..5204222 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -179,8 +179,8 @@ def data_ranges(posterior, percentiles=(50, 16, 84)): ------- ranges : `~numpy.ndarray` """ - samples, weights = posterior - values = wpercentile(samples, weights, percentiles, axis=0) + values = wpercentile(posterior.samples, posterior.weights, + percentiles, axis=0) ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) return ranges.T diff --git a/hela/plot.py b/hela/plot.py index fa721ae..abf802f 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -125,8 +125,6 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ------- """ - samples, weights = posterior - cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) for c in colors] @@ -135,7 +133,7 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): if soft_colors is None: soft_colors = colors - num_dims = samples.shape[1] + num_dims = posterior.samples.shape[1] fig, axes = plt.subplots(nrows=num_dims, ncols=num_dims, figsize=(2 * num_dims, 2 * num_dims)) @@ -190,14 +188,14 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ) else: histogram, bins = _histogram1d( - samples[:, dims[:1]], weights, + posterior.samples[:, dims[:1]], posterior.weights, ranges=ranges[dims[:1]] ) ax.bar(bins[:-1], histogram, color=soft_colors[dims[0]], width=bins[1] - bins[0]) kd_probs = histogram - expected = wmedian(samples[:, dims[0]], weights) + expected = wmedian(posterior.samples[:, dims[0]], posterior.weights) ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', linewidth=1, color='#222222') @@ -211,15 +209,14 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): - samples, weights = posterior # For efficiency, do not compute the kernel density # over all the samples of the posterior. Subsample first. - if len(samples) > POSTERIOR_MAX_SIZE: - samples, weights = resample_posterior(posterior, POSTERIOR_MAX_SIZE) + if len(posterior.samples) > POSTERIOR_MAX_SIZE: + posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) locations, kd_probs, *_ = _kernel_density_joint( - samples[:, dims], - weights, + posterior.samples[:, dims], + posterior.weights, ranges ) ax.contour( @@ -230,14 +227,13 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): ) # For the rest of the plot we use the complete posterior - samples, weights = posterior histogram, grid_x, grid_y = _histogram2d( - samples[:, dims], weights, + posterior.samples[:, dims], posterior.weights, ranges ) ax.pcolormesh(grid_x, grid_y, histogram, cmap=cmap) - expected = wmedian(samples[:, dims], weights, axis=0) + expected = wmedian(posterior.samples[:, dims], posterior.weights, axis=0) ax.plot([expected[0], expected[0]], [ranges[1][0], ranges[1][1]], '-', linewidth=1, color='#222222') ax.plot([ranges[0][0], ranges[0][1]], [expected[1], expected[1]], @@ -253,17 +249,15 @@ def _plot_samples(ax, posterior, color, dims, ranges): if len(posterior.samples) > POSTERIOR_MAX_SIZE: posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) - samples, weights = posterior - - points_alpha = _weights_to_alpha(weights) + points_alpha = _weights_to_alpha(posterior.weights) current_colors = to_rgba_array(color) - current_colors = np.tile(current_colors, (len(samples), 1)) + current_colors = np.tile(current_colors, (len(posterior.samples), 1)) current_colors[:, 3] = points_alpha ax.scatter( - x=samples[:, dims[0]], - y=samples[:, dims[1]], + x=posterior.samples[:, dims[0]], + y=posterior.samples[:, dims[1]], s=100, c=current_colors, marker='.', From 83d72fb0fff4e6ad2c42c402dc5839df478eddd8 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:36:16 +0100 Subject: [PATCH 17/46] Removing redundant train step, removing empty notebook --- docs/hela/tutorial.rst | 4 -- ...pheric retrieval with Random Forests.ipynb | 46 ------------------- 2 files changed, 50 deletions(-) delete mode 100644 notebooks/Atmospheric retrieval with Random Forests.ipynb diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 0ae43dc..c8dc876 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -103,10 +103,6 @@ using the `~hela.RandomForest.predict` method: # Initialize a random forest object: rf = RandomForest(training_dataset, example_dir, samples_path) - # Train the random forest: - r2scores = rf.train(num_trees=1000, num_jobs=1) - plt.close() - # Predict posterior distributions from random forest posterior = rf.predict(plot_posterior=True) posterior_slopes, posterior_intercepts = posterior.samples.T diff --git a/notebooks/Atmospheric retrieval with Random Forests.ipynb b/notebooks/Atmospheric retrieval with Random Forests.ipynb deleted file mode 100644 index dde2c9d..0000000 --- a/notebooks/Atmospheric retrieval with Random Forests.ipynb +++ /dev/null @@ -1,46 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Atmospheric Retrieval with Random Forests\n", - "\n", - "This notebook contains a implementation of the method introduced in the paper _Using Supervised Machine Learning to Analyze Spectra of Exoplanetary Atmospheres_.\n", - "\n", - "## Instructions\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 6cf54dc3cbc168804489c5debc75a04d8fba1d3a Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:52:36 +0100 Subject: [PATCH 18/46] Updating readme, moving old code to legacy dir --- README.rst | 45 +++++++++++++++++++++++++ docs/hela/tutorial.rst | 2 +- README.md => legacy/README.md | 0 models.py => legacy/models.py | 0 plot.py => legacy/plot.py | 0 rfretrieval.py => legacy/rfretrieval.py | 0 utils.py => legacy/utils.py | 0 wpercentile.py => legacy/wpercentile.py | 0 8 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 README.rst rename README.md => legacy/README.md (100%) rename models.py => legacy/models.py (100%) rename plot.py => legacy/plot.py (100%) rename rfretrieval.py => legacy/rfretrieval.py (100%) rename utils.py => legacy/utils.py (100%) rename wpercentile.py => legacy/wpercentile.py (100%) diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..bf18c45 --- /dev/null +++ b/README.rst @@ -0,0 +1,45 @@ +HELA +==== + +.. image:: img/HELA_logo1.png + +.. image:: http://img.shields.io/badge/powered%20by-AstroPy-orange.svg?style=flat + :target: http://www.astropy.org/ + +.. image:: http://img.shields.io/badge/arXiv-1806.03944-red.svg?style=flat + :target: https://arxiv.org/abs/1806.03944 + :alt: arXiv paper + +A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. + +Legacy API +++++++++++ + +If you used HELA previous to the most recent major update and want to recover +the old behavior of HELA, visit the [``legacy``](legacy) directory for a +legacy branch of the package. + +Citation +++++++++ + +If you use HELA in your work, please cite [Marquez-Neila et al. 2018](https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M/abstract): + +``` +@ARTICLE{2018NatAs...2..719M, + author = {{M{\'a}rquez-Neila}, Pablo and {Fisher}, Chloe and {Sznitman}, Raphael and + {Heng}, Kevin}, + title = "{Supervised machine learning for analysing spectra of exoplanetary atmospheres}", + journal = {Nature Astronomy}, + keywords = {Astrophysics - Earth and Planetary Astrophysics, Physics - Atmospheric and Oceanic Physics, Physics - Data Analysis, Statistics and Probability}, + year = "2018", + month = "Jun", + volume = {2}, + pages = {719-724}, + doi = {10.1038/s41550-018-0504-2}, +archivePrefix = {arXiv}, + eprint = {1806.03944}, + primaryClass = {astro-ph.EP}, + adsurl = {https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M}, + adsnote = {Provided by the SAO/NASA Astrophysics Data System} +} +``` \ No newline at end of file diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index c8dc876..9165314 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -15,7 +15,7 @@ which we'd like to predict on: # Generate an example dataset directory example_dir, training_dataset, samples_path = generate_example_data() -What did that just do? We created an example directory called ``linear_data``, +This handy command created an example directory called ``linear_data``, which contains a training dataset described by the metadata file located at path ``training_dataset``. This training dataset contains a JSON file describing the free parameters, which looks like this: diff --git a/README.md b/legacy/README.md similarity index 100% rename from README.md rename to legacy/README.md diff --git a/models.py b/legacy/models.py similarity index 100% rename from models.py rename to legacy/models.py diff --git a/plot.py b/legacy/plot.py similarity index 100% rename from plot.py rename to legacy/plot.py diff --git a/rfretrieval.py b/legacy/rfretrieval.py similarity index 100% rename from rfretrieval.py rename to legacy/rfretrieval.py diff --git a/utils.py b/legacy/utils.py similarity index 100% rename from utils.py rename to legacy/utils.py diff --git a/wpercentile.py b/legacy/wpercentile.py similarity index 100% rename from wpercentile.py rename to legacy/wpercentile.py From ce1801e78f7010d69ba5aed4f8c90c2e437ed4f9 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:54:43 +0100 Subject: [PATCH 19/46] readme tweaks --- README.rst | 42 +++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/README.rst b/README.rst index bf18c45..e05b49a 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,3 @@ -HELA -==== - .. image:: img/HELA_logo1.png .. image:: http://img.shields.io/badge/powered%20by-AstroPy-orange.svg?style=flat @@ -16,7 +13,7 @@ Legacy API ++++++++++ If you used HELA previous to the most recent major update and want to recover -the old behavior of HELA, visit the [``legacy``](legacy) directory for a +the old behavior of HELA, visit the `legacy `_ directory for a legacy branch of the package. Citation @@ -24,22 +21,21 @@ Citation If you use HELA in your work, please cite [Marquez-Neila et al. 2018](https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M/abstract): -``` -@ARTICLE{2018NatAs...2..719M, - author = {{M{\'a}rquez-Neila}, Pablo and {Fisher}, Chloe and {Sznitman}, Raphael and - {Heng}, Kevin}, - title = "{Supervised machine learning for analysing spectra of exoplanetary atmospheres}", - journal = {Nature Astronomy}, - keywords = {Astrophysics - Earth and Planetary Astrophysics, Physics - Atmospheric and Oceanic Physics, Physics - Data Analysis, Statistics and Probability}, - year = "2018", - month = "Jun", - volume = {2}, - pages = {719-724}, - doi = {10.1038/s41550-018-0504-2}, -archivePrefix = {arXiv}, - eprint = {1806.03944}, - primaryClass = {astro-ph.EP}, - adsurl = {https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M}, - adsnote = {Provided by the SAO/NASA Astrophysics Data System} -} -``` \ No newline at end of file +.. code-block:: + @ARTICLE{2018NatAs...2..719M, + author = {{M{\'a}rquez-Neila}, Pablo and {Fisher}, Chloe and {Sznitman}, Raphael and + {Heng}, Kevin}, + title = "{Supervised machine learning for analysing spectra of exoplanetary atmospheres}", + journal = {Nature Astronomy}, + keywords = {Astrophysics - Earth and Planetary Astrophysics, Physics - Atmospheric and Oceanic Physics, Physics - Data Analysis, Statistics and Probability}, + year = "2018", + month = "Jun", + volume = {2}, + pages = {719-724}, + doi = {10.1038/s41550-018-0504-2}, + archivePrefix = {arXiv}, + eprint = {1806.03944}, + primaryClass = {astro-ph.EP}, + adsurl = {https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M}, + adsnote = {Provided by the SAO/NASA Astrophysics Data System} + } From a4d4103603fca95f9d0c7e5caac58c0b474c7fb1 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:55:38 +0100 Subject: [PATCH 20/46] readme tweaks --- README.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index e05b49a..e3254c0 100644 --- a/README.rst +++ b/README.rst @@ -7,7 +7,8 @@ :target: https://arxiv.org/abs/1806.03944 :alt: arXiv paper -A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. +A Random Forest retrieval algorithm, here used to perform atmospheric retrieval + on exoplanet atmospheres. Legacy API ++++++++++ @@ -19,9 +20,11 @@ legacy branch of the package. Citation ++++++++ -If you use HELA in your work, please cite [Marquez-Neila et al. 2018](https://ui.adsabs.harvard.edu/abs/2018NatAs...2..719M/abstract): +If you use HELA in your work, please cite +`Marquez-Neila et al. 2018 `_: .. code-block:: + @ARTICLE{2018NatAs...2..719M, author = {{M{\'a}rquez-Neila}, Pablo and {Fisher}, Chloe and {Sznitman}, Raphael and {Heng}, Kevin}, From bed28826e8dbed5675ef48a4328c3818dc5d59b2 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 04:58:26 +0100 Subject: [PATCH 21/46] Moving final file into legacy dir --- README.rst | 2 +- legacy/README.md | 2 +- dataset.py => legacy/dataset.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename dataset.py => legacy/dataset.py (100%) diff --git a/README.rst b/README.rst index e3254c0..46d81e2 100644 --- a/README.rst +++ b/README.rst @@ -8,7 +8,7 @@ :alt: arXiv paper A Random Forest retrieval algorithm, here used to perform atmospheric retrieval - on exoplanet atmospheres. +on exoplanet atmospheres. Legacy API ++++++++++ diff --git a/legacy/README.md b/legacy/README.md index 356df9a..c6321c7 100644 --- a/legacy/README.md +++ b/legacy/README.md @@ -1,4 +1,4 @@ -# HELA drawing +# HELA drawing A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. diff --git a/dataset.py b/legacy/dataset.py similarity index 100% rename from dataset.py rename to legacy/dataset.py From 102f303c99abb7cc0762483a327ca4deaeeaf570 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 05:03:03 +0100 Subject: [PATCH 22/46] Fixing legacy README, adding developer docs --- docs/hela/developer.rst | 10 ++++++++++ docs/index.rst | 3 ++- legacy/README.md | 6 +++--- 3 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 docs/hela/developer.rst diff --git a/docs/hela/developer.rst b/docs/hela/developer.rst new file mode 100644 index 0000000..176c7d9 --- /dev/null +++ b/docs/hela/developer.rst @@ -0,0 +1,10 @@ +Developer Docs +============== + +To run the tests locally, run:: + + python setup.py test + +To build the documentation locally, run:: + + python setup.py build_docs \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 3728648..8219e21 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,4 +9,5 @@ Random Forest retrieval for exoplanet atmospheres. hela/installation.rst hela/tutorial.rst hela/api.rst - hela/legacy.rst \ No newline at end of file + hela/legacy.rst + hela/developer.rst diff --git a/legacy/README.md b/legacy/README.md index c6321c7..5d69278 100644 --- a/legacy/README.md +++ b/legacy/README.md @@ -32,7 +32,7 @@ python3 rfretrieval.py train -h This will show you the usage of 'train', in case you need a reminder. So, we run training as follows: ``` -python3 rfretrieval.py train example_dataset/example_dataset.json example_model/ +python3 rfretrieval.py train ../example_dataset/example_dataset.json ../example_model/ ``` The ```training_dataset``` refers to the ```.json``` file in the dataset folder. The ```training.npy``` and ```testing.npy``` files must also be in this folder. The ```model_path``` is just some new output path you need to choose a name for. It will be created. @@ -40,7 +40,7 @@ The ```training_dataset``` refers to the ```.json``` file in the dataset folder. You can also edit the number of trees used, and the number of jobs, and find the feature importances, by running with the extra optional arguments: ``` -python3 rfretrieval.py train example_dataset/example_dataset.json example_model/ --num-trees 100 --num-jobs 3 --feature-importance +python3 rfretrieval.py train ../example_dataset/example_dataset.json ../example_model/ --num-trees 100 --num-jobs 3 --feature-importance ``` The default number of trees is 1000. The default number of jobs is 5. The default does not run the feature importance. This is because it requires training a new forest for each parameter, so makes the process much slower, and you may not need the feature importance every time you use HELA. @@ -64,7 +64,7 @@ python3 rfretreival.py predict -h For this stage, you must provide the model's path, the data file, and an output folder. Whether the posteriors are plotted or not is optional. So, to include the posteriors, we run: ``` -python3 rfretrieval.py predict example_model/ example_dataset/WASP12b.npy example_plots/ --plot-posterior +python3 rfretrieval.py predict ../example_model/ ../example_dataset/WASP12b.npy ../example_plots/ --plot-posterior ``` This will give you a prediction for each parameter on this data file. The numbers given are the median, and in brackets the 16th and 84th percentiles, of the posteriors. The posterior matrix can now be found in the ```example_plots/``` folder. From df9d6b0e4412d5eb8c490fad70c496a7f744b812 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 11:00:27 +0100 Subject: [PATCH 23/46] tweaking docs build --- .gitmodules | 3 + ah_bootstrap.py | 334 ++++++++++++++++++++++-------------------------- astropy_helpers | 1 + 3 files changed, 160 insertions(+), 178 deletions(-) create mode 100644 .gitmodules create mode 160000 astropy_helpers diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..2aef90e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "astropy-helpers"] + path = astropy-helpers + url = git@github.com:astropy/astropy-helpers.git diff --git a/ah_bootstrap.py b/ah_bootstrap.py index 67ca92b..0dc5007 100644 --- a/ah_bootstrap.py +++ b/ah_bootstrap.py @@ -19,14 +19,9 @@ contains an option called ``auto_use`` with a value of ``True``, it will automatically call the main function of this module called `use_astropy_helpers` (see that function's docstring for full details). -Otherwise no further action is taken and by default the system-installed version -of astropy-helpers will be used (however, ``ah_bootstrap.use_astropy_helpers`` -may be called manually from within the setup.py script). - -This behavior can also be controlled using the ``--auto-use`` and -``--no-auto-use`` command-line flags. For clarity, an alias for -``--no-auto-use`` is ``--use-system-astropy-helpers``, and we recommend using -the latter if needed. +Otherwise no further action is taken (however, +``ah_bootstrap.use_astropy_helpers`` may be called manually from within the +setup.py script). Additional options in the ``[ah_boostrap]`` section of setup.cfg have the same names as the arguments to `use_astropy_helpers`, and can be used to configure @@ -38,6 +33,7 @@ import contextlib import errno +import imp import io import locale import os @@ -45,133 +41,54 @@ import subprocess as sp import sys -from distutils import log -from distutils.debug import DEBUG - -from configparser import ConfigParser, RawConfigParser - -import pkg_resources - -from setuptools import Distribution -from setuptools.package_index import PackageIndex - -# This is the minimum Python version required for astropy-helpers -__minimum_python_version__ = (3, 5) - -# TODO: Maybe enable checking for a specific version of astropy_helpers? -DIST_NAME = 'astropy-helpers' -PACKAGE_NAME = 'astropy_helpers' -UPPER_VERSION_EXCLUSIVE = None - -# Defaults for other options -DOWNLOAD_IF_NEEDED = True -INDEX_URL = 'https://pypi.python.org/simple' -USE_GIT = True -OFFLINE = False -AUTO_UPGRADE = True - -# A list of all the configuration options and their required types -CFG_OPTIONS = [ - ('auto_use', bool), ('path', str), ('download_if_needed', bool), - ('index_url', str), ('use_git', bool), ('offline', bool), - ('auto_upgrade', bool) -] - -# Start off by parsing the setup.cfg file - -_err_help_msg = """ -If the problem persists consider installing astropy_helpers manually using pip -(`pip install astropy_helpers`) or by manually downloading the source archive, -extracting it, and installing by running `python setup.py install` from the -root of the extracted source code. -""" - -SETUP_CFG = ConfigParser() - -if os.path.exists('setup.cfg'): - - try: - SETUP_CFG.read('setup.cfg') - except Exception as e: - if DEBUG: - raise +try: + from ConfigParser import ConfigParser, RawConfigParser +except ImportError: + from configparser import ConfigParser, RawConfigParser - log.error( - "Error reading setup.cfg: {0!r}\n{1} will not be " - "automatically bootstrapped and package installation may fail." - "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) -# We used package_name in the package template for a while instead of name -if SETUP_CFG.has_option('metadata', 'name'): - parent_package = SETUP_CFG.get('metadata', 'name') -elif SETUP_CFG.has_option('metadata', 'package_name'): - parent_package = SETUP_CFG.get('metadata', 'package_name') +if sys.version_info[0] < 3: + _str_types = (str, unicode) + _text_type = unicode + PY3 = False else: - parent_package = None - -if SETUP_CFG.has_option('options', 'python_requires'): - - python_requires = SETUP_CFG.get('options', 'python_requires') - - # The python_requires key has a syntax that can be parsed by SpecifierSet - # in the packaging package. However, we don't want to have to depend on that - # package, so instead we can use setuptools (which bundles packaging). We - # have to add 'python' to parse it with Requirement. - - from pkg_resources import Requirement - req = Requirement.parse('python' + python_requires) - - # We want the Python version as a string, which we can get from the platform module - import platform - # strip off trailing '+' incase this is a dev install of python - python_version = platform.python_version().strip('+') - # allow pre-releases to count as 'new enough' - if not req.specifier.contains(python_version, True): - if parent_package is None: - message = "ERROR: Python {} is required by this package\n".format(req.specifier) - else: - message = "ERROR: Python {} is required by {}\n".format(req.specifier, parent_package) - sys.stderr.write(message) - sys.exit(1) - -if sys.version_info < __minimum_python_version__: - - if parent_package is None: - message = "ERROR: Python {} or later is required by astropy-helpers\n".format( - __minimum_python_version__) - else: - message = "ERROR: Python {} or later is required by astropy-helpers for {}\n".format( - __minimum_python_version__, parent_package) - - sys.stderr.write(message) - sys.exit(1) - -_str_types = (str, bytes) + _str_types = (str, bytes) + _text_type = str + PY3 = True # What follows are several import statements meant to deal with install-time # issues with either missing or misbehaving pacakges (including making sure # setuptools itself is installed): -# Check that setuptools 30.3 or later is present -from distutils.version import LooseVersion +# Some pre-setuptools checks to ensure that either distribute or setuptools >= +# 0.7 is used (over pre-distribute setuptools) if it is available on the path; +# otherwise the latest setuptools will be downloaded and bootstrapped with +# ``ez_setup.py``. This used to be included in a separate file called +# setuptools_bootstrap.py; but it was combined into ah_bootstrap.py try: - import setuptools - assert LooseVersion(setuptools.__version__) >= LooseVersion('30.3') -except (ImportError, AssertionError): - sys.stderr.write("ERROR: setuptools 30.3 or later is required by astropy-helpers\n") - sys.exit(1) - -# typing as a dependency for 1.6.1+ Sphinx causes issues when imported after -# initializing submodule with ah_boostrap.py -# See discussion and references in -# https://github.com/astropy/astropy-helpers/issues/302 - -try: - import typing # noqa -except ImportError: - pass + import pkg_resources + _setuptools_req = pkg_resources.Requirement.parse('setuptools>=0.7') + # This may raise a DistributionNotFound in which case no version of + # setuptools or distribute is properly installed + _setuptools = pkg_resources.get_distribution('setuptools') + if _setuptools not in _setuptools_req: + # Older version of setuptools; check if we have distribute; again if + # this results in DistributionNotFound we want to give up + _distribute = pkg_resources.get_distribution('distribute') + if _setuptools != _distribute: + # It's possible on some pathological systems to have an old version + # of setuptools and distribute on sys.path simultaneously; make + # sure distribute is the one that's used + sys.path.insert(1, _distribute.location) + _distribute.activate() + imp.reload(pkg_resources) +except: + # There are several types of exceptions that can occur here; if all else + # fails bootstrap and use the bootstrapped version + from ez_setup import use_setuptools + use_setuptools() # Note: The following import is required as a workaround to @@ -180,7 +97,7 @@ # later cause the TemporaryDirectory class defined in it to stop working when # used later on by setuptools try: - import setuptools.py31compat # noqa + import setuptools.py31compat except ImportError: pass @@ -204,6 +121,36 @@ # End compatibility imports... +# In case it didn't successfully import before the ez_setup checks +import pkg_resources + +from setuptools import Distribution +from setuptools.package_index import PackageIndex +from setuptools.sandbox import run_setup + +from distutils import log +from distutils.debug import DEBUG + + +# TODO: Maybe enable checking for a specific version of astropy_helpers? +DIST_NAME = 'astropy-helpers' +PACKAGE_NAME = 'astropy_helpers' + +# Defaults for other options +DOWNLOAD_IF_NEEDED = True +INDEX_URL = 'https://pypi.python.org/simple' +USE_GIT = True +OFFLINE = False +AUTO_UPGRADE = True + +# A list of all the configuration options and their required types +CFG_OPTIONS = [ + ('auto_use', bool), ('path', str), ('download_if_needed', bool), + ('index_url', str), ('use_git', bool), ('offline', bool), + ('auto_upgrade', bool) +] + + class _Bootstrapper(object): """ Bootstrapper implementation. See ``use_astropy_helpers`` for parameter @@ -219,7 +166,7 @@ def __init__(self, path=None, index_url=None, use_git=None, offline=None, if not (isinstance(path, _str_types) or path is False): raise TypeError('path must be a string or False') - if not isinstance(path, str): + if PY3 and not isinstance(path, _text_type): fs_encoding = sys.getfilesystemencoding() path = path.decode(fs_encoding) # path to unicode @@ -273,20 +220,36 @@ def main(cls, argv=None): @classmethod def parse_config(cls): + if not os.path.exists('setup.cfg'): + return {} + + cfg = ConfigParser() - if not SETUP_CFG.has_section('ah_bootstrap'): + try: + cfg.read('setup.cfg') + except Exception as e: + if DEBUG: + raise + + log.error( + "Error reading setup.cfg: {0!r}\n{1} will not be " + "automatically bootstrapped and package installation may fail." + "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) + return {} + + if not cfg.has_section('ah_bootstrap'): return {} config = {} for option, type_ in CFG_OPTIONS: - if not SETUP_CFG.has_option('ah_bootstrap', option): + if not cfg.has_option('ah_bootstrap', option): continue if type_ is bool: - value = SETUP_CFG.getboolean('ah_bootstrap', option) + value = cfg.getboolean('ah_bootstrap', option) else: - value = SETUP_CFG.get('ah_bootstrap', option) + value = cfg.get('ah_bootstrap', option) config[option] = value @@ -313,18 +276,6 @@ def parse_command_line(cls, argv=None): config['offline'] = True argv.remove('--offline') - if '--auto-use' in argv: - config['auto_use'] = True - argv.remove('--auto-use') - - if '--no-auto-use' in argv: - config['auto_use'] = False - argv.remove('--no-auto-use') - - if '--use-system-astropy-helpers' in argv: - config['auto_use'] = False - argv.remove('--use-system-astropy-helpers') - return config def run(self): @@ -502,10 +453,9 @@ def _directory_import(self): # setup.py exists we can generate it setup_py = os.path.join(path, 'setup.py') if os.path.isfile(setup_py): - # We use subprocess instead of run_setup from setuptools to - # avoid segmentation faults - see the following for more details: - # https://github.com/cython/cython/issues/2104 - sp.check_output([sys.executable, 'setup.py', 'egg_info'], cwd=path) + with _silence(): + run_setup(os.path.join(path, 'setup.py'), + ['egg_info']) for dist in pkg_resources.find_distributions(path, True): # There should be only one... @@ -540,32 +490,16 @@ def get_option_dict(self, command_name): if version: req = '{0}=={1}'.format(DIST_NAME, version) else: - if UPPER_VERSION_EXCLUSIVE is None: - req = DIST_NAME - else: - req = '{0}<{1}'.format(DIST_NAME, UPPER_VERSION_EXCLUSIVE) + req = DIST_NAME attrs = {'setup_requires': [req]} - # NOTE: we need to parse the config file (e.g. setup.cfg) to make sure - # it honours the options set in the [easy_install] section, and we need - # to explicitly fetch the requirement eggs as setup_requires does not - # get honored in recent versions of setuptools: - # https://github.com/pypa/setuptools/issues/1273 - try: - - context = _verbose if DEBUG else _silence - with context(): - dist = _Distribution(attrs=attrs) - try: - dist.parse_config_files(ignore_option_errors=True) - dist.fetch_build_eggs(req) - except TypeError: - # On older versions of setuptools, ignore_option_errors - # doesn't exist, and the above two lines are not needed - # so we can just continue - pass + if DEBUG: + _Distribution(attrs=attrs) + else: + with _silence(): + _Distribution(attrs=attrs) # If the setup_requires succeeded it will have added the new dist to # the main working_set @@ -675,8 +609,8 @@ def _check_submodule_using_git(self): # only if the submodule is initialized. We ignore this information for # now _git_submodule_status_re = re.compile( - r'^(?P[+-U ])(?P[0-9a-f]{40}) ' - r'(?P\S+)( .*)?$') + '^(?P[+-U ])(?P[0-9a-f]{40}) ' + '(?P\S+)( .*)?$') # The stdout should only contain one line--the status of the # requested submodule @@ -768,7 +702,7 @@ def _update_submodule(self, submodule, status): if self.offline: cmd.append('--no-fetch') elif status == 'U': - raise _AHBootstrapSystemExit( + raise _AHBoostrapSystemExit( 'Error: Submodule {0} contains unresolved merge conflicts. ' 'Please complete or abandon any changes in the submodule so that ' 'it is in a usable state, then try again.'.format(submodule)) @@ -829,7 +763,7 @@ def run_cmd(cmd): msg = 'Command not found: `{0}`'.format(' '.join(cmd)) raise _CommandNotFound(msg, cmd) else: - raise _AHBootstrapSystemExit( + raise _AHBoostrapSystemExit( 'An unexpected error occurred when running the ' '`{0}` command:\n{1}'.format(' '.join(cmd), str(e))) @@ -846,9 +780,9 @@ def run_cmd(cmd): stdio_encoding = 'latin1' # Unlikely to fail at this point but even then let's be flexible - if not isinstance(stdout, str): + if not isinstance(stdout, _text_type): stdout = stdout.decode(stdio_encoding, 'replace') - if not isinstance(stderr, str): + if not isinstance(stderr, _text_type): stderr = stderr.decode(stdio_encoding, 'replace') return (p.returncode, stdout, stderr) @@ -901,10 +835,6 @@ def flush(self): pass -@contextlib.contextmanager -def _verbose(): - yield - @contextlib.contextmanager def _silence(): """A context manager that silences sys.stdout and sys.stderr.""" @@ -928,6 +858,14 @@ def _silence(): sys.stderr = old_stderr +_err_help_msg = """ +If the problem persists consider installing astropy_helpers manually using pip +(`pip install astropy_helpers`) or by manually downloading the source archive, +extracting it, and installing by running `python setup.py install` from the +root of the extracted source code. +""" + + class _AHBootstrapSystemExit(SystemExit): def __init__(self, *args): if not args: @@ -940,6 +878,46 @@ def __init__(self, *args): super(_AHBootstrapSystemExit, self).__init__(msg, *args[1:]) +if sys.version_info[:2] < (2, 7): + # In Python 2.6 the distutils log does not log warnings, errors, etc. to + # stderr so we have to wrap it to ensure consistency at least in this + # module + import distutils + + class log(object): + def __getattr__(self, attr): + return getattr(distutils.log, attr) + + def warn(self, msg, *args): + self._log_to_stderr(distutils.log.WARN, msg, *args) + + def error(self, msg): + self._log_to_stderr(distutils.log.ERROR, msg, *args) + + def fatal(self, msg): + self._log_to_stderr(distutils.log.FATAL, msg, *args) + + def log(self, level, msg, *args): + if level in (distutils.log.WARN, distutils.log.ERROR, + distutils.log.FATAL): + self._log_to_stderr(level, msg, *args) + else: + distutils.log.log(level, msg, *args) + + def _log_to_stderr(self, level, msg, *args): + # This is the only truly 'public' way to get the current threshold + # of the log + current_threshold = distutils.log.set_threshold(distutils.log.WARN) + distutils.log.set_threshold(current_threshold) + if level >= current_threshold: + if args: + msg = msg % args + sys.stderr.write('%s\n' % msg) + sys.stderr.flush() + + log = log() + + BOOTSTRAPPER = _Bootstrapper.main() diff --git a/astropy_helpers b/astropy_helpers new file mode 160000 index 0000000..d8f4890 --- /dev/null +++ b/astropy_helpers @@ -0,0 +1 @@ +Subproject commit d8f48901442e4056879c4249e73ecd7d04a28282 From 9cb3d80d50fd90b5c270884fd63446ed7e11da23 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 12:01:48 +0100 Subject: [PATCH 24/46] Fixing up package-template structure --- .gitmodules | 6 +- MANIFEST.in | 39 ++++++ ah_bootstrap.py | 319 +++++++++++++++++++++++++---------------------- docs/conf.py | 23 ++-- hela/conftest.py | 0 setup.cfg | 63 +++++----- setup.py | 150 +++------------------- 7 files changed, 268 insertions(+), 332 deletions(-) create mode 100644 MANIFEST.in create mode 100644 hela/conftest.py diff --git a/.gitmodules b/.gitmodules index 2aef90e..22d01ae 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "astropy-helpers"] - path = astropy-helpers - url = git@github.com:astropy/astropy-helpers.git +[submodule "astropy_helpers"] + path = astropy_helpers + url = git://github.com/astropy/astropy-helpers diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..d311a39 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,39 @@ +include README.rst +include CHANGES.rst + +include ah_bootstrap.py +include setup.cfg +include hela/tests/coveragerc + +recursive-include hela *.pyx *.c *.pxd +recursive-include docs * +recursive-include licenses * +recursive-include cextern * +recursive-include scripts * + +prune build +prune docs/_build +prune docs/api + + +# the next few stanzas are for astropy_helpers. It's derived from the +# astropy_helpers/MANIFEST.in, but requires additional includes for the actual +# package directory and egg-info. + +include astropy_helpers/README.rst +include astropy_helpers/CHANGES.rst +include astropy_helpers/LICENSE.rst +recursive-include astropy_helpers/licenses * + +include astropy_helpers/ah_bootstrap.py + +recursive-include astropy_helpers/astropy_helpers *.py *.pyx *.c *.h *.rst +recursive-include astropy_helpers/astropy_helpers.egg-info * +# include the sphinx stuff with "*" because there are css/html/rst/etc. +recursive-include astropy_helpers/astropy_helpers/sphinx * + +prune astropy_helpers/build +prune astropy_helpers/astropy_helpers/tests + + +global-exclude *.pyc *.o diff --git a/ah_bootstrap.py b/ah_bootstrap.py index 0dc5007..512a05a 100644 --- a/ah_bootstrap.py +++ b/ah_bootstrap.py @@ -19,9 +19,14 @@ contains an option called ``auto_use`` with a value of ``True``, it will automatically call the main function of this module called `use_astropy_helpers` (see that function's docstring for full details). -Otherwise no further action is taken (however, -``ah_bootstrap.use_astropy_helpers`` may be called manually from within the -setup.py script). +Otherwise no further action is taken and by default the system-installed version +of astropy-helpers will be used (however, ``ah_bootstrap.use_astropy_helpers`` +may be called manually from within the setup.py script). + +This behavior can also be controlled using the ``--auto-use`` and +``--no-auto-use`` command-line flags. For clarity, an alias for +``--no-auto-use`` is ``--use-system-astropy-helpers``, and we recommend using +the latter if needed. Additional options in the ``[ah_boostrap]`` section of setup.cfg have the same names as the arguments to `use_astropy_helpers`, and can be used to configure @@ -33,7 +38,6 @@ import contextlib import errno -import imp import io import locale import os @@ -41,54 +45,126 @@ import subprocess as sp import sys -try: - from ConfigParser import ConfigParser, RawConfigParser -except ImportError: - from configparser import ConfigParser, RawConfigParser +from distutils import log +from distutils.debug import DEBUG + +from configparser import ConfigParser, RawConfigParser + +import pkg_resources + +from setuptools import Distribution +from setuptools.package_index import PackageIndex + +# This is the minimum Python version required for astropy-helpers +__minimum_python_version__ = (3, 5) + +# TODO: Maybe enable checking for a specific version of astropy_helpers? +DIST_NAME = 'astropy-helpers' +PACKAGE_NAME = 'astropy_helpers' +UPPER_VERSION_EXCLUSIVE = None + +# Defaults for other options +DOWNLOAD_IF_NEEDED = True +INDEX_URL = 'https://pypi.python.org/simple' +USE_GIT = True +OFFLINE = False +AUTO_UPGRADE = True + +# A list of all the configuration options and their required types +CFG_OPTIONS = [ + ('auto_use', bool), ('path', str), ('download_if_needed', bool), + ('index_url', str), ('use_git', bool), ('offline', bool), + ('auto_upgrade', bool) +] + +# Start off by parsing the setup.cfg file + +SETUP_CFG = ConfigParser() + +if os.path.exists('setup.cfg'): + + try: + SETUP_CFG.read('setup.cfg') + except Exception as e: + if DEBUG: + raise + log.error( + "Error reading setup.cfg: {0!r}\n{1} will not be " + "automatically bootstrapped and package installation may fail." + "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) -if sys.version_info[0] < 3: - _str_types = (str, unicode) - _text_type = unicode - PY3 = False +# We used package_name in the package template for a while instead of name +if SETUP_CFG.has_option('metadata', 'name'): + parent_package = SETUP_CFG.get('metadata', 'name') +elif SETUP_CFG.has_option('metadata', 'package_name'): + parent_package = SETUP_CFG.get('metadata', 'package_name') else: - _str_types = (str, bytes) - _text_type = str - PY3 = True + parent_package = None + +if SETUP_CFG.has_option('options', 'python_requires'): + + python_requires = SETUP_CFG.get('options', 'python_requires') + + # The python_requires key has a syntax that can be parsed by SpecifierSet + # in the packaging package. However, we don't want to have to depend on that + # package, so instead we can use setuptools (which bundles packaging). We + # have to add 'python' to parse it with Requirement. + + from pkg_resources import Requirement + req = Requirement.parse('python' + python_requires) + + # We want the Python version as a string, which we can get from the platform module + import platform + # strip off trailing '+' incase this is a dev install of python + python_version = platform.python_version().strip('+') + # allow pre-releases to count as 'new enough' + if not req.specifier.contains(python_version, True): + if parent_package is None: + message = "ERROR: Python {} is required by this package\n".format(req.specifier) + else: + message = "ERROR: Python {} is required by {}\n".format(req.specifier, parent_package) + sys.stderr.write(message) + sys.exit(1) + +if sys.version_info < __minimum_python_version__: + + if parent_package is None: + message = "ERROR: Python {} or later is required by astropy-helpers\n".format( + __minimum_python_version__) + else: + message = "ERROR: Python {} or later is required by astropy-helpers for {}\n".format( + __minimum_python_version__, parent_package) + + sys.stderr.write(message) + sys.exit(1) + +_str_types = (str, bytes) # What follows are several import statements meant to deal with install-time # issues with either missing or misbehaving pacakges (including making sure # setuptools itself is installed): +# Check that setuptools 30.3 or later is present +from distutils.version import LooseVersion -# Some pre-setuptools checks to ensure that either distribute or setuptools >= -# 0.7 is used (over pre-distribute setuptools) if it is available on the path; -# otherwise the latest setuptools will be downloaded and bootstrapped with -# ``ez_setup.py``. This used to be included in a separate file called -# setuptools_bootstrap.py; but it was combined into ah_bootstrap.py try: - import pkg_resources - _setuptools_req = pkg_resources.Requirement.parse('setuptools>=0.7') - # This may raise a DistributionNotFound in which case no version of - # setuptools or distribute is properly installed - _setuptools = pkg_resources.get_distribution('setuptools') - if _setuptools not in _setuptools_req: - # Older version of setuptools; check if we have distribute; again if - # this results in DistributionNotFound we want to give up - _distribute = pkg_resources.get_distribution('distribute') - if _setuptools != _distribute: - # It's possible on some pathological systems to have an old version - # of setuptools and distribute on sys.path simultaneously; make - # sure distribute is the one that's used - sys.path.insert(1, _distribute.location) - _distribute.activate() - imp.reload(pkg_resources) -except: - # There are several types of exceptions that can occur here; if all else - # fails bootstrap and use the bootstrapped version - from ez_setup import use_setuptools - use_setuptools() + import setuptools + assert LooseVersion(setuptools.__version__) >= LooseVersion('30.3') +except (ImportError, AssertionError): + sys.stderr.write("ERROR: setuptools 30.3 or later is required by astropy-helpers\n") + sys.exit(1) + +# typing as a dependency for 1.6.1+ Sphinx causes issues when imported after +# initializing submodule with ah_boostrap.py +# See discussion and references in +# https://github.com/astropy/astropy-helpers/issues/302 + +try: + import typing # noqa +except ImportError: + pass # Note: The following import is required as a workaround to @@ -97,7 +173,7 @@ # later cause the TemporaryDirectory class defined in it to stop working when # used later on by setuptools try: - import setuptools.py31compat + import setuptools.py31compat # noqa except ImportError: pass @@ -121,36 +197,6 @@ # End compatibility imports... -# In case it didn't successfully import before the ez_setup checks -import pkg_resources - -from setuptools import Distribution -from setuptools.package_index import PackageIndex -from setuptools.sandbox import run_setup - -from distutils import log -from distutils.debug import DEBUG - - -# TODO: Maybe enable checking for a specific version of astropy_helpers? -DIST_NAME = 'astropy-helpers' -PACKAGE_NAME = 'astropy_helpers' - -# Defaults for other options -DOWNLOAD_IF_NEEDED = True -INDEX_URL = 'https://pypi.python.org/simple' -USE_GIT = True -OFFLINE = False -AUTO_UPGRADE = True - -# A list of all the configuration options and their required types -CFG_OPTIONS = [ - ('auto_use', bool), ('path', str), ('download_if_needed', bool), - ('index_url', str), ('use_git', bool), ('offline', bool), - ('auto_upgrade', bool) -] - - class _Bootstrapper(object): """ Bootstrapper implementation. See ``use_astropy_helpers`` for parameter @@ -166,7 +212,7 @@ def __init__(self, path=None, index_url=None, use_git=None, offline=None, if not (isinstance(path, _str_types) or path is False): raise TypeError('path must be a string or False') - if PY3 and not isinstance(path, _text_type): + if not isinstance(path, str): fs_encoding = sys.getfilesystemencoding() path = path.decode(fs_encoding) # path to unicode @@ -220,36 +266,20 @@ def main(cls, argv=None): @classmethod def parse_config(cls): - if not os.path.exists('setup.cfg'): - return {} - - cfg = ConfigParser() - try: - cfg.read('setup.cfg') - except Exception as e: - if DEBUG: - raise - - log.error( - "Error reading setup.cfg: {0!r}\n{1} will not be " - "automatically bootstrapped and package installation may fail." - "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) - return {} - - if not cfg.has_section('ah_bootstrap'): + if not SETUP_CFG.has_section('ah_bootstrap'): return {} config = {} for option, type_ in CFG_OPTIONS: - if not cfg.has_option('ah_bootstrap', option): + if not SETUP_CFG.has_option('ah_bootstrap', option): continue if type_ is bool: - value = cfg.getboolean('ah_bootstrap', option) + value = SETUP_CFG.getboolean('ah_bootstrap', option) else: - value = cfg.get('ah_bootstrap', option) + value = SETUP_CFG.get('ah_bootstrap', option) config[option] = value @@ -276,6 +306,18 @@ def parse_command_line(cls, argv=None): config['offline'] = True argv.remove('--offline') + if '--auto-use' in argv: + config['auto_use'] = True + argv.remove('--auto-use') + + if '--no-auto-use' in argv: + config['auto_use'] = False + argv.remove('--no-auto-use') + + if '--use-system-astropy-helpers' in argv: + config['auto_use'] = False + argv.remove('--use-system-astropy-helpers') + return config def run(self): @@ -453,9 +495,10 @@ def _directory_import(self): # setup.py exists we can generate it setup_py = os.path.join(path, 'setup.py') if os.path.isfile(setup_py): - with _silence(): - run_setup(os.path.join(path, 'setup.py'), - ['egg_info']) + # We use subprocess instead of run_setup from setuptools to + # avoid segmentation faults - see the following for more details: + # https://github.com/cython/cython/issues/2104 + sp.check_output([sys.executable, 'setup.py', 'egg_info'], cwd=path) for dist in pkg_resources.find_distributions(path, True): # There should be only one... @@ -490,16 +533,32 @@ def get_option_dict(self, command_name): if version: req = '{0}=={1}'.format(DIST_NAME, version) else: - req = DIST_NAME + if UPPER_VERSION_EXCLUSIVE is None: + req = DIST_NAME + else: + req = '{0}<{1}'.format(DIST_NAME, UPPER_VERSION_EXCLUSIVE) attrs = {'setup_requires': [req]} + # NOTE: we need to parse the config file (e.g. setup.cfg) to make sure + # it honours the options set in the [easy_install] section, and we need + # to explicitly fetch the requirement eggs as setup_requires does not + # get honored in recent versions of setuptools: + # https://github.com/pypa/setuptools/issues/1273 + try: - if DEBUG: - _Distribution(attrs=attrs) - else: - with _silence(): - _Distribution(attrs=attrs) + + context = _verbose if DEBUG else _silence + with context(): + dist = _Distribution(attrs=attrs) + try: + dist.parse_config_files(ignore_option_errors=True) + dist.fetch_build_eggs(req) + except TypeError: + # On older versions of setuptools, ignore_option_errors + # doesn't exist, and the above two lines are not needed + # so we can just continue + pass # If the setup_requires succeeded it will have added the new dist to # the main working_set @@ -609,8 +668,8 @@ def _check_submodule_using_git(self): # only if the submodule is initialized. We ignore this information for # now _git_submodule_status_re = re.compile( - '^(?P[+-U ])(?P[0-9a-f]{40}) ' - '(?P\S+)( .*)?$') + r'^(?P[+-U ])(?P[0-9a-f]{40}) ' + r'(?P\S+)( .*)?$') # The stdout should only contain one line--the status of the # requested submodule @@ -702,7 +761,7 @@ def _update_submodule(self, submodule, status): if self.offline: cmd.append('--no-fetch') elif status == 'U': - raise _AHBoostrapSystemExit( + raise _AHBootstrapSystemExit( 'Error: Submodule {0} contains unresolved merge conflicts. ' 'Please complete or abandon any changes in the submodule so that ' 'it is in a usable state, then try again.'.format(submodule)) @@ -763,7 +822,7 @@ def run_cmd(cmd): msg = 'Command not found: `{0}`'.format(' '.join(cmd)) raise _CommandNotFound(msg, cmd) else: - raise _AHBoostrapSystemExit( + raise _AHBootstrapSystemExit( 'An unexpected error occurred when running the ' '`{0}` command:\n{1}'.format(' '.join(cmd), str(e))) @@ -780,9 +839,9 @@ def run_cmd(cmd): stdio_encoding = 'latin1' # Unlikely to fail at this point but even then let's be flexible - if not isinstance(stdout, _text_type): + if not isinstance(stdout, str): stdout = stdout.decode(stdio_encoding, 'replace') - if not isinstance(stderr, _text_type): + if not isinstance(stderr, str): stderr = stderr.decode(stdio_encoding, 'replace') return (p.returncode, stdout, stderr) @@ -835,6 +894,10 @@ def flush(self): pass +@contextlib.contextmanager +def _verbose(): + yield + @contextlib.contextmanager def _silence(): """A context manager that silences sys.stdout and sys.stderr.""" @@ -878,46 +941,6 @@ def __init__(self, *args): super(_AHBootstrapSystemExit, self).__init__(msg, *args[1:]) -if sys.version_info[:2] < (2, 7): - # In Python 2.6 the distutils log does not log warnings, errors, etc. to - # stderr so we have to wrap it to ensure consistency at least in this - # module - import distutils - - class log(object): - def __getattr__(self, attr): - return getattr(distutils.log, attr) - - def warn(self, msg, *args): - self._log_to_stderr(distutils.log.WARN, msg, *args) - - def error(self, msg): - self._log_to_stderr(distutils.log.ERROR, msg, *args) - - def fatal(self, msg): - self._log_to_stderr(distutils.log.FATAL, msg, *args) - - def log(self, level, msg, *args): - if level in (distutils.log.WARN, distutils.log.ERROR, - distutils.log.FATAL): - self._log_to_stderr(level, msg, *args) - else: - distutils.log.log(level, msg, *args) - - def _log_to_stderr(self, level, msg, *args): - # This is the only truly 'public' way to get the current threshold - # of the log - current_threshold = distutils.log.set_threshold(distutils.log.WARN) - distutils.log.set_threshold(current_threshold) - if level >= current_threshold: - if args: - msg = msg % args - sys.stderr.write('%s\n' % msg) - sys.stderr.flush() - - log = log() - - BOOTSTRAPPER = _Bootstrapper.main() diff --git a/docs/conf.py b/docs/conf.py index 8ccbfb9..7f9dd4a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,9 +25,10 @@ # Thus, any C-extensions that are needed to build the documentation will *not* # be accessible, and the documentation will not build correctly. -import datetime import os import sys +import datetime +from importlib import import_module try: from sphinx_astropy.conf.v1 import * # noqa @@ -36,10 +37,7 @@ sys.exit(1) # Get configuration information from setup.cfg -try: - from ConfigParser import ConfigParser -except ImportError: - from configparser import ConfigParser +from configparser import ConfigParser conf = ConfigParser() conf.read([os.path.join(os.path.dirname(__file__), '..', 'setup.cfg')]) @@ -69,7 +67,7 @@ # -- Project information ------------------------------------------------------ # This does not *have* to match the package name, but typically does -project = setup_cfg['package_name'] +project = setup_cfg['name'] author = setup_cfg['author'] copyright = '{0}, {1}'.format( datetime.datetime.now().year, setup_cfg['author']) @@ -78,8 +76,8 @@ # |version| and |release|, also used in various other places throughout the # built documents. -__import__(setup_cfg['package_name']) -package = sys.modules[setup_cfg['package_name']] +import_module(setup_cfg['name']) +package = sys.modules[setup_cfg['name']] # The short X.Y version. version = package.__version__.split('-', 1)[0] @@ -107,7 +105,6 @@ #html_theme = None -# Please update these texts to match the name of your package. html_theme_options = { 'logotext1': 'hela', # white, semi-bold 'logotext2': '', # orange, light @@ -120,7 +117,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -html_logo = '../img/HELA_logo1.png' +#html_logo = '' # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -160,10 +157,10 @@ if eval(setup_cfg.get('edit_on_github')): extensions += ['sphinx_astropy.ext.edit_on_github'] - versionmod = __import__(setup_cfg['package_name'] + '.version') + versionmod = import_module(setup_cfg['name'] + '.version') edit_on_github_project = setup_cfg['github_project'] - if versionmod.version.release: - edit_on_github_branch = "v" + versionmod.version.version + if versionmod.release: + edit_on_github_branch = "v" + versionmod.version else: edit_on_github_branch = "master" diff --git a/hela/conftest.py b/hela/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg index 4aae0e7..0ce1bce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,28 @@ +[metadata] +name = hela +# version should be PEP440 compatible (https://www.python.org/dev/peps/pep-0440/) +version = 0.0.dev +author = Pablo Márquez-Neila and Chloe Fisher +author_email = +description = A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. +long_description = +license = BSD 3-Clause +url = https://github.com/exoclime/HELA +edit_on_github = False +github_project = exoclime/HELA +python_requires = ">=3.5" + +[options] +# install_requires should be formatted as a semicolon-separated list, e.g.: +install_requires = astropy; joblib; scikit-learn; matplotlib +zip_safe = False +use_2to3 = False + +[options.package_data] +* = *.c +hela = data/* +hela.tests = coveragerc + [build_sphinx] source-dir = docs build-dir = docs/_build @@ -19,38 +44,10 @@ doctest_plus = enabled addopts = -p no:warnings [ah_bootstrap] -auto_use = False +auto_use = True -[pycodestyle] -# E101 - mix of tabs and spaces -# W191 - use of tabs -# W291 - trailing whitespace -# W292 - no newline at end of file -# W293 - trailing whitespace -# W391 - blank line at end of file -# E111 - 4 spaces per indentation level -# E112 - 4 spaces per indentation level -# E113 - 4 spaces per indentation level -# E901 - SyntaxError or IndentationError -# E902 - IOError -select = E101,W191,W291,W292,W293,W391,E111,E112,E113,E901,E902 -exclude = extern,sphinx,*parsetab.py - -[metadata] -package_name = hela -description = A Random Forest retrieval algorithm, here used to perform atmospheric retrieval on exoplanet atmospheres. -author = Pablo Márquez-Neila and Chloe Fisher -license = Other -url = https://github.com/exoclime/HELA -edit_on_github = False -github_project = exoclime/HELA -# install_requires should be formatted as a comma-separated list, e.g.: -# install_requires = astropy, scipy, matplotlib -install_requires = astropy, sklearn, matplotlib, joblib -# version should be PEP386 compatible (http://www.python.org/dev/peps/pep-0386) -version = 0.0.dev0 -# Note: you will also need to change this in your package's __init__.py -minimum_python_version = 3.5 - -[entry_points] +[flake8] +exclude = extern,sphinx,*parsetab.py,astropy_helpers,ah_bootstrap.py,conftest.py,docs/conf.py,setup.py +[pycodestyle] +exclude = extern,sphinx,*parsetab.py,astropy_helpers,ah_bootstrap.py,conftest.py,docs/conf.py,setup.py diff --git a/setup.py b/setup.py index 2474ff3..c771525 100755 --- a/setup.py +++ b/setup.py @@ -1,153 +1,33 @@ #!/usr/bin/env python -# Licensed under a 3-clause BSD style license - see LICENSE.rst - -import glob -import os -import sys - -try: - from configparser import ConfigParser -except ImportError: - from ConfigParser import ConfigParser -# Get some values from the setup.cfg -conf = ConfigParser() -conf.read(['setup.cfg']) -metadata = dict(conf.items('metadata')) - -PACKAGENAME = metadata.get('package_name', 'packagename') -DESCRIPTION = metadata.get('description', 'Astropy Package Template') -AUTHOR = metadata.get('author', 'Astropy Developers') -AUTHOR_EMAIL = metadata.get('author_email', '') -LICENSE = metadata.get('license', 'unknown') -URL = metadata.get('url', 'http://astropy.org') -__minimum_python_version__ = metadata.get("minimum_python_version", "2.7") +# Licensed under a 3-clause BSD style license - see LICENSE.rst -# Enforce Python version check - this is the same check as in __init__.py but -# this one has to happen before importing ah_bootstrap. -if sys.version_info < tuple((int(val) for val in __minimum_python_version__.split('.'))): - sys.stderr.write("ERROR: packagename requires Python {} or later\n".format(__minimum_python_version__)) - sys.exit(1) +import builtins -# Import ah_bootstrap after the python version validation +# Ensure that astropy-helpers is available +import ah_bootstrap # noqa -import ah_bootstrap from setuptools import setup +from setuptools.config import read_configuration -# A dirty hack to get around some early import/configurations ambiguities -if sys.version_info[0] >= 3: - import builtins -else: - import __builtin__ as builtins -builtins._ASTROPY_SETUP_ = True - -from astropy_helpers.astropy_helpers.setup_helpers import (register_commands, - get_debug_option, - get_package_info) -from astropy_helpers.astropy_helpers.git_helpers import get_git_devstr -from astropy_helpers.astropy_helpers.version_helpers import generate_version_py - - -# order of priority for long_description: -# (1) set in setup.cfg, -# (2) load LONG_DESCRIPTION.rst, -# (3) load README.rst, -# (4) package docstring -readme_glob = 'README*' -_cfg_long_description = metadata.get('long_description', '') -if _cfg_long_description: - LONG_DESCRIPTION = _cfg_long_description - -elif os.path.exists('LONG_DESCRIPTION.rst'): - with open('LONG_DESCRIPTION.rst') as f: - LONG_DESCRIPTION = f.read() - -elif len(glob.glob(readme_glob)) > 0: - with open(glob.glob(readme_glob)[0]) as f: - LONG_DESCRIPTION = f.read() - -else: - # Get the long description from the package's docstring - __import__(PACKAGENAME) - package = sys.modules[PACKAGENAME] - LONG_DESCRIPTION = package.__doc__ +from astropy_helpers.setup_helpers import register_commands, get_package_info +from astropy_helpers.version_helpers import generate_version_py # Store the package name in a built-in variable so it's easy # to get from other parts of the setup infrastructure -builtins._ASTROPY_PACKAGE_NAME_ = PACKAGENAME +builtins._ASTROPY_PACKAGE_NAME_ = read_configuration('setup.cfg')['metadata']['name'] -# VERSION should be PEP440 compatible (http://www.python.org/dev/peps/pep-0440) -VERSION = metadata.get('version', '0.0.dev0') - -# Indicates if this version is a release version -RELEASE = 'dev' not in VERSION - -if not RELEASE: - VERSION += get_git_devstr(False) - -# Populate the dict of setup command overrides; this should be done before -# invoking any other functionality from distutils since it can potentially -# modify distutils' behavior. -cmdclassd = register_commands(PACKAGENAME, VERSION, RELEASE) - -# Freeze build information in version.py -generate_version_py(PACKAGENAME, VERSION, RELEASE, - get_debug_option(PACKAGENAME)) - -# Treat everything in scripts except README* as a script to be installed -scripts = [fname for fname in glob.glob(os.path.join('scripts', '*')) - if not os.path.basename(fname).startswith('README')] +# Create a dictionary with setup command overrides. Note that this gets +# information about the package (name and version) from the setup.cfg file. +cmdclass = register_commands() +# Freeze build information in version.py. Note that this gets information +# about the package (name and version) from the setup.cfg file. +version = generate_version_py() # Get configuration information from all of the various subpackages. # See the docstring for setup_helpers.update_package_files for more # details. package_info = get_package_info() -# Add the project-global data -package_info['package_data'].setdefault(PACKAGENAME, []) -package_info['package_data'][PACKAGENAME].append('data/*') - -# Define entry points for command-line scripts -entry_points = {'console_scripts': []} - -if conf.has_section('entry_points'): - entry_point_list = conf.items('entry_points') - for entry_point in entry_point_list: - entry_points['console_scripts'].append('{0} = {1}'.format( - entry_point[0], entry_point[1])) - -# Include all .c files, recursively, including those generated by -# Cython, since we can not do this in MANIFEST.in with a "dynamic" -# directory name. -c_files = [] -for root, dirs, files in os.walk(PACKAGENAME): - for filename in files: - if filename.endswith('.c'): - c_files.append( - os.path.join( - os.path.relpath(root, PACKAGENAME), filename)) -package_info['package_data'][PACKAGENAME].extend(c_files) - -# Note that requires and provides should not be included in the call to -# ``setup``, since these are now deprecated. See this link for more details: -# https://groups.google.com/forum/#!topic/astropy-dev/urYO8ckB2uM - -setup(name=PACKAGENAME, - version=VERSION, - description=DESCRIPTION, - scripts=scripts, - install_requires=[s.strip() for s in metadata.get('install_requires', - 'astropy').split(',')], - author=AUTHOR, - author_email=AUTHOR_EMAIL, - license=LICENSE, - url=URL, - long_description=LONG_DESCRIPTION, - cmdclass=cmdclassd, - zip_safe=False, - use_2to3=False, - entry_points=entry_points, - python_requires='>={}'.format(__minimum_python_version__), - **package_info -) +setup(version=version, cmdclass=cmdclass, **package_info) From afce4ffd41aa85872bd124464f7eb57912ac147a Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 12:04:08 +0100 Subject: [PATCH 25/46] Adding logo back into docs/conf.py --- docs/conf.py | 2 +- hela/conftest.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 7f9dd4a..63aa90d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -117,7 +117,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = '' +html_logo = '../img/HELA_logo1.png' # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 diff --git a/hela/conftest.py b/hela/conftest.py index e69de29..6cbbb30 100644 --- a/hela/conftest.py +++ b/hela/conftest.py @@ -0,0 +1,49 @@ +# This file is used to configure the behavior of pytest when using the Astropy +# test infrastructure. +import os + +from astropy.version import version as astropy_version +if astropy_version < '3.0': + # With older versions of Astropy, we actually need to import the pytest + # plugins themselves in order to make them discoverable by pytest. + from astropy.tests.pytest_plugins import * +else: + # As of Astropy 3.0, the pytest plugins provided by Astropy are + # automatically made available when Astropy is installed. This means it's + # not necessary to import them here, but we still need to import global + # variables that are used for configuration. + from astropy.tests.plugins.display import (pytest_report_header, + PYTEST_HEADER_MODULES, + TESTED_VERSIONS) + +from astropy.tests.helper import enable_deprecations_as_exceptions + +## Uncomment the following line to treat all DeprecationWarnings as +## exceptions. For Astropy v2.0 or later, there are 2 additional keywords, +## as follow (although default should work for most cases). +## To ignore some packages that produce deprecation warnings on import +## (in addition to 'compiler', 'scipy', 'pygments', 'ipykernel', and +## 'setuptools'), add: +## modules_to_ignore_on_import=['module_1', 'module_2'] +## To ignore some specific deprecation warning messages for Python version +## MAJOR.MINOR or later, add: +## warnings_to_ignore_by_pyver={(MAJOR, MINOR): ['Message to ignore']} +# enable_deprecations_as_exceptions() + +# Customize the following lines to add/remove entries from +# the list of packages for which version numbers are displayed when running +# the tests. Making it pass for KeyError is essential in some cases when +# the package uses other astropy affiliated packages. +try: + PYTEST_HEADER_MODULES['Astropy'] = 'astropy' + del PYTEST_HEADER_MODULES['h5py'] +except KeyError: + pass + +# This is to figure out the package version, rather than +# using Astropy's +from .version import version, astropy_helpers_version + +packagename = os.path.basename(os.path.dirname(__file__)) +TESTED_VERSIONS[packagename] = version +TESTED_VERSIONS['astropy_helpers'] = astropy_helpers_version From 652179f1c85babf75bc50e40b903dea7681af1a5 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 12:12:18 +0100 Subject: [PATCH 26/46] Adding first end-to-end test which checks against linear docs example --- hela/tests/__init__.py | 4 ++++ hela/tests/coveragerc | 31 +++++++++++++++++++++++++++++++ hela/tests/test_example.py | 27 +++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 hela/tests/__init__.py create mode 100644 hela/tests/coveragerc create mode 100644 hela/tests/test_example.py diff --git a/hela/tests/__init__.py b/hela/tests/__init__.py new file mode 100644 index 0000000..838b457 --- /dev/null +++ b/hela/tests/__init__.py @@ -0,0 +1,4 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +This module contains package tests. +""" diff --git a/hela/tests/coveragerc b/hela/tests/coveragerc new file mode 100644 index 0000000..bec7c29 --- /dev/null +++ b/hela/tests/coveragerc @@ -0,0 +1,31 @@ +[run] +source = {packagename} +omit = + {packagename}/_astropy_init* + {packagename}/conftest* + {packagename}/cython_version* + {packagename}/setup_package* + {packagename}/*/setup_package* + {packagename}/*/*/setup_package* + {packagename}/tests/* + {packagename}/*/tests/* + {packagename}/*/*/tests/* + {packagename}/version* + +[report] +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about packages we have installed + except ImportError + + # Don't complain if tests don't hit assertions + raise AssertionError + raise NotImplementedError + + # Don't complain about script hooks + def main\(.*\): + + # Ignore branches that don't pertain to this version of Python + pragma: py{ignore_python_version} \ No newline at end of file diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py new file mode 100644 index 0000000..eae1b2e --- /dev/null +++ b/hela/tests/test_example.py @@ -0,0 +1,27 @@ + + +def test_linear_end_to_end(): + from ..forest import generate_example_data + example_dir, training_dataset, samples_path = generate_example_data() + + # Import RandomForest object from HELA + from ..forest import RandomForest + + # Initialize a random forest object: + rf = RandomForest(training_dataset, example_dir, samples_path) + + # Train the random forest: + r2scores = rf.train(num_trees=1000, num_jobs=1) + + # Do a rough check that the R^2 values are near unity + assert abs(r2scores['slope'] - 1) < 0.01 + assert abs(r2scores['intercept'] - 1) < 0.01 + + # Predict posterior distributions from random forest + posterior = rf.predict(plot_posterior=False) + posterior_slopes, posterior_intercepts = posterior.samples.T + + # Do a very generous check that the posterior distributions match + # the expected values + assert abs(posterior_slopes.mean() - 0.3) < 0.1 + assert abs(posterior_intercepts.mean() - 0.5) < 0.1 From c01c96ba43db704eb527895b2587591c5e13fb76 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 13:33:01 +0100 Subject: [PATCH 27/46] Fixing astropy helpers version --- .gitmodules | 2 +- astropy_helpers | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 22d01ae..ba99405 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "astropy_helpers"] path = astropy_helpers - url = git://github.com/astropy/astropy-helpers + url = git@github.com:astropy/astropy-helpers.git diff --git a/astropy_helpers b/astropy_helpers index d8f4890..ce42e6e 160000 --- a/astropy_helpers +++ b/astropy_helpers @@ -1 +1 @@ -Subproject commit d8f48901442e4056879c4249e73ecd7d04a28282 +Subproject commit ce42e6e238c200a4715785ef8c9d233f612d0c75 From 5cfa38aff15b5bea60b25793cfd3f5b6a0355c45 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 13:37:22 +0100 Subject: [PATCH 28/46] changing url for ah submodule --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index ba99405..22d01ae 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "astropy_helpers"] path = astropy_helpers - url = git@github.com:astropy/astropy-helpers.git + url = git://github.com/astropy/astropy-helpers From 4d4de146d15fc4f60db9908d66044bafc8f8b76f Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 13:44:38 +0100 Subject: [PATCH 29/46] Formatting fixes --- hela/__init__.py | 10 +++++----- hela/forest.py | 17 +++++++++++------ hela/models.py | 20 +++++++++++--------- hela/plot.py | 6 ++++-- hela/wpercentile.py | 5 ++--- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/hela/__init__.py b/hela/__init__.py index 4b59994..fe10686 100644 --- a/hela/__init__.py +++ b/hela/__init__.py @@ -3,7 +3,7 @@ # Packages may add whatever they like to this file, but # should keep this content at the top. # ---------------------------------------------------------------------------- -from ._astropy_init import * +from ._astropy_init import * # noqa # ---------------------------------------------------------------------------- # Enforce Python version check during package import. @@ -25,7 +25,7 @@ class UnsupportedPythonError(Exception): if not _ASTROPY_SETUP_: # For egg_info test builds to pass, put package imports here. - from .forest import * - from .models import * - from .dataset import * - from .plot import * + from .forest import * # noqa + from .models import * # noqa + from .dataset import * # noqa + from .plot import * # noqa diff --git a/hela/forest.py b/hela/forest.py index 5204222..6838518 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -50,9 +50,9 @@ def compute_feature_importance(model, dataset, output_path): forests = [i.rf for i in regr.estimators_] + [model.rf] fig = plot_feature_importances( - forests=[i.rf for i in regr.estimators_] + [model.rf], - names=dataset.names + ["joint prediction"], - colors=dataset.colors + ["C0"]) + forests=[i.rf for i in regr.estimators_] + [model.rf], + names=dataset.names + ["joint prediction"], + colors=dataset.colors + ["C0"]) fig.savefig(os.path.join(output_path, "feature_importances.pdf"), bbox_inches='tight') @@ -63,6 +63,7 @@ class RandomForest(object): """ A class for a random forest. """ + def __init__(self, training_dataset, model_path, data_file): """ Parameters @@ -139,7 +140,8 @@ def predict(self, plot_posterior=True): ------- preds : `~numpy.ndarray` ``N x M`` values where ``N`` is number of parameters, ``M`` is - number of samples/trees (check out attributes of model for metadata) + number of samples/trees (check out attributes of model for + metadata) """ model_file = os.path.join(self.model_path, "model.pkl") # Loading random forest from '{}'...".format(model_file) @@ -166,6 +168,7 @@ def predict(self, plot_posterior=True): bbox_inches='tight') return posterior + def data_ranges(posterior, percentiles=(50, 16, 84)): """ Return posterior ranges. @@ -181,9 +184,11 @@ def data_ranges(posterior, percentiles=(50, 16, 84)): """ values = wpercentile(posterior.samples, posterior.weights, percentiles, axis=0) - ranges = np.array([values[0], values[2]-values[0], values[0]-values[1]]) + ranges = np.array( + [values[0], values[2] - values[0], values[0] - values[1]]) return ranges.T + def generate_example_data(): """ Generate an example dataset in the new directory ``linear_dataset``. @@ -243,4 +248,4 @@ def generate_example_data(): samples = true_slope * x + true_intercept np.save(samples_path, samples.T) - return example_dir, training_dataset, samples_path \ No newline at end of file + return example_dir, training_dataset, samples_path diff --git a/hela/models.py b/hela/models.py index 84526d1..f619a2c 100644 --- a/hela/models.py +++ b/hela/models.py @@ -1,5 +1,3 @@ -from collections import namedtuple - import numpy as np from sklearn import ensemble @@ -114,7 +112,7 @@ def get_params(self, deep=True): return {"num_trees": self.num_trees, "num_jobs": self.num_jobs, "names": self.names, "ranges": self.ranges, "colors": self.colors, "verbose": self.verbose, - "enable_posterior": self.enable_posterior,} + "enable_posterior": self.enable_posterior} def trees_predict(self, x): @@ -132,19 +130,22 @@ def predict_percentile(self, x, percentile): if not self.enable_posterior: raise ValueError("Cannot compute posteriors with this model. " - "Set `enable_posterior` to True to enable posterior computation.") + "Set `enable_posterior` to True to enable " + "posterior computation.") # Find the leaves for the query points leaves_x = self.rf.apply(x) if len(x) > self.num_trees: - # If there are many queries, it is faster to find points using a cache + # If there are many queries, it is faster to find points + # using a cache return _posterior_percentile_cache( self.data_leaves, self.data_weights, self.data_y, leaves_x, percentile ) else: - # For few queries, it is faster if we just compute the posterior for each element + # For few queries, it is faster if we just compute the posterior + # for each element return _posterior_percentile_nocache( self.data_leaves, self.data_weights, self.data_y, leaves_x, percentile @@ -154,7 +155,8 @@ def posterior(self, x): leaves_x = self.rf.apply(x[None, :])[0] if not self.enable_posterior: raise ValueError("Cannot compute posteriors with this model. " - "Set `enable_posterior` to True to enable posterior computation.") + "Set `enable_posterior` to True to enable " + "posterior computation.") return _posterior( self.data_leaves, self.data_weights, @@ -207,8 +209,8 @@ def _posterior_percentile_nocache(data_leaves, data_weights, data_y, return np.array(values) -def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, - percentile): +def _posterior_percentile_cache(data_leaves, data_weights, data_y, + query_leaves, percentile): # Build a dictionary for fast access of the contents of the leaves. # Building cache... cache = [ diff --git a/hela/plot.py b/hela/plot.py index abf802f..f1f5247 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -195,7 +195,8 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): width=bins[1] - bins[0]) kd_probs = histogram - expected = wmedian(posterior.samples[:, dims[0]], posterior.weights) + expected = wmedian(posterior.samples[:, dims[0]], + posterior.weights) ax.plot([expected, expected], [0, 1.1 * kd_probs.max()], '-', linewidth=1, color='#222222') @@ -245,7 +246,8 @@ def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): def _plot_samples(ax, posterior, color, dims, ranges): - # For efficiency, do not plot all the samples of the posterior. Subsample first. + # For efficiency, do not plot all the samples of the posterior. + # Subsample first. if len(posterior.samples) > POSTERIOR_MAX_SIZE: posterior = resample_posterior(posterior, POSTERIOR_MAX_SIZE) diff --git a/hela/wpercentile.py b/hela/wpercentile.py index 241d891..62c639e 100644 --- a/hela/wpercentile.py +++ b/hela/wpercentile.py @@ -4,7 +4,6 @@ def _wpercentile1d(data, weights, percentiles): - if data.ndim > 1 or weights.ndim > 1: raise ValueError("data and weights must be one-dimensional arrays") @@ -22,7 +21,7 @@ def _wpercentile1d(data, weights, percentiles): cumsum_weights = np.cumsum(sorted_weights) sum_weights = cumsum_weights[-1] - pn = 100 * (cumsum_weights - 0.5*sorted_weights) / sum_weights + pn = 100 * (cumsum_weights - 0.5 * sorted_weights) / sum_weights return np.interp(percentiles, pn, sorted_data) @@ -82,4 +81,4 @@ def wmedian(data, weights, axis=None): ar : `~numpy.ndarray` """ - return wpercentile(data, weights, 50, axis) \ No newline at end of file + return wpercentile(data, weights, 50, axis) From 7f11aefe46cd65fedfe5f2a7a5025ab271522643 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 13:51:53 +0100 Subject: [PATCH 30/46] flake8 improvements --- .travis.yml | 5 +---- hela/__init__.py | 4 ++-- hela/dataset.py | 1 + hela/models.py | 6 ++---- hela/plot.py | 4 ++-- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6724ac8..67ac7d0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -64,10 +64,7 @@ env: # line to use the X virtual framebuffer. # - SETUP_XVFB=True - # If you want to ignore certain flake8 errors, you can list them - # in FLAKE8_OPT, for example: - # - FLAKE8_OPT='--ignore=E501' - - FLAKE8_OPT='' + - FLAKE8_OPT='--ignore=E501,W504' matrix: diff --git a/hela/__init__.py b/hela/__init__.py index fe10686..4999d8b 100644 --- a/hela/__init__.py +++ b/hela/__init__.py @@ -20,10 +20,10 @@ class UnsupportedPythonError(Exception): if sys.version_info < tuple( (int(val) for val in __minimum_python_version__.split('.'))): raise UnsupportedPythonError( - "packagename does not support Python < {}".format( + "hela does not support Python < {}".format( __minimum_python_version__)) -if not _ASTROPY_SETUP_: +if not _ASTROPY_SETUP_: # noqa # For egg_info test builds to pass, put package imports here. from .forest import * # noqa from .models import * # noqa diff --git a/hela/dataset.py b/hela/dataset.py index ddf832e..bbd9aaa 100644 --- a/hela/dataset.py +++ b/hela/dataset.py @@ -31,6 +31,7 @@ def __init__(self, training_x, training_y, testing_x, testing_y, names, self.ranges = ranges self.colors = colors + def load_data_file(data_file, num_features): data = np.load(data_file) diff --git a/hela/models.py b/hela/models.py index f619a2c..452e7f4 100644 --- a/hela/models.py +++ b/hela/models.py @@ -213,10 +213,8 @@ def _posterior_percentile_cache(data_leaves, data_weights, data_y, query_leaves, percentile): # Build a dictionary for fast access of the contents of the leaves. # Building cache... - cache = [ - _build_leaves_cache(leaves_i, weights_i) - for leaves_i, weights_i in zip(data_leaves, data_weights) - ] + cache = [_build_leaves_cache(leaves_i, weights_i) + for leaves_i, weights_i in zip(data_leaves, data_weights)] values = [] # Check the contents of the leaves in leaves_x diff --git a/hela/plot.py b/hela/plot.py index f1f5247..9f3ea1f 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -275,8 +275,8 @@ def _min_max_scaler(ranges, feature_range=(0, 100)): res.data_max_ = ranges[:, 1] res.data_min_ = ranges[:, 0] res.data_range_ = res.data_max_ - res.data_min_ - res.scale_ = (feature_range[1] - feature_range[0]) / ( - ranges[:, 1] - ranges[:, 0]) + res.scale_ = ((feature_range[1] - feature_range[0]) / + (ranges[:, 1] - ranges[:, 0])) res.min_ = -res.scale_ * res.data_min_ res.n_samples_seen_ = 1 res.feature_range = feature_range From bee11bbb41c0be4cb83cb0421efc2d6407c3139f Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 13:56:20 +0100 Subject: [PATCH 31/46] Adding missing testing matrix dependencies --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 67ac7d0..fc8593e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -51,7 +51,7 @@ env: # List other runtime dependencies for the package that are available as # pip packages here. - - PIP_DEPENDENCIES='scipy matplotlib' + - PIP_DEPENDENCIES='scipy matplotlib tqdm joblib' # Conda packages for affiliated packages are hosted in channel # "astropy" while builds for astropy LTS with recent numpy versions From 7001cdc46c66684a44eef4d2cb9fbbe3a45a8058 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 14:01:29 +0100 Subject: [PATCH 32/46] Editting end-to-end test to improve reliability --- hela/tests/test_example.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index eae1b2e..76a4b64 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -23,5 +23,6 @@ def test_linear_end_to_end(): # Do a very generous check that the posterior distributions match # the expected values - assert abs(posterior_slopes.mean() - 0.3) < 0.1 - assert abs(posterior_intercepts.mean() - 0.5) < 0.1 + assert abs(posterior_slopes.mean() - 0.3) < 3 * posterior_slopes.std() + assert (abs(posterior_intercepts.mean() - 0.5) < + 3 * posterior_intercepts.std()) From 50bc2c6de166cf13f67a3af33eb323e7721d0ccb Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 16:01:09 +0100 Subject: [PATCH 33/46] Splitting out plotting functions from the calculation functions --- hela/forest.py | 102 ++++++++++++++++++++++++------------- hela/plot.py | 6 +-- hela/tests/test_example.py | 2 +- 3 files changed, 71 insertions(+), 39 deletions(-) diff --git a/hela/forest.py b/hela/forest.py index 6838518..f8e66d9 100644 --- a/hela/forest.py +++ b/hela/forest.py @@ -36,26 +36,18 @@ def test_model(model, dataset, output_path): for name, values in r2scores.items(): print("\tR^2 score for {}: {:.3f}".format(name, values)) - fig = plot_predicted_vs_real(dataset.testing_y, pred, dataset.names, - dataset.ranges) + fig, axes = plot_predicted_vs_real(dataset.testing_y, pred, dataset.names, + dataset.ranges) fig.savefig(os.path.join(output_path, "predicted_vs_real.pdf"), bbox_inches='tight') return r2scores -def compute_feature_importance(model, dataset, output_path): +def compute_feature_importance(model, dataset): regr = multioutput.MultiOutputRegressor(model, n_jobs=1) regr.fit(dataset.training_x, dataset.training_y) forests = [i.rf for i in regr.estimators_] + [model.rf] - - fig = plot_feature_importances( - forests=[i.rf for i in regr.estimators_] + [model.rf], - names=dataset.names + ["joint prediction"], - colors=dataset.colors + ["C0"]) - - fig.savefig(os.path.join(output_path, "feature_importances.pdf"), - bbox_inches='tight') return np.array([forest_i.feature_importances_ for forest_i in forests]) @@ -82,6 +74,9 @@ def __init__(self, training_dataset, model_path, data_file): self.dataset = None self.model = None + self._feature_importance = None + self._posterior = None + self.oob = None def train(self, num_trees=1000, num_jobs=5, quiet=False): """ @@ -106,11 +101,13 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): os.makedirs(self.model_path, exist_ok=True) model_file = os.path.join(self.model_path, "model.pkl") + # Saving model joblib.dump(self.model, model_file) # Printing model information... print("OOB score: {:.4f}".format(self.model.rf.oob_score_)) + self.oob = self.model.rf.oob_score_ r2scores = test_model(self.model, self.dataset, self.model_path) @@ -124,11 +121,30 @@ def feature_importance(self): ------- feature_importances : `~numpy.ndarray` """ - self.model.enable_posterior = False - return compute_feature_importance(self.model, self.dataset, - self.model_path) + if self._feature_importance is None: + self._feature_importance = compute_feature_importance(self.model, + self.dataset) + return self._feature_importance + + def plot_feature_importance(self): + """ + Plot the feature importances. + + Returns + ------- + fig, axes + """ + forests = self.feature_importance() + fig, axes = plot_feature_importances(forests=forests, + names=(self.dataset.names + + ["joint prediction"]), + colors=self.dataset.colors + ["C0"]) + + fig.savefig(os.path.join(self.output_path, "feature_importances.pdf"), + bbox_inches='tight') + return fig, axes - def predict(self, plot_posterior=True): + def predict(self, quiet=False): """ Predict values from the trained random forest. @@ -143,30 +159,46 @@ def predict(self, plot_posterior=True): number of samples/trees (check out attributes of model for metadata) """ + if self._posterior is None: + model_file = os.path.join(self.model_path, "model.pkl") + # Loading random forest from '{}'...".format(model_file) + model = joblib.load(model_file) + + # Loading data from '{}'...".format(data_file) + data, _ = load_data_file(self.data_file, model.rf.n_features_) + + posterior = model.posterior(data[0]) + + if not quiet: + posterior_ranges = data_ranges(posterior) + for name_i, pred_range_i in zip(model.names, posterior_ranges): + print("Prediction for {}: {:.3g} " + "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) + + self._posterior = posterior + + return self._posterior + + def plot_posterior(self): + """ + Plot the posterior distributions for each parameter. + + Returns + ------- + fig, axes + """ model_file = os.path.join(self.model_path, "model.pkl") # Loading random forest from '{}'...".format(model_file) model = joblib.load(model_file) - # Loading data from '{}'...".format(data_file) - data, _ = load_data_file(self.data_file, model.rf.n_features_) - - posterior = model.posterior(data[0]) - - posterior_ranges = data_ranges(posterior) - for name_i, pred_range_i in zip(model.names, posterior_ranges): - print("Prediction for {}: {:.3g} " - "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) - - if plot_posterior: - # Plotting and saving the posterior matrix..." - fig = plot_posterior_matrix(posterior, - names=model.names, - ranges=model.ranges, - colors=model.colors) - os.makedirs(self.output_path, exist_ok=True) - fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), - bbox_inches='tight') - return posterior + fig, axes = plot_posterior_matrix(self._posterior, + names=model.names, + ranges=model.ranges, + colors=model.colors) + os.makedirs(self.output_path, exist_ok=True) + fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), + bbox_inches='tight') + return fig, axes def data_ranges(posterior, percentiles=(50, 16, 84)): diff --git a/hela/plot.py b/hela/plot.py index 9f3ea1f..6f5d67d 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -71,7 +71,7 @@ def plot_predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): ax.legend(loc="upper left", fontsize=14) fig.tight_layout() - return fig + return fig, axes def plot_feature_importances(forests, names, colors): @@ -106,7 +106,7 @@ def plot_feature_importances(forests, names, colors): ax.grid() fig.tight_layout() - return fig + return fig, axes def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): @@ -206,7 +206,7 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): # fig.tight_layout(pad=0) # fig.tight_layout(pad=0) - return fig + return fig, axes def _plot_histogram2d(ax, posterior, color, cmap, dims, ranges): diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index 76a4b64..e6582d5 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -18,7 +18,7 @@ def test_linear_end_to_end(): assert abs(r2scores['intercept'] - 1) < 0.01 # Predict posterior distributions from random forest - posterior = rf.predict(plot_posterior=False) + posterior = rf.predict() posterior_slopes, posterior_intercepts = posterior.samples.T # Do a very generous check that the posterior distributions match From c9bcdc3cc80f3a0d872b1ddbfd0dd9fd92505da5 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 16:14:58 +0100 Subject: [PATCH 34/46] Updating file/object names with @pmneila's help --- docs/hela/tutorial.rst | 26 +++++++++++++------------- hela/__init__.py | 4 ++-- hela/{forest.py => model.py} | 18 +++++++++--------- hela/plot.py | 2 +- hela/tests/test_example.py | 12 ++++++------ hela/{models.py => wrapper.py} | 6 +++--- 6 files changed, 34 insertions(+), 34 deletions(-) rename hela/{forest.py => model.py} (95%) rename hela/{models.py => wrapper.py} (98%) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 9165314..26dc8ea 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -39,16 +39,16 @@ We also generated a bunch of samples with a known slope and intercept, called the slope and intercept. Once we have these three data structures written and their paths saved, we can -run ``hela`` on the data. First, we'll initialize a `~hela.RandomForest` object +run ``hela`` on the data. First, we'll initialize a `~hela.Model` object with the paths to the three files/directories that it needs to know about: .. code-block:: python - from hela import RandomForest + from hela import Model import matplotlib.pyplot as plt # Initialize a random forest object: - rf = RandomForest(training_dataset, example_dir, samples_path) + m = Model(training_dataset, example_dir, samples_path) We now have a random forest object ``rf`` which is ready for training. We can train the random forest with 1000 trees and on a single processor: @@ -56,7 +56,7 @@ train the random forest with 1000 trees and on a single processor: .. code-block:: python # Train the random forest: - r2scores = rf.train(num_trees=1000, num_jobs=1) + r2scores = m.train(num_trees=1000, num_jobs=1) plt.show() .. plot:: @@ -65,29 +65,29 @@ train the random forest with 1000 trees and on a single processor: # Generate an example dataset directory example_dir, training_dataset, samples_path = generate_example_data() - from hela import RandomForest + from hela import Model import matplotlib.pyplot as plt # Initialize a random forest object: - rf = RandomForest(training_dataset, example_dir, samples_path) + m = Model(training_dataset, example_dir, samples_path) # Train the random forest: - r2scores = rf.train(num_trees=1000, num_jobs=1) + r2scores = m.train(num_trees=1000, num_jobs=1) plt.show() -The `~hela.RandomForest.train` method returns a dictionary called `r2scores` +The `~hela.Model.train` method returns a dictionary called `r2scores` which contains the :math:`R^2` scores of the slope and intercept, which should both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept using the trained random forest on the sample data in ``samples_path``, where the true values of the slope and intercept are :math:`m=0.3` and :math:`b=0.5` -using the `~hela.RandomForest.predict` method: +using the `~hela.Model.predict` method: .. code-block:: python # Predict posterior distributions from random forest - posterior = rf.predict(plot_posterior=True) + posterior = m.predict() posterior_slopes, posterior_intercepts = posterior.samples.T plt.show() @@ -97,14 +97,14 @@ using the `~hela.RandomForest.predict` method: # Generate an example dataset directory example_dir, training_dataset, samples_path = generate_example_data() - from hela import RandomForest + from hela import Model import matplotlib.pyplot as plt # Initialize a random forest object: - rf = RandomForest(training_dataset, example_dir, samples_path) + m = Model(training_dataset, example_dir, samples_path) # Predict posterior distributions from random forest - posterior = rf.predict(plot_posterior=True) + posterior = m.predict() posterior_slopes, posterior_intercepts = posterior.samples.T plt.tight_layout() plt.show() diff --git a/hela/__init__.py b/hela/__init__.py index 4999d8b..f44d80d 100644 --- a/hela/__init__.py +++ b/hela/__init__.py @@ -25,7 +25,7 @@ class UnsupportedPythonError(Exception): if not _ASTROPY_SETUP_: # noqa # For egg_info test builds to pass, put package imports here. - from .forest import * # noqa - from .models import * # noqa + from .model import * # noqa + from .wrapper import * # noqa from .dataset import * # noqa from .plot import * # noqa diff --git a/hela/forest.py b/hela/model.py similarity index 95% rename from hela/forest.py rename to hela/model.py index f8e66d9..5cd906b 100644 --- a/hela/forest.py +++ b/hela/model.py @@ -6,20 +6,20 @@ import joblib from .dataset import load_dataset, load_data_file -from .models import Model +from .wrapper import RandomForestWrapper from .plot import (plot_predicted_vs_real, plot_feature_importances, plot_posterior_matrix) from .wpercentile import wpercentile -__all__ = ['RandomForest', 'generate_example_data'] +__all__ = ['Model', 'generate_example_data'] def train_model(dataset, num_trees, num_jobs, verbose=1): - pipeline = Model(num_trees, num_jobs, - names=dataset.names, - ranges=dataset.ranges, - colors=dataset.colors, - verbose=verbose) + pipeline = RandomForestWrapper(num_trees, num_jobs, + names=dataset.names, + ranges=dataset.ranges, + colors=dataset.colors, + verbose=verbose) pipeline.fit(dataset.training_x, dataset.training_y) return pipeline @@ -51,9 +51,9 @@ def compute_feature_importance(model, dataset): return np.array([forest_i.feature_importances_ for forest_i in forests]) -class RandomForest(object): +class Model(object): """ - A class for a random forest. + A class for a trainable random forest model. """ def __init__(self, training_dataset, model_path, data_file): diff --git a/hela/plot.py b/hela/plot.py index 6f5d67d..c217004 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -9,7 +9,7 @@ from tqdm import tqdm -from .models import resample_posterior +from .wrapper import resample_posterior from .wpercentile import wmedian __all__ = ['plot_predicted_vs_real', 'plot_feature_importances', diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index e6582d5..a10b489 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -1,24 +1,24 @@ def test_linear_end_to_end(): - from ..forest import generate_example_data + from ..model import generate_example_data example_dir, training_dataset, samples_path = generate_example_data() # Import RandomForest object from HELA - from ..forest import RandomForest + from ..model import Model - # Initialize a random forest object: - rf = RandomForest(training_dataset, example_dir, samples_path) + # Initialize a model: + m = Model(training_dataset, example_dir, samples_path) # Train the random forest: - r2scores = rf.train(num_trees=1000, num_jobs=1) + r2scores = m.train(num_trees=1000, num_jobs=1) # Do a rough check that the R^2 values are near unity assert abs(r2scores['slope'] - 1) < 0.01 assert abs(r2scores['intercept'] - 1) < 0.01 # Predict posterior distributions from random forest - posterior = rf.predict() + posterior = m.predict() posterior_slopes, posterior_intercepts = posterior.samples.T # Do a very generous check that the posterior distributions match diff --git a/hela/models.py b/hela/wrapper.py similarity index 98% rename from hela/models.py rename to hela/wrapper.py index 452e7f4..1a87b92 100644 --- a/hela/models.py +++ b/hela/wrapper.py @@ -6,12 +6,12 @@ from tqdm import tqdm -__all__ = ['Model', 'Posterior', 'resample_posterior'] +__all__ = ['RandomForestWrapper', 'Posterior', 'resample_posterior'] -class Model(object): +class RandomForestWrapper(object): """ - Class for models. + Wrapper class for the scikit-learn RandomForestRegressor. """ def __init__(self, num_trees, num_jobs, names, ranges, colors, From 8a6c4ede855443b46c3d9be2d4114b4561efac51 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 16:24:42 +0100 Subject: [PATCH 35/46] Fixing tutorial with new object names --- docs/hela/tutorial.rst | 8 ++++++++ hela/model.py | 28 ++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 26dc8ea..53f210f 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -57,6 +57,9 @@ train the random forest with 1000 trees and on a single processor: # Train the random forest: r2scores = m.train(num_trees=1000, num_jobs=1) + + # Plot the results: + m.plot_correlations() plt.show() .. plot:: @@ -73,6 +76,7 @@ train the random forest with 1000 trees and on a single processor: # Train the random forest: r2scores = m.train(num_trees=1000, num_jobs=1) + m.plot_correlations() plt.show() The `~hela.Model.train` method returns a dictionary called `r2scores` @@ -89,6 +93,9 @@ using the `~hela.Model.predict` method: # Predict posterior distributions from random forest posterior = m.predict() posterior_slopes, posterior_intercepts = posterior.samples.T + + # Plot the posteriors + m.plot_posterior() plt.show() .. plot:: @@ -106,6 +113,7 @@ using the `~hela.Model.predict` method: # Predict posterior distributions from random forest posterior = m.predict() posterior_slopes, posterior_intercepts = posterior.samples.T + m.plot_posterior() plt.tight_layout() plt.show() diff --git a/hela/model.py b/hela/model.py index 5cd906b..2dadf6a 100644 --- a/hela/model.py +++ b/hela/model.py @@ -24,7 +24,7 @@ def train_model(dataset, num_trees, num_jobs, verbose=1): return pipeline -def test_model(model, dataset, output_path): +def test_model(model, dataset): if dataset.testing_x is None: return @@ -36,11 +36,7 @@ def test_model(model, dataset, output_path): for name, values in r2scores.items(): print("\tR^2 score for {}: {:.3f}".format(name, values)) - fig, axes = plot_predicted_vs_real(dataset.testing_y, pred, dataset.names, - dataset.ranges) - fig.savefig(os.path.join(output_path, "predicted_vs_real.pdf"), - bbox_inches='tight') - return r2scores + return pred, r2scores def compute_feature_importance(model, dataset): @@ -77,6 +73,7 @@ def __init__(self, training_dataset, model_path, data_file): self._feature_importance = None self._posterior = None self.oob = None + self.pred = None def train(self, num_trees=1000, num_jobs=5, quiet=False): """ @@ -109,10 +106,25 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): print("OOB score: {:.4f}".format(self.model.rf.oob_score_)) self.oob = self.model.rf.oob_score_ - r2scores = test_model(self.model, self.dataset, self.model_path) - + pred, r2scores = test_model(self.model, self.dataset) + self.pred = pred return r2scores + def plot_correlations(self): + """ + Plot training correlations. + + Returns + ------- + fig, axes + """ + fig, axes = plot_predicted_vs_real(self.dataset.testing_y, self.pred, + self.dataset.names, + self.dataset.ranges) + fig.savefig(os.path.join(self.output_path, "predicted_vs_real.pdf"), + bbox_inches='tight') + return fig, axes + def feature_importance(self): """ Compute feature importance. From f2b7b47bc86450efb20945b1a00cee86c32e8f6b Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 16:28:03 +0100 Subject: [PATCH 36/46] flake8 fix --- hela/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hela/model.py b/hela/model.py index 2dadf6a..91ef687 100644 --- a/hela/model.py +++ b/hela/model.py @@ -150,7 +150,8 @@ def plot_feature_importance(self): fig, axes = plot_feature_importances(forests=forests, names=(self.dataset.names + ["joint prediction"]), - colors=self.dataset.colors + ["C0"]) + colors=(self.dataset.colors + + ["C0"])) fig.savefig(os.path.join(self.output_path, "feature_importances.pdf"), bbox_inches='tight') From 739d065dfa493e1811e6b65f16849a1ceae58be5 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Fri, 8 Nov 2019 16:30:12 +0100 Subject: [PATCH 37/46] Fixing important typo in test --- docs/hela/tutorial.rst | 2 +- hela/tests/test_example.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 53f210f..95b9f2c 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -85,7 +85,7 @@ both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept using the trained random forest on the sample data in ``samples_path``, where -the true values of the slope and intercept are :math:`m=0.3` and :math:`b=0.5` +the true values of the slope and intercept are :math:`m=0.2` and :math:`b=0.5` using the `~hela.Model.predict` method: .. code-block:: python diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index a10b489..f7cc8a5 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -23,6 +23,6 @@ def test_linear_end_to_end(): # Do a very generous check that the posterior distributions match # the expected values - assert abs(posterior_slopes.mean() - 0.3) < 3 * posterior_slopes.std() + assert abs(posterior_slopes.mean() - 0.2) < 3 * posterior_slopes.std() assert (abs(posterior_intercepts.mean() - 0.5) < 3 * posterior_intercepts.std()) From 296d476f81a65354f60d75de03e9fdee7280d29b Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Mon, 11 Nov 2019 09:16:33 +0100 Subject: [PATCH 38/46] Fixing typos in docs --- docs/hela/tutorial.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 95b9f2c..d9ad38d 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -30,7 +30,7 @@ free parameters, which looks like this: "training_data": "training.npy", "testing_data": "testing.npy"} -This file tells the model what the two fitting parameters are and their rainges, +This file tells the model what the two fitting parameters are and their ranges, where to grab the training and testing datasets (in the npy pickle files), the number of features (1000), the colors to use for each parameter in the plots. @@ -47,10 +47,10 @@ with the paths to the three files/directories that it needs to know about: from hela import Model import matplotlib.pyplot as plt - # Initialize a random forest object: + # Initialize a retrieval model object: m = Model(training_dataset, example_dir, samples_path) -We now have a random forest object ``rf`` which is ready for training. We can +We now have a random forest object ``m`` which is ready for training. We can train the random forest with 1000 trees and on a single processor: .. code-block:: python @@ -79,7 +79,7 @@ train the random forest with 1000 trees and on a single processor: m.plot_correlations() plt.show() -The `~hela.Model.train` method returns a dictionary called `r2scores` +The `~hela.Model.train` method returns a dictionary called ``r2scores`` which contains the :math:`R^2` scores of the slope and intercept, which should both be close to unity for this example. From c5e62d2f8e5fde4ed3b5bf6c411fbe1cd6fa1e9e Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Mon, 11 Nov 2019 19:05:48 +0100 Subject: [PATCH 39/46] Renaming Model -> Retrieval --- docs/hela/tutorial.rst | 36 ++++++++++++++++----------------- hela/__init__.py | 2 +- hela/{model.py => retrieval.py} | 4 ++-- hela/tests/test_example.py | 10 ++++----- 4 files changed, 26 insertions(+), 26 deletions(-) rename hela/{model.py => retrieval.py} (99%) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index d9ad38d..64d23cb 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -39,27 +39,27 @@ We also generated a bunch of samples with a known slope and intercept, called the slope and intercept. Once we have these three data structures written and their paths saved, we can -run ``hela`` on the data. First, we'll initialize a `~hela.Model` object +run ``hela`` on the data. First, we'll initialize a `~hela.Retrieval` object with the paths to the three files/directories that it needs to know about: .. code-block:: python - from hela import Model + from hela import Retrieval import matplotlib.pyplot as plt # Initialize a retrieval model object: - m = Model(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir, samples_path) -We now have a random forest object ``m`` which is ready for training. We can +We now have a Retrieval object ``r`` which is ready for training. We can train the random forest with 1000 trees and on a single processor: .. code-block:: python # Train the random forest: - r2scores = m.train(num_trees=1000, num_jobs=1) + r2scores = r.train(num_trees=1000, num_jobs=1) # Plot the results: - m.plot_correlations() + r.plot_correlations() plt.show() .. plot:: @@ -68,34 +68,34 @@ train the random forest with 1000 trees and on a single processor: # Generate an example dataset directory example_dir, training_dataset, samples_path = generate_example_data() - from hela import Model + from hela import Retrieval import matplotlib.pyplot as plt # Initialize a random forest object: - m = Model(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir, samples_path) # Train the random forest: - r2scores = m.train(num_trees=1000, num_jobs=1) - m.plot_correlations() + r2scores = r.train(num_trees=1000, num_jobs=1) + r.plot_correlations() plt.show() -The `~hela.Model.train` method returns a dictionary called ``r2scores`` +The `~hela.Retrieval.train` method returns a dictionary called ``r2scores`` which contains the :math:`R^2` scores of the slope and intercept, which should both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept using the trained random forest on the sample data in ``samples_path``, where the true values of the slope and intercept are :math:`m=0.2` and :math:`b=0.5` -using the `~hela.Model.predict` method: +using the `~hela.Retrieval.predict` method: .. code-block:: python # Predict posterior distributions from random forest - posterior = m.predict() + posterior = r.predict() posterior_slopes, posterior_intercepts = posterior.samples.T # Plot the posteriors - m.plot_posterior() + r.plot_posterior() plt.show() .. plot:: @@ -104,16 +104,16 @@ using the `~hela.Model.predict` method: # Generate an example dataset directory example_dir, training_dataset, samples_path = generate_example_data() - from hela import Model + from hela import Retrieval import matplotlib.pyplot as plt # Initialize a random forest object: - m = Model(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir, samples_path) # Predict posterior distributions from random forest - posterior = m.predict() + posterior = r.predict() posterior_slopes, posterior_intercepts = posterior.samples.T - m.plot_posterior() + r.plot_posterior() plt.tight_layout() plt.show() diff --git a/hela/__init__.py b/hela/__init__.py index f44d80d..6bc7f0f 100644 --- a/hela/__init__.py +++ b/hela/__init__.py @@ -25,7 +25,7 @@ class UnsupportedPythonError(Exception): if not _ASTROPY_SETUP_: # noqa # For egg_info test builds to pass, put package imports here. - from .model import * # noqa + from .retrieval import * # noqa from .wrapper import * # noqa from .dataset import * # noqa from .plot import * # noqa diff --git a/hela/model.py b/hela/retrieval.py similarity index 99% rename from hela/model.py rename to hela/retrieval.py index 91ef687..6fd2dfd 100644 --- a/hela/model.py +++ b/hela/retrieval.py @@ -11,7 +11,7 @@ plot_posterior_matrix) from .wpercentile import wpercentile -__all__ = ['Model', 'generate_example_data'] +__all__ = ['Retrieval', 'generate_example_data'] def train_model(dataset, num_trees, num_jobs, verbose=1): @@ -47,7 +47,7 @@ def compute_feature_importance(model, dataset): return np.array([forest_i.feature_importances_ for forest_i in forests]) -class Model(object): +class Retrieval(object): """ A class for a trainable random forest model. """ diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index f7cc8a5..5ff2303 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -1,24 +1,24 @@ def test_linear_end_to_end(): - from ..model import generate_example_data + from ..retrieval import generate_example_data example_dir, training_dataset, samples_path = generate_example_data() # Import RandomForest object from HELA - from ..model import Model + from ..retrieval import Retrieval # Initialize a model: - m = Model(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir, samples_path) # Train the random forest: - r2scores = m.train(num_trees=1000, num_jobs=1) + r2scores = r.train(num_trees=1000, num_jobs=1) # Do a rough check that the R^2 values are near unity assert abs(r2scores['slope'] - 1) < 0.01 assert abs(r2scores['intercept'] - 1) < 0.01 # Predict posterior distributions from random forest - posterior = m.predict() + posterior = r.predict() posterior_slopes, posterior_intercepts = posterior.samples.T # Do a very generous check that the posterior distributions match From 8223dabf4492477d497e4bb3b62829b097215114 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Mon, 11 Nov 2019 19:17:03 +0100 Subject: [PATCH 40/46] Reordering returned values from generate_example_data --- docs/hela/tutorial.rst | 6 +++--- hela/retrieval.py | 2 +- hela/tests/test_example.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 64d23cb..86aca7c 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -13,7 +13,7 @@ which we'd like to predict on: from hela import generate_example_data # Generate an example dataset directory - example_dir, training_dataset, samples_path = generate_example_data() + training_dataset, example_dir, samples_path = generate_example_data() This handy command created an example directory called ``linear_data``, which contains a training dataset described by the metadata file located at path @@ -66,7 +66,7 @@ train the random forest with 1000 trees and on a single processor: from hela import generate_example_data # Generate an example dataset directory - example_dir, training_dataset, samples_path = generate_example_data() + training_dataset, example_dir, samples_path = generate_example_data() from hela import Retrieval import matplotlib.pyplot as plt @@ -102,7 +102,7 @@ using the `~hela.Retrieval.predict` method: from hela import generate_example_data # Generate an example dataset directory - example_dir, training_dataset, samples_path = generate_example_data() + training_dataset, example_dir, samples_path = generate_example_data() from hela import Retrieval import matplotlib.pyplot as plt diff --git a/hela/retrieval.py b/hela/retrieval.py index 6fd2dfd..de10a49 100644 --- a/hela/retrieval.py +++ b/hela/retrieval.py @@ -293,4 +293,4 @@ def generate_example_data(): samples = true_slope * x + true_intercept np.save(samples_path, samples.T) - return example_dir, training_dataset, samples_path + return training_dataset, example_dir, samples_path diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index 5ff2303..5f32efd 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -2,7 +2,7 @@ def test_linear_end_to_end(): from ..retrieval import generate_example_data - example_dir, training_dataset, samples_path = generate_example_data() + training_dataset, example_dir, samples_path = generate_example_data() # Import RandomForest object from HELA from ..retrieval import Retrieval From f98e83c95132732ff01d1b2145af3ad08f8698f5 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Tue, 12 Nov 2019 09:09:57 +0100 Subject: [PATCH 41/46] Naming tweaks from @pmneila --- hela/__init__.py | 2 +- hela/plot.py | 2 +- hela/{wrapper.py => posteriors.py} | 8 ++--- hela/retrieval.py | 50 ++++++++++++------------------ 4 files changed, 25 insertions(+), 37 deletions(-) rename hela/{wrapper.py => posteriors.py} (97%) diff --git a/hela/__init__.py b/hela/__init__.py index 6bc7f0f..f8e5a0b 100644 --- a/hela/__init__.py +++ b/hela/__init__.py @@ -26,6 +26,6 @@ class UnsupportedPythonError(Exception): if not _ASTROPY_SETUP_: # noqa # For egg_info test builds to pass, put package imports here. from .retrieval import * # noqa - from .wrapper import * # noqa + from .posteriors import * # noqa from .dataset import * # noqa from .plot import * # noqa diff --git a/hela/plot.py b/hela/plot.py index c217004..7c53332 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -9,7 +9,7 @@ from tqdm import tqdm -from .wrapper import resample_posterior +from .posteriors import resample_posterior from .wpercentile import wmedian __all__ = ['plot_predicted_vs_real', 'plot_feature_importances', diff --git a/hela/wrapper.py b/hela/posteriors.py similarity index 97% rename from hela/wrapper.py rename to hela/posteriors.py index 1a87b92..f0a8629 100644 --- a/hela/wrapper.py +++ b/hela/posteriors.py @@ -6,12 +6,12 @@ from tqdm import tqdm -__all__ = ['RandomForestWrapper', 'Posterior', 'resample_posterior'] +__all__ = ['PosteriorRandomForest', 'Posterior', 'resample_posterior'] -class RandomForestWrapper(object): +class PosteriorRandomForest(object): """ - Wrapper class for the scikit-learn RandomForestRegressor. + Produces posterior samples from a random forest. """ def __init__(self, num_trees, num_jobs, names, ranges, colors, @@ -151,7 +151,7 @@ def predict_percentile(self, x, percentile): self.data_y, leaves_x, percentile ) - def posterior(self, x): + def predict_posterior(self, x): leaves_x = self.rf.apply(x[None, :])[0] if not self.enable_posterior: raise ValueError("Cannot compute posteriors with this model. " diff --git a/hela/retrieval.py b/hela/retrieval.py index de10a49..4a7b0b1 100644 --- a/hela/retrieval.py +++ b/hela/retrieval.py @@ -6,7 +6,7 @@ import joblib from .dataset import load_dataset, load_data_file -from .wrapper import RandomForestWrapper +from .posteriors import PosteriorRandomForest from .plot import (plot_predicted_vs_real, plot_feature_importances, plot_posterior_matrix) from .wpercentile import wpercentile @@ -15,11 +15,11 @@ def train_model(dataset, num_trees, num_jobs, verbose=1): - pipeline = RandomForestWrapper(num_trees, num_jobs, - names=dataset.names, - ranges=dataset.ranges, - colors=dataset.colors, - verbose=verbose) + pipeline = PosteriorRandomForest(num_trees, num_jobs, + names=dataset.names, + ranges=dataset.ranges, + colors=dataset.colors, + verbose=verbose) pipeline.fit(dataset.training_x, dataset.training_y) return pipeline @@ -71,7 +71,6 @@ def __init__(self, training_dataset, model_path, data_file): self.dataset = None self.model = None self._feature_importance = None - self._posterior = None self.oob = None self.pred = None @@ -102,8 +101,7 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): # Saving model joblib.dump(self.model, model_file) - # Printing model information... - print("OOB score: {:.4f}".format(self.model.rf.oob_score_)) + # saving model information... self.oob = self.model.rf.oob_score_ pred, r2scores = test_model(self.model, self.dataset) @@ -121,8 +119,6 @@ def plot_correlations(self): fig, axes = plot_predicted_vs_real(self.dataset.testing_y, self.pred, self.dataset.names, self.dataset.ranges) - fig.savefig(os.path.join(self.output_path, "predicted_vs_real.pdf"), - bbox_inches='tight') return fig, axes def feature_importance(self): @@ -152,9 +148,6 @@ def plot_feature_importance(self): ["joint prediction"]), colors=(self.dataset.colors + ["C0"])) - - fig.savefig(os.path.join(self.output_path, "feature_importances.pdf"), - bbox_inches='tight') return fig, axes def predict(self, quiet=False): @@ -172,25 +165,22 @@ def predict(self, quiet=False): number of samples/trees (check out attributes of model for metadata) """ - if self._posterior is None: - model_file = os.path.join(self.model_path, "model.pkl") - # Loading random forest from '{}'...".format(model_file) - model = joblib.load(model_file) - - # Loading data from '{}'...".format(data_file) - data, _ = load_data_file(self.data_file, model.rf.n_features_) + model_file = os.path.join(self.model_path, "model.pkl") + # Loading random forest from '{}'...".format(model_file) + model = joblib.load(model_file) - posterior = model.posterior(data[0]) + # Loading data from '{}'...".format(data_file) + data, _ = load_data_file(self.data_file, model.rf.n_features_) - if not quiet: - posterior_ranges = data_ranges(posterior) - for name_i, pred_range_i in zip(model.names, posterior_ranges): - print("Prediction for {}: {:.3g} " - "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) + posterior = model.predict_posterior(data[0]) - self._posterior = posterior + if not quiet: + posterior_ranges = data_ranges(posterior) + for name_i, pred_range_i in zip(model.names, posterior_ranges): + print("Prediction for {}: {:.3g} " + "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) - return self._posterior + return posterior def plot_posterior(self): """ @@ -209,8 +199,6 @@ def plot_posterior(self): ranges=model.ranges, colors=model.colors) os.makedirs(self.output_path, exist_ok=True) - fig.savefig(os.path.join(self.output_path, "posterior_matrix.pdf"), - bbox_inches='tight') return fig, axes From cc7168b570f83a9d8e6ba8af73ff9efb0aa20d91 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Tue, 12 Nov 2019 09:40:16 +0100 Subject: [PATCH 42/46] API, example data tweaks --- docs/hela/tutorial.rst | 22 +++++++++++----------- hela/retrieval.py | 28 ++++++++++------------------ 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 86aca7c..c78f3f0 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -13,7 +13,7 @@ which we'd like to predict on: from hela import generate_example_data # Generate an example dataset directory - training_dataset, example_dir, samples_path = generate_example_data() + training_dataset, example_dir, data = generate_example_data() This handy command created an example directory called ``linear_data``, which contains a training dataset described by the metadata file located at path @@ -48,7 +48,7 @@ with the paths to the three files/directories that it needs to know about: import matplotlib.pyplot as plt # Initialize a retrieval model object: - r = Retrieval(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir) We now have a Retrieval object ``r`` which is ready for training. We can train the random forest with 1000 trees and on a single processor: @@ -66,13 +66,13 @@ train the random forest with 1000 trees and on a single processor: from hela import generate_example_data # Generate an example dataset directory - training_dataset, example_dir, samples_path = generate_example_data() + training_dataset, example_dir, data = generate_example_data() from hela import Retrieval import matplotlib.pyplot as plt # Initialize a random forest object: - r = Retrieval(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir) # Train the random forest: r2scores = r.train(num_trees=1000, num_jobs=1) @@ -84,36 +84,36 @@ which contains the :math:`R^2` scores of the slope and intercept, which should both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept -using the trained random forest on the sample data in ``samples_path``, where +using the trained random forest on the sample data in ``data``, where the true values of the slope and intercept are :math:`m=0.2` and :math:`b=0.5` using the `~hela.Retrieval.predict` method: .. code-block:: python # Predict posterior distributions from random forest - posterior = r.predict() + posterior = r.predict(data) posterior_slopes, posterior_intercepts = posterior.samples.T # Plot the posteriors - r.plot_posterior() + r.plot_posterior(posterior) plt.show() .. plot:: from hela import generate_example_data # Generate an example dataset directory - training_dataset, example_dir, samples_path = generate_example_data() + training_dataset, example_dir, data = generate_example_data() from hela import Retrieval import matplotlib.pyplot as plt # Initialize a random forest object: - r = Retrieval(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir) # Predict posterior distributions from random forest - posterior = r.predict() + posterior = r.predict(data) posterior_slopes, posterior_intercepts = posterior.samples.T - r.plot_posterior() + r.plot_posterior(posterior) plt.tight_layout() plt.show() diff --git a/hela/retrieval.py b/hela/retrieval.py index 4a7b0b1..d27c777 100644 --- a/hela/retrieval.py +++ b/hela/retrieval.py @@ -5,7 +5,7 @@ from sklearn import metrics, multioutput import joblib -from .dataset import load_dataset, load_data_file +from .dataset import load_dataset from .posteriors import PosteriorRandomForest from .plot import (plot_predicted_vs_real, plot_feature_importances, plot_posterior_matrix) @@ -52,7 +52,7 @@ class Retrieval(object): A class for a trainable random forest model. """ - def __init__(self, training_dataset, model_path, data_file): + def __init__(self, training_dataset, model_path): """ Parameters ---------- @@ -60,12 +60,9 @@ def __init__(self, training_dataset, model_path, data_file): Path to the dataset metadata JSON file model_path : str Path to the output directory to create and populate - data_file : str - Path to the numpy pickle of the samples to predict on """ self.training_dataset = training_dataset self.model_path = model_path - self.data_file = data_file self.output_path = self.model_path self.dataset = None @@ -150,12 +147,14 @@ def plot_feature_importance(self): ["C0"])) return fig, axes - def predict(self, quiet=False): + def predict(self, x, quiet=False): """ Predict values from the trained random forest. Parameters ---------- + x : `~numpy.ndarray` + plot_posterior : bool Returns @@ -169,10 +168,7 @@ def predict(self, quiet=False): # Loading random forest from '{}'...".format(model_file) model = joblib.load(model_file) - # Loading data from '{}'...".format(data_file) - data, _ = load_data_file(self.data_file, model.rf.n_features_) - - posterior = model.predict_posterior(data[0]) + posterior = model.predict_posterior(x) if not quiet: posterior_ranges = data_ranges(posterior) @@ -182,7 +178,7 @@ def predict(self, quiet=False): return posterior - def plot_posterior(self): + def plot_posterior(self, posterior): """ Plot the posterior distributions for each parameter. @@ -194,11 +190,10 @@ def plot_posterior(self): # Loading random forest from '{}'...".format(model_file) model = joblib.load(model_file) - fig, axes = plot_posterior_matrix(self._posterior, + fig, axes = plot_posterior_matrix(posterior, names=model.names, ranges=model.ranges, colors=model.colors) - os.makedirs(self.output_path, exist_ok=True) return fig, axes @@ -232,12 +227,10 @@ def generate_example_data(): Path to the directory of the example data training_dataset : str Path to the dataset metadata JSON file - samples_path : str - Path to the numpy pickle of the samples to predict on + samples : `~numpy.ndarray` """ example_dir = 'linear_dataset' training_dataset = os.path.join(example_dir, 'example_dataset.json') - samples_path = 'samples.npy' os.makedirs(example_dir, exist_ok=True) @@ -280,5 +273,4 @@ def generate_example_data(): true_intercept = 0.5 samples = true_slope * x + true_intercept - np.save(samples_path, samples.T) - return training_dataset, example_dir, samples_path + return training_dataset, example_dir, samples.T[0] From 6268e8464d3a21c6844bf8acc6a6e32a5868b049 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Tue, 12 Nov 2019 10:16:35 +0100 Subject: [PATCH 43/46] Adding missing training line from docs build --- docs/hela/tutorial.rst | 2 ++ hela/posteriors.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index c78f3f0..99f1abf 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -110,6 +110,8 @@ using the `~hela.Retrieval.predict` method: # Initialize a random forest object: r = Retrieval(training_dataset, example_dir) + r2scores = r.train(num_trees=1000, num_jobs=1) + # Predict posterior distributions from random forest posterior = r.predict(data) posterior_slopes, posterior_intercepts = posterior.samples.T diff --git a/hela/posteriors.py b/hela/posteriors.py index f0a8629..c2d2aa1 100644 --- a/hela/posteriors.py +++ b/hela/posteriors.py @@ -108,7 +108,7 @@ def predict(self, x): pred = self.rf.predict(x) return self._scaler_inverse_transform(pred) - def get_params(self, deep=True): + def get_params(self): return {"num_trees": self.num_trees, "num_jobs": self.num_jobs, "names": self.names, "ranges": self.ranges, "colors": self.colors, "verbose": self.verbose, From 51402538260e75f45145f1a75252d7ad84618cc8 Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Tue, 12 Nov 2019 10:20:41 +0100 Subject: [PATCH 44/46] Fix broken test due to update to generate_example_data --- hela/tests/test_example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index 5f32efd..cadd791 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -2,13 +2,13 @@ def test_linear_end_to_end(): from ..retrieval import generate_example_data - training_dataset, example_dir, samples_path = generate_example_data() + training_dataset, example_dir, data = generate_example_data() # Import RandomForest object from HELA from ..retrieval import Retrieval # Initialize a model: - r = Retrieval(training_dataset, example_dir, samples_path) + r = Retrieval(training_dataset, example_dir) # Train the random forest: r2scores = r.train(num_trees=1000, num_jobs=1) @@ -18,7 +18,7 @@ def test_linear_end_to_end(): assert abs(r2scores['intercept'] - 1) < 0.01 # Predict posterior distributions from random forest - posterior = r.predict() + posterior = r.predict(data) posterior_slopes, posterior_intercepts = posterior.samples.T # Do a very generous check that the posterior distributions match From 1b5267c3e7772a797839ddd786dd69953797f0ae Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Sun, 1 Dec 2019 10:47:10 +0100 Subject: [PATCH 45/46] Object-oriented refactor, but a bit more explicit --- docs/hela/tutorial.rst | 72 +++++++++---------- example.py | 32 +++++++++ hela/dataset.py | 75 +++++++++---------- hela/plot.py | 16 +++-- hela/posteriors.py | 43 ++++++++++- hela/retrieval.py | 144 ++++++++++++------------------------- hela/tests/test_example.py | 28 ++++---- 7 files changed, 216 insertions(+), 194 deletions(-) create mode 100644 example.py diff --git a/docs/hela/tutorial.rst b/docs/hela/tutorial.rst index 99f1abf..d5c95e4 100644 --- a/docs/hela/tutorial.rst +++ b/docs/hela/tutorial.rst @@ -44,11 +44,11 @@ with the paths to the three files/directories that it needs to know about: .. code-block:: python - from hela import Retrieval + from hela import Retrieval, Dataset import matplotlib.pyplot as plt # Initialize a retrieval model object: - r = Retrieval(training_dataset, example_dir) + r = Retrieval() We now have a Retrieval object ``r`` which is ready for training. We can train the random forest with 1000 trees and on a single processor: @@ -58,25 +58,26 @@ train the random forest with 1000 trees and on a single processor: # Train the random forest: r2scores = r.train(num_trees=1000, num_jobs=1) - # Plot the results: - r.plot_correlations() - plt.show() + # Plot predicted vs real: + fig, ax = r.plot_predicted_vs_real(dataset) .. plot:: - from hela import generate_example_data - # Generate an example dataset directory - training_dataset, example_dir, data = generate_example_data() - - from hela import Retrieval import matplotlib.pyplot as plt + from hela import Retrieval, Dataset, generate_example_data - # Initialize a random forest object: - r = Retrieval(training_dataset, example_dir) + # Create an example dataset + training_dataset, example_dir, example_data = generate_example_data() - # Train the random forest: - r2scores = r.train(num_trees=1000, num_jobs=1) - r.plot_correlations() + # Load the dataset + dataset = Dataset.load_json("linear_dataset/example_dataset.json") + + # Train the model + r = Retrieval() + r2scores = r.train(dataset, num_trees=1000, num_jobs=5) + + # Plot predicted vs real: + fig, ax = r.plot_predicted_vs_real(dataset) plt.show() The `~hela.Retrieval.train` method returns a dictionary called ``r2scores`` @@ -85,37 +86,36 @@ both be close to unity for this example. Finally, let's estimate the posterior distributions for the slope and intercept using the trained random forest on the sample data in ``data``, where -the true values of the slope and intercept are :math:`m=0.2` and :math:`b=0.5` +the true values of the slope and intercept are :math:`m=0.7` and :math:`b=0.5` using the `~hela.Retrieval.predict` method: .. code-block:: python - # Predict posterior distributions from random forest - posterior = r.predict(data) - posterior_slopes, posterior_intercepts = posterior.samples.T + # Predict posterior distribution for slope and intercept of example data + posterior = r.predict(example_data) - # Plot the posteriors - r.plot_posterior(posterior) - plt.show() + # Plot posterior distribution matrix + fig2, ax2 = posterior.plot_posterior_matrix(dataset) .. plot:: - from hela import generate_example_data - # Generate an example dataset directory - training_dataset, example_dir, data = generate_example_data() - - from hela import Retrieval import matplotlib.pyplot as plt + from hela import Retrieval, Dataset, generate_example_data - # Initialize a random forest object: - r = Retrieval(training_dataset, example_dir) + # Create an example dataset + training_dataset, example_dir, example_data = generate_example_data() - r2scores = r.train(num_trees=1000, num_jobs=1) + # Load the dataset + dataset = Dataset.load_json("linear_dataset/example_dataset.json") - # Predict posterior distributions from random forest - posterior = r.predict(data) - posterior_slopes, posterior_intercepts = posterior.samples.T - r.plot_posterior(posterior) - plt.tight_layout() - plt.show() + # Train the model + r = Retrieval() + r2scores = r.train(dataset, num_trees=1000, num_jobs=5) + + # Predict posterior distribution for slope and intercept of example data + posterior = r.predict(example_data) + + # Plot posterior distribution matrix + fig2, ax2 = posterior.plot_posterior_matrix(dataset) + plt.show() \ No newline at end of file diff --git a/example.py b/example.py new file mode 100644 index 0000000..02e743e --- /dev/null +++ b/example.py @@ -0,0 +1,32 @@ +import matplotlib.pyplot as plt +from hela import Retrieval, Dataset, generate_example_data + +# Create an example dataset +training_dataset, example_dir, example_data = generate_example_data() + +# Load the dataset +dataset = Dataset.load_json("linear_dataset/example_dataset.json") + +# Train the model +r = Retrieval() +r2scores = r.train(dataset, num_trees=1000, num_jobs=5) + +# Print OOB score (optional) +print(r.oob) + +# Plot predicted vs real: +fig, ax = r.plot_predicted_vs_real(dataset) + +# Save the model (optional) +# from hela import save_model +# save_model("linear_dataset/model.pkl", r.model) + +# Predict posterior distribution for slope and intercept of example data +posterior = r.predict(example_data) + +# Print posterior ranges (optional) +posterior.print_percentiles(dataset.names) + +fig2, ax2 = posterior.plot_posterior_matrix(dataset) + +plt.show() \ No newline at end of file diff --git a/hela/dataset.py b/hela/dataset.py index bbd9aaa..628afad 100644 --- a/hela/dataset.py +++ b/hela/dataset.py @@ -3,7 +3,7 @@ import numpy as np -__all__ = ["Dataset", "load_dataset", "load_data_file"] +__all__ = ["Dataset", "load_data_file"] class Dataset(object): @@ -31,6 +31,40 @@ def __init__(self, training_x, training_y, testing_x, testing_y, names, self.ranges = ranges self.colors = colors + @classmethod + def load_json(cls, path): + """ + Load a JSON file containing dataset parameters. + + Parameters + ---------- + path : str + Path to the JSON file + """ + with open(path, "r") as f: + dataset_info = json.load(f) + + metadata = dataset_info["metadata"] + + base_path = os.path.dirname(path) + + # Load training data + training_file = os.path.join(base_path, dataset_info["training_data"]) + # Loading training data from '{}'...".format(training_file) + training_x, training_y = load_data_file(training_file, + metadata["num_features"]) + + # Optionally, load testing data + testing_x, testing_y = None, None + if dataset_info["testing_data"] is not None: + testing_file = os.path.join(base_path, dataset_info["testing_data"]) + # Loading testing data from '{}'...".format(testing_file) + testing_x, testing_y = load_data_file(testing_file, + metadata["num_features"]) + + return cls(training_x, training_y, testing_x, testing_y, + metadata["names"], metadata["ranges"], metadata["colors"]) + def load_data_file(data_file, num_features): data = np.load(data_file) @@ -42,42 +76,3 @@ def load_data_file(data_file, num_features): y = data[:, num_features:] return x, y - - -def load_dataset(dataset_file): - """ - Load a dataset from a JSON file. - - Parameters - ---------- - dataset_file - - Returns - ------- - - """ - with open(dataset_file, "r") as f: - dataset_info = json.load(f) - - metadata = dataset_info["metadata"] - - base_path = os.path.dirname(dataset_file) - - # Load training data - training_file = os.path.join(base_path, dataset_info["training_data"]) - # Loading training data from '{}'...".format(training_file) - training_x, training_y = load_data_file(training_file, - metadata["num_features"]) - # TODO: slice training_x (data) and training_y (params) to the same length - # but something smaller for fast docs - - # Optionally, load testing data - testing_x, testing_y = None, None - if dataset_info["testing_data"] is not None: - testing_file = os.path.join(base_path, dataset_info["testing_data"]) - # Loading testing data from '{}'...".format(testing_file) - testing_x, testing_y = load_data_file(testing_file, - metadata["num_features"]) - - return Dataset(training_x, training_y, testing_x, testing_y, - metadata["names"], metadata["ranges"], metadata["colors"]) diff --git a/hela/plot.py b/hela/plot.py index 7c53332..a5ed2c0 100644 --- a/hela/plot.py +++ b/hela/plot.py @@ -18,7 +18,7 @@ POSTERIOR_MAX_SIZE = 10000 -def plot_predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): +def plot_predicted_vs_real(dataset, retrieval, alpha='auto'): """ Plot predicted and real parameter values. @@ -34,6 +34,10 @@ def plot_predicted_vs_real(y_real, y_pred, names, ranges, alpha='auto'): ------- """ + + y_real, y_pred, names, ranges = (dataset.testing_y, retrieval.pred, + dataset.names, dataset.ranges) + num_plots = y_pred.shape[1] num_plot_rows = int(np.sqrt(num_plots)) num_plot_cols = (num_plots - 1) // num_plot_rows + 1 @@ -109,7 +113,7 @@ def plot_feature_importances(forests, names, colors): return fig, axes -def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): +def plot_posterior_matrix(posterior, dataset, soft_colors=None): """ Plot the posterior matrix. @@ -125,6 +129,9 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ------- """ + + names, ranges, colors = dataset.names, dataset.ranges, dataset.colors + cmaps = [LinearSegmentedColormap.from_list("MyReds", [(1, 1, 1), c], N=256) for c in colors] @@ -202,10 +209,7 @@ def plot_posterior_matrix(posterior, names, ranges, colors, soft_colors=None): ax.axis([ranges[dims[0]][0], ranges[dims[0]][1], 0, 1.1 * kd_probs.max()]) - - # fig.tight_layout(pad=0) - - # fig.tight_layout(pad=0) + fig.tight_layout() return fig, axes diff --git a/hela/posteriors.py b/hela/posteriors.py index c2d2aa1..9dc1ecf 100644 --- a/hela/posteriors.py +++ b/hela/posteriors.py @@ -6,6 +6,8 @@ from tqdm import tqdm +from .wpercentile import wpercentile + __all__ = ['PosteriorRandomForest', 'Posterior', 'resample_posterior'] @@ -13,7 +15,6 @@ class PosteriorRandomForest(object): """ Produces posterior samples from a random forest. """ - def __init__(self, num_trees, num_jobs, names, ranges, colors, verbose=1, enable_posterior=True): """ @@ -179,6 +180,46 @@ def __init__(self, samples, weights): self.samples = samples self.weights = weights + def print_percentiles(self, names): + """ + Print the median and the +/- 1 sigma posterior values. + + Parameters + ---------- + names + + Returns + ------- + + """ + posterior_ranges = self.data_ranges() + for name_i, pred_range_i in zip(names, posterior_ranges): + print("Prediction for {}: {:.3g} " + "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) + + + def data_ranges(self, percentiles=(50, 16, 84)): + """ + Return posterior ranges. + + Parameters + ---------- + posterior : `~numpy.ndarray` + percentiles : tuple + + Returns + ------- + ranges : `~numpy.ndarray` + """ + values = wpercentile(self.samples, self.weights, + percentiles, axis=0) + ranges = np.array( + [values[0], values[2] - values[0], values[0] - values[1]]) + return ranges.T + + def plot_posterior_matrix(self, dataset): + from .plot import plot_posterior_matrix + return plot_posterior_matrix(self, dataset) def _posterior(data_leaves, data_weights, data_y, query_leaves): weights_x = (query_leaves[:, None] == data_leaves) * data_weights diff --git a/hela/retrieval.py b/hela/retrieval.py index d27c777..8729581 100644 --- a/hela/retrieval.py +++ b/hela/retrieval.py @@ -5,13 +5,18 @@ from sklearn import metrics, multioutput import joblib -from .dataset import load_dataset from .posteriors import PosteriorRandomForest -from .plot import (plot_predicted_vs_real, plot_feature_importances, - plot_posterior_matrix) -from .wpercentile import wpercentile +from .plot import plot_predicted_vs_real, plot_feature_importances -__all__ = ['Retrieval', 'generate_example_data'] +__all__ = ['Retrieval', 'generate_example_data', 'save_model', 'load_model'] + + +def save_model(path, model, **kwargs): + joblib.dump(model, path, **kwargs) + + +def load_model(path, **kwargs): + joblib.load(path, **kwargs) def train_model(dataset, num_trees, num_jobs, verbose=1): @@ -52,31 +57,19 @@ class Retrieval(object): A class for a trainable random forest model. """ - def __init__(self, training_dataset, model_path): - """ - Parameters - ---------- - training_dataset : str - Path to the dataset metadata JSON file - model_path : str - Path to the output directory to create and populate - """ - self.training_dataset = training_dataset - self.model_path = model_path - self.output_path = self.model_path - - self.dataset = None + def __init__(self): self.model = None self._feature_importance = None self.oob = None self.pred = None - def train(self, num_trees=1000, num_jobs=5, quiet=False): + def train(self, dataset, num_trees=1000, num_jobs=5, quiet=False): """ Train the random forest on a set of observations. Parameters ---------- + dataset : `~hela.Dataset` num_trees : int num_jobs : int quiet : bool @@ -86,26 +79,17 @@ def train(self, num_trees=1000, num_jobs=5, quiet=False): r2scores : dict :math:`R^2` values for each parameter after training """ - # Loading dataset - self.dataset = load_dataset(self.training_dataset) - # Training model - self.model = train_model(self.dataset, num_trees, num_jobs, not quiet) - - os.makedirs(self.model_path, exist_ok=True) - model_file = os.path.join(self.model_path, "model.pkl") - - # Saving model - joblib.dump(self.model, model_file) + self.model = train_model(dataset, num_trees, num_jobs, not quiet) # saving model information... self.oob = self.model.rf.oob_score_ - pred, r2scores = test_model(self.model, self.dataset) + pred, r2scores = test_model(self.model, dataset) self.pred = pred return r2scores - def plot_correlations(self): + def plot_predicted_vs_real(self, dataset): """ Plot training correlations. @@ -113,41 +97,47 @@ def plot_correlations(self): ------- fig, axes """ - fig, axes = plot_predicted_vs_real(self.dataset.testing_y, self.pred, - self.dataset.names, - self.dataset.ranges) + fig, axes = plot_predicted_vs_real(dataset, self) return fig, axes - def feature_importance(self): + def feature_importance(self, dataset): """ Compute feature importance. + Parameters + ---------- + dataset : `~hela.Dataset` + Returns ------- feature_importances : `~numpy.ndarray` """ if self._feature_importance is None: self._feature_importance = compute_feature_importance(self.model, - self.dataset) + dataset) return self._feature_importance - def plot_feature_importance(self): + def plot_feature_importance(self, dataset): """ Plot the feature importances. + Parameters + ---------- + dataset : `~hela.Dataset` + Returns ------- fig, axes """ forests = self.feature_importance() fig, axes = plot_feature_importances(forests=forests, - names=(self.dataset.names + + names=(dataset.names + ["joint prediction"]), - colors=(self.dataset.colors + + colors=(dataset.colors + ["C0"])) return fig, axes - def predict(self, x, quiet=False): + def predict(self, x): """ Predict values from the trained random forest. @@ -155,8 +145,6 @@ def predict(self, x, quiet=False): ---------- x : `~numpy.ndarray` - plot_posterior : bool - Returns ------- preds : `~numpy.ndarray` @@ -164,58 +152,10 @@ def predict(self, x, quiet=False): number of samples/trees (check out attributes of model for metadata) """ - model_file = os.path.join(self.model_path, "model.pkl") - # Loading random forest from '{}'...".format(model_file) - model = joblib.load(model_file) - - posterior = model.predict_posterior(x) - - if not quiet: - posterior_ranges = data_ranges(posterior) - for name_i, pred_range_i in zip(model.names, posterior_ranges): - print("Prediction for {}: {:.3g} " - "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) + posterior = self.model.predict_posterior(x) return posterior - def plot_posterior(self, posterior): - """ - Plot the posterior distributions for each parameter. - - Returns - ------- - fig, axes - """ - model_file = os.path.join(self.model_path, "model.pkl") - # Loading random forest from '{}'...".format(model_file) - model = joblib.load(model_file) - - fig, axes = plot_posterior_matrix(posterior, - names=model.names, - ranges=model.ranges, - colors=model.colors) - return fig, axes - - -def data_ranges(posterior, percentiles=(50, 16, 84)): - """ - Return posterior ranges. - - Parameters - ---------- - posterior : `~numpy.ndarray` - percentiles : tuple - - Returns - ------- - ranges : `~numpy.ndarray` - """ - values = wpercentile(posterior.samples, posterior.weights, - percentiles, axis=0) - ranges = np.array( - [values[0], values[2] - values[0], values[0] - values[1]]) - return ranges.T - def generate_example_data(): """ @@ -252,24 +192,32 @@ def generate_example_data(): # Generate fake training data npoints = 1000 - slopes = np.random.rand(npoints) - ints = np.random.rand(npoints) - x = np.linspace(0, 1, 1000)[:, np.newaxis] - data = slopes * x + ints + slopes = np.random.uniform(size=npoints) + ints = np.random.uniform(size=npoints) + x = np.linspace(0, 1, 1000)[:, None] + + # Add correlated noise to parameters to introduce degeneracies + noise_ints = np.random.normal(scale=0.15, size=npoints) + noise_slopes = (np.abs(noise_ints) + + np.random.normal(scale=0.02, size=npoints)) + + # Add also noise to data points (not strictly necessary) + data = ((slopes + noise_slopes) * x + (ints + noise_ints) + + np.random.normal(scale=0.01, size=(1000, npoints))) labels = np.vstack([slopes, ints]) X = np.vstack([data, labels]) # Split dataset into training and testing segments training = X[:, :int(0.8 * npoints)].T - testing = X[:, int(-0.2 * npoints):].T + testing = X[:, int(0.8 * npoints):].T np.save(os.path.join(example_dir, 'training.npy'), training) np.save(os.path.join(example_dir, 'testing.npy'), testing) # Generate a bunch of samples with a test value to "retrieve" with the # random forest: - true_slope = 0.2 + true_slope = 0.7 true_intercept = 0.5 samples = true_slope * x + true_intercept diff --git a/hela/tests/test_example.py b/hela/tests/test_example.py index cadd791..4d8c9a9 100644 --- a/hela/tests/test_example.py +++ b/hela/tests/test_example.py @@ -1,28 +1,30 @@ def test_linear_end_to_end(): - from ..retrieval import generate_example_data - training_dataset, example_dir, data = generate_example_data() + from ..retrieval import Retrieval, generate_example_data + from ..dataset import Dataset - # Import RandomForest object from HELA - from ..retrieval import Retrieval + # Create an example dataset + training_dataset, example_dir, example_data = generate_example_data() - # Initialize a model: - r = Retrieval(training_dataset, example_dir) + # Load the dataset + dataset = Dataset.load_json("linear_dataset/example_dataset.json") - # Train the random forest: - r2scores = r.train(num_trees=1000, num_jobs=1) + # Train the model + r = Retrieval() + r2scores = r.train(dataset, num_trees=1000, num_jobs=5) + + # Predict posterior distribution for slope and intercept of example data + posterior = r.predict(example_data) # Do a rough check that the R^2 values are near unity - assert abs(r2scores['slope'] - 1) < 0.01 - assert abs(r2scores['intercept'] - 1) < 0.01 + assert abs(r2scores['slope'] - 1) < 0.3 + assert abs(r2scores['intercept'] - 1) < 0.3 - # Predict posterior distributions from random forest - posterior = r.predict(data) posterior_slopes, posterior_intercepts = posterior.samples.T # Do a very generous check that the posterior distributions match # the expected values - assert abs(posterior_slopes.mean() - 0.2) < 3 * posterior_slopes.std() + assert abs(posterior_slopes.mean() - 0.7) < 3 * posterior_slopes.std() assert (abs(posterior_intercepts.mean() - 0.5) < 3 * posterior_intercepts.std()) From 8b514042a693a8d404fdc984d6dcd60e383d80cc Mon Sep 17 00:00:00 2001 From: Brett Morris Date: Sun, 1 Dec 2019 10:53:04 +0100 Subject: [PATCH 46/46] minor flake8 (formatting) fixes --- hela/dataset.py | 3 ++- hela/posteriors.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hela/dataset.py b/hela/dataset.py index 628afad..5e935a4 100644 --- a/hela/dataset.py +++ b/hela/dataset.py @@ -57,7 +57,8 @@ def load_json(cls, path): # Optionally, load testing data testing_x, testing_y = None, None if dataset_info["testing_data"] is not None: - testing_file = os.path.join(base_path, dataset_info["testing_data"]) + testing_file = os.path.join(base_path, + dataset_info["testing_data"]) # Loading testing data from '{}'...".format(testing_file) testing_x, testing_y = load_data_file(testing_file, metadata["num_features"]) diff --git a/hela/posteriors.py b/hela/posteriors.py index 9dc1ecf..1f87ad0 100644 --- a/hela/posteriors.py +++ b/hela/posteriors.py @@ -197,7 +197,6 @@ def print_percentiles(self, names): print("Prediction for {}: {:.3g} " "[+{:.3g} -{:.3g}]".format(name_i, *pred_range_i)) - def data_ranges(self, percentiles=(50, 16, 84)): """ Return posterior ranges. @@ -221,6 +220,7 @@ def plot_posterior_matrix(self, dataset): from .plot import plot_posterior_matrix return plot_posterior_matrix(self, dataset) + def _posterior(data_leaves, data_weights, data_y, query_leaves): weights_x = (query_leaves[:, None] == data_leaves) * data_weights weights_x = _as_smallest_udtype(weights_x.sum(0))