diff --git a/build.gradle b/build.gradle index f816fc2d62c5..9f34251638c6 100644 --- a/build.gradle +++ b/build.gradle @@ -155,7 +155,8 @@ ext { pythonProjects = [ project(':h2o-py'), - project(':h2o-py-cloud-extensions') + project(':h2o-py-cloud-extensions'), + project(':h2o-py-mlflow-flavor') ] // The project which need to be run under CI only diff --git a/h2o-py-mlflow-flavor/README.rst b/h2o-py-mlflow-flavor/README.rst new file mode 100644 index 000000000000..c4defa766905 --- /dev/null +++ b/h2o-py-mlflow-flavor/README.rst @@ -0,0 +1,110 @@ +H2O-3 MLFlow Flavor +=================== + +A tiny library containing a `MLFlow `_ flavor for working with H2O-3 MOJO and POJO models. + +Logging Models to MLFlow Registry +--------------------------------- + +The model that was trained with H2O-3 runtime can be exported to MLFlow registry with `log_model` function.: + +.. code-block:: Python + + import mlflow + import h2o_mlflow_flavor + + mlflow.set_tracking_uri("http://127.0.0.1:8080") + + h2o_model = ... training phase ... + + with mlflow.start_run(run_name="myrun") as run: + h2o_mlflow_flavor.log_model(h2o_model=h2o_model, + artifact_path="folder", + model_type="MOJO", + extra_prediction_args=["--predictCalibrated"]) + + +Compared to `log_model` functions of the other flavors being a part of MLFlow, this function has two extra arguments: + +* ``model_type`` - It indicates whether the model should be exported as `MOJO `_ or `POJO `_. The default value is `MOJO`. + +* ``extra_prediction_args`` - A list of extra arguments for java scoring process. Possible values: + + * ``--setConvertInvalidNum`` - The scoring process will convert invalid numbers to NA. + + * ``--predictContributions`` - The scoring process will Return also Shapley values a long with the predictions. Model must support that Shapley values, otherwise scoring process will throw an error. + + * ``--predictCalibrated`` - The scoring process will also return calibrated prediction values. + +The `save_model` function that persists h2o binary model to MOJO or POJO has the same signature as the `log_model` function. + +Extracting Information about Model +---------------------------------- + +The flavor offers several functions to extract information about the model. + +* ``get_metrics(h2o_model, metric_type=None)`` - Extracts metrics from the trained H2O binary model. It returns dictionary and takes following parameters: + + * ``h2o_model`` - An H2O binary model. + + * ``metric_type`` - The type of metrics. Possible values are "training", "validation", "cross_validation". If parameter is not specified, metrics for all types are returned. + +* ``get_params(h2o_model)`` - Extracts training parameters for the H2O binary model. It returns dictionary and expects one parameter: + + * ``h2o_model`` - An H2O binary model. + +* ``get_input_example(h2o_model, number_of_records=5, relevant_columns_only=True)`` - Creates an example Pandas dataset from the training dataset of H2O binary model. It takes following parameters: + + * ``h2o_model`` - An H2O binary model. + + * ``number_of_records`` - A number of records that will be extracted from the training dataset. + + * ``relevant_columns_only`` - A flag indicating whether the output dataset should contain only columns required by the model. Defaults to ``True``. + +The functions can be utilized as follows: + +.. code-block:: Python + + import mlflow + import h2o_mlflow_flavor + + mlflow.set_tracking_uri("http://127.0.0.1:8080") + + h2o_model = ... training phase ... + + with mlflow.start_run(run_name="myrun") as run: + mlflow.log_params(h2o_mlflow_flavor.get_params(h2o_model)) + mlflow.log_metrics(h2o_mlflow_flavor.get_metrics(h2o_model)) + input_example = h2o_mlflow_flavor.get_input_example(h2o_model) + h2o_mlflow_flavor.log_model(h2o_model=h2o_model, + input_example=input_example, + artifact_path="folder", + model_type="MOJO", + extra_prediction_args=["--predictCalibrated"]) + + +Model Scoring +------------- + +After a model obtained from the model registry, the model doesn't require h2o runtime for ability to score. The only thing +that model requires is a ``h2o-gemodel.jar`` which was persisted with the model during saving procedure. +The model could be loaded by the function ``load_model(model_uri, dst_path=None)``. It returns an objecting making +predictions on Pandas dataframe and takes the following parameters: + +* ``model_uri`` - An unique identifier of the model within MLFlow registry. + +* ``dst_path`` - (Optional) A local filesystem path for downloading the persisted form of the model. + +The object for scoring could be obtained also via the `pyfunc` flavor as follows: + +.. code-block:: Python + + import mlflow + mlflow.set_tracking_uri("http://127.0.0.1:8080") + + logged_model = 'runs:/9a42265cf0ef484c905b02afb8fe6246/iris' + loaded_model = mlflow.pyfunc.load_model(logged_model) + + import pandas as pd + data = pd.read_csv("http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv") + loaded_model.predict(data) diff --git a/h2o-py-mlflow-flavor/build.gradle b/h2o-py-mlflow-flavor/build.gradle new file mode 100644 index 000000000000..1edc3ccab43e --- /dev/null +++ b/h2o-py-mlflow-flavor/build.gradle @@ -0,0 +1,63 @@ +description = "H2O-3 MLFlow Flavor" + +dependencies {} + +def buildVersion = new H2OBuildVersion(rootDir, version) + +ext { + PROJECT_VERSION = buildVersion.getProjectVersion() + pythonexe = findProperty("pythonExec") ?: "python" + pipexe = findProperty("pipExec") ?: "pip" + if (System.env.VIRTUAL_ENV) { + pythonexe = "${System.env.VIRTUAL_ENV}/bin/python".toString() + pipexe = "${System.env.VIRTUAL_ENV}/bin/pip".toString() + } + testsPath = file("tests") +} + +task copySrcFiles(type: Copy) { + from ("${projectDir}") { + include "setup.py" + include "setup.cfg" + include "h2o_mlflow_flavor/**" + include "README.rst" + } + into "${buildDir}" +} + +task buildDist(type: Exec, dependsOn: [copySrcFiles]) { + workingDir buildDir + doFirst { + file("${buildDir}/tmp").mkdirs() + standardOutput = new FileOutputStream(file("${buildDir}/tmp/h2o_mlflow_flavor_buildDist.out")) + } + commandLine getOsSpecificCommandLine([pythonexe, "setup.py", "bdist_wheel"]) +} + +task copyMainDist(type: Copy, dependsOn: [buildDist]) { + from ("${buildDir}/main/") { + include "dist/**" + } + into "${buildDir}" +} + +task pythonVersion(type: Exec) { + doFirst { + println(System.env.VIRTUAL_ENV) + println(environment) + } + commandLine getOsSpecificCommandLine([pythonexe, "--version"]) +} + +task cleanBuild(type: Delete) { + doFirst { + println "Cleaning..." + } + delete file("build/") +} + +// +// Define the dependencies +// +clean.dependsOn cleanBuild +build.dependsOn copyMainDist diff --git a/h2o-py-mlflow-flavor/examples/DRF_mojo.ipynb b/h2o-py-mlflow-flavor/examples/DRF_mojo.ipynb new file mode 100644 index 000000000000..7327c94f9a0d --- /dev/null +++ b/h2o-py-mlflow-flavor/examples/DRF_mojo.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ded5553", + "metadata": {}, + "outputs": [], + "source": [ + "# Start H2O-3 runtime.\n", + "\n", + "import h2o\n", + "h2o.init(strict_version_check=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e746ad4", + "metadata": {}, + "outputs": [], + "source": [ + "# Configure DRF algorithm and train a model.\n", + "\n", + "from h2o.estimators import H2ORandomForestEstimator\n", + "\n", + "# Import the cars dataset into H2O:\n", + "cars = h2o.import_file(\"https://s3.amazonaws.com/h2o-public-test-data/smalldata/junit/cars_20mpg.csv\")\n", + "\n", + "# Set the predictors and response;\n", + "# set the response as a factor:\n", + "cars[\"economy_20mpg\"] = cars[\"economy_20mpg\"].asfactor()\n", + "predictors = [\"displacement\",\"power\",\"weight\",\"acceleration\",\"year\"]\n", + "response = \"economy_20mpg\"\n", + "\n", + "# Split the dataset into a train and valid set:\n", + "train, valid = cars.split_frame(ratios=[.8], seed=1234)\n", + "drf = H2ORandomForestEstimator(ntrees=10,\n", + " max_depth=5,\n", + " min_rows=10,\n", + " calibrate_model=True,\n", + " calibration_frame=valid,\n", + " binomial_double_trees=True)\n", + "drf.train(x=predictors,\n", + " y=response,\n", + " training_frame=train,\n", + " validation_frame=valid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29eb0722", + "metadata": {}, + "outputs": [], + "source": [ + "# Log the model to an MLFlow reqistry.\n", + "\n", + "import mlflow\n", + "import h2o_mlflow_flavor\n", + "mlflow.set_tracking_uri(\"http://127.0.0.1:8080\")\n", + "\n", + "with mlflow.start_run(run_name=\"cars\") as run:\n", + " mlflow.log_params(h2o_mlflow_flavor.get_params(drf)) # Log training parameters of the model (optional).\n", + " mlflow.log_metrics(h2o_mlflow_flavor.get_metrics(drf)) # Log performance matrics of the model (optional).\n", + " input_example = h2o_mlflow_flavor.get_input_example(drf) # Extract input example from training dataset (optional)\n", + " h2o_mlflow_flavor.log_model(drf, \"cars\", input_example=input_example,\n", + " model_type=\"MOJO\", # Specify whether the output model should be MOJO or POJO. (MOJO is default)\n", + " extra_prediction_args=[\"--predictCalibrated\"]) # Add extra prediction args if needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bed1dafe", + "metadata": {}, + "outputs": [], + "source": [ + "# Load model from the MLFlow registry and score with the model.\n", + "\n", + "import mlflow\n", + "mlflow.set_tracking_uri(\"http://127.0.0.1:8080\")\n", + "\n", + "logged_model = 'runs:/a9ff364f07fa499eb44e7c49e47fab11/cars' # Specify correct id of your run.\n", + "\n", + "# Load model as a PyFuncModel.\n", + "loaded_model = mlflow.pyfunc.load_model(logged_model)\n", + "\n", + "# Predict on a Pandas DataFrame.\n", + "import pandas as pd\n", + "data = pd.read_csv(\"https://s3.amazonaws.com/h2o-public-test-data/smalldata/junit/cars_20mpg.csv\")\n", + "loaded_model.predict(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "905b0c4c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlflow", + "language": "python", + "name": "mlflow" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/h2o-py-mlflow-flavor/examples/KMeans_pojo.ipynb b/h2o-py-mlflow-flavor/examples/KMeans_pojo.ipynb new file mode 100644 index 000000000000..e83f909b085b --- /dev/null +++ b/h2o-py-mlflow-flavor/examples/KMeans_pojo.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ded5553", + "metadata": {}, + "outputs": [], + "source": [ + "# Start H2O-3 runtime.\n", + "\n", + "import h2o\n", + "h2o.init(strict_version_check=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e746ad4", + "metadata": {}, + "outputs": [], + "source": [ + "# Configure K-Means algorithm and train a model.\n", + "\n", + "from h2o.estimators import H2OKMeansEstimator\n", + "\n", + "# Import the iris dataset into H2O:\n", + "iris = h2o.import_file(\"http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv\")\n", + "\n", + "# Set the predictors:\n", + "predictors = [\"sepal_len\", \"sepal_wid\", \"petal_len\", \"petal_wid\"]\n", + "\n", + "# Split the dataset into a train and valid set:\n", + "train, valid = iris.split_frame(ratios=[.8], seed=1234)\n", + "\n", + "# Build and train the model:\n", + "kmeans = H2OKMeansEstimator(k=10,\n", + " estimate_k=True,\n", + " standardize=False,\n", + " seed=1234)\n", + "kmeans.train(x=predictors,\n", + " training_frame=train,\n", + " validation_frame=valid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29eb0722", + "metadata": {}, + "outputs": [], + "source": [ + "# Log the model to an MLFlow reqistry.\n", + "\n", + "import mlflow\n", + "import h2o_mlflow_flavor\n", + "mlflow.set_tracking_uri(\"http://127.0.0.1:8080\")\n", + "\n", + "with mlflow.start_run(run_name=\"iris\") as run:\n", + " mlflow.log_params(h2o_mlflow_flavor.get_params(kmeans)) # Log training parameters of the model (optional).\n", + " mlflow.log_metrics(h2o_mlflow_flavor.get_metrics(kmeans)) # Log performance matrics of the model (optional).\n", + " input_example = h2o_mlflow_flavor.get_input_example(kmeans) # Extract input example from training dataset (optional)\n", + " h2o_mlflow_flavor.log_model(kmeans, \"iris\", input_example=input_example,\n", + " model_type=\"POJO\") # Specify whether the output model should be MOJO or POJO. (MOJO is default)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bed1dafe", + "metadata": {}, + "outputs": [], + "source": [ + "# Load model from the MLFlow registry and score with the model.\n", + "\n", + "import mlflow\n", + "mlflow.set_tracking_uri(\"http://127.0.0.1:8080\")\n", + "\n", + "logged_model = 'runs:/9a42265cf0ef484c905b02afb8fe6246/iris' # Specify correct id of your run.\n", + "\n", + "# Load model as a PyFuncModel.\n", + "loaded_model = mlflow.pyfunc.load_model(logged_model)\n", + "\n", + "# Predict on a Pandas DataFrame.\n", + "import pandas as pd\n", + "data = pd.read_csv(\"http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv\")\n", + "loaded_model.predict(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "905b0c4c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlflow", + "language": "python", + "name": "mlflow" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/h2o-py-mlflow-flavor/h2o_mlflow_flavor/__init__.py b/h2o-py-mlflow-flavor/h2o_mlflow_flavor/__init__.py new file mode 100644 index 000000000000..4d6aa93bf459 --- /dev/null +++ b/h2o-py-mlflow-flavor/h2o_mlflow_flavor/__init__.py @@ -0,0 +1,384 @@ +""" +The `h2o_mlflow_flavor` module provides an API for working with H2O MOJO and POJO models. +""" + +import logging +import os +import tempfile +import pandas +import subprocess +import sys + +import yaml + +import mlflow +from mlflow import pyfunc +from mlflow.models import Model +from mlflow.models.model import MLMODEL_FILE_NAME +from mlflow.models.utils import _save_example +from mlflow.models import ModelSignature, ModelInputExample +from mlflow.utils.environment import ( + _CONDA_ENV_FILE_NAME, + _CONSTRAINTS_FILE_NAME, + _PYTHON_ENV_FILE_NAME, + _REQUIREMENTS_FILE_NAME, + _mlflow_conda_env, + _process_conda_env, + _process_pip_requirements, + _PythonEnv, + _validate_env_arguments, +) +from mlflow.utils.file_utils import write_to +from mlflow.utils.model_utils import ( + _get_flavor_configuration, + _validate_and_copy_code_paths, + _validate_and_prepare_target_save_path, +) +from mlflow.tracking.artifact_utils import _download_artifact_from_uri +from mlflow.utils.requirements_utils import _get_pinned_requirement + +_logger = logging.getLogger(__name__) + +FLAVOR_NAME = "h2o_mojo_pojo" + + +def get_default_pip_requirements(): + """ + :return: A list of default pip requirements for MLflow Models produced by this flavor. + Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment + that, at minimum, contains these requirements. + """ + return [_get_pinned_requirement("h2o_mlflow_flavor")] + + +def get_default_conda_env(): + """ + :return: The default Conda environment for MLflow Models produced by calls to + :func:`save_model()` and :func:`log_model()`. + """ + return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) + + +def get_params(h2o_model): + """ + Extracts training parameters for the H2O binary model. + + :param h2o_model: An H2O binary model. + :return: A dictionary of parameters that were used for training the model. + """ + def is_valid(key): + return key != "model_id" and \ + not key.endswith("_frame") and \ + not key.startswith("keep_cross_validation_") + + return {key: val for key, val in h2o_model.actual_params.items() if is_valid(key)} + + +def get_metrics(h2o_model, metric_type=None): + """ + Extracts metrics from the H2O binary model. + + :param h2o_model: An H2O binary model. + :param metric_type: The type of metrics. Possible values are "training", "validation", "cross_validation". + If parameter is not specified, metrics for all types are returned. + :return: A dictionary of model metrics. + """ + def get_metrics_section(output, prefix, metric_type): + is_valid = lambda key, val: isinstance(val, (bool, float, int)) and not str(key).endswith("checksum") + items = output[metric_type]._metric_json.items() + dictionary = dict(items) + if dictionary["custom_metric_name"] is None: + del dictionary["custom_metric_value"] + return {prefix + str(key): val for key, val in dictionary.items() if is_valid(key, val)} + + metric_type_lower = None + if metric_type: + metric_type_lower = metric_type.toLowerCase() + + output = h2o_model._model_json["output"] + metrics = {} + + if output["training_metrics"] and (metric_type_lower is None or metric_type_lower == "training"): + training_metrics = get_metrics_section(output, "training_", "training_metrics") + metrics = dict(metrics, **training_metrics) + if output["validation_metrics"] and (metric_type_lower is None or metric_type_lower == "validation"): + validation_metrics = get_metrics_section(output, "validation_", "validation_metrics") + metrics = dict(metrics, **validation_metrics) + if output["cross_validation_metrics"] and ( + metric_type_lower is None or metric_type_lower in ["cv", "cross_validation"]): + cross_validation_metrics = get_metrics_section(output, "cv_", "cross_validation_metrics") + metrics = dict(metrics, **cross_validation_metrics) + + return metrics + + +def get_input_example(h2o_model, number_of_records=5, relevant_columns_only=True): + """ + Creates an example Pandas dataset from the training dataset of H2O binary model. + + :param h2o_model: An H2O binary model. + :param number_of_records: A number of records that will be extracted from the training dataset. + :param relevant_columns_only: A flag indicating whether the output dataset should contain + only columns required by the model. Defaults to ``True``. + :return: Pandas dataset made from the training dataset of H2O binary model + """ + + import h2o + frame = h2o.get_frame(h2o_model.actual_params["training_frame"]).head(number_of_records) + result = frame.as_data_frame() + if relevant_columns_only: + relevant_columns = _get_relevant_columns(h2o_model) + input_columns = [col for col in frame.col_names if col in relevant_columns] + return result[input_columns] + else: + return result + + +def _get_relevant_columns(h2o_model): + names = h2o_model._model_json["output"]["original_names"] or h2o_model._model_json["output"]["names"] + response_column = h2o_model.actual_params.get("response_column") + ignored_columns = h2o_model.actual_params.get("ignored_columns") or [] + irrelevant_columns = ignored_columns + [response_column] if response_column else ignored_columns + relevant_columns = [feature for feature in names if feature not in irrelevant_columns] + return relevant_columns + +def save_model( + h2o_model, + path, + conda_env=None, + code_paths=None, + mlflow_model=None, + signature=None, + input_example=None, + pip_requirements=None, + extra_pip_requirements=None, + model_type="MOJO", + extra_prediction_args=[] +): + """ + Saves an H2O binary model to a path on the local file system in MOJO or POJO format. + + :param h2o_model: H2O binary model to be saved to MOJO or POJO. + :param path: Local path where the model is to be saved. + :param conda_env: {{ conda_env }} + :param code_paths: A list of local filesystem paths to Python file dependencies (or directories + containing file dependencies). These files are *prepended* to the system + path when the model is loaded. + :param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to. + + :param signature: {{ signature }} + :param input_example: {{ input_example }} + :param pip_requirements: {{ pip_requirements }} + :param extra_pip_requirements: {{ extra_pip_requirements }} + :param model_type: A flag deciding whether the model is MOJO or POJO. + :param extra_prediction_args: A list of extra arguments for java predictions process. Possible values: + --setConvertInvalidNum - Converts invalid numbers to NA + --predictContributions - Returns also Shapley values a long with the predictions + --predictCalibrated - Return also calibrated prediction values. + """ + + import h2o + model_type_upper = model_type.upper() + if model_type_upper != "MOJO" and model_type_upper != "POJO": + raise ValueError(f"The `model_type` parameter must be 'MOJO' or 'POJO'. The passed value was '{model_type}'.") + + _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) + _validate_and_prepare_target_save_path(path) + code_dir_subpath = _validate_and_copy_code_paths(code_paths, path) + + if model_type_upper == "MOJO": + model_data_path = h2o_model.download_mojo(path=path, get_genmodel_jar=True) + model_file = os.path.basename(model_data_path) + else: + model_data_path = h2o_model.download_pojo(path=path, get_genmodel_jar=True) + h2o_genmodel_jar = os.path.join(path, "h2o-genmodel.jar") + output_path = os.path.join(path, "classes") + javac_cmd = ["javac", "-cp", h2o_genmodel_jar, "-d", output_path, "-J-Xmx12g", model_data_path] + subprocess.check_call(javac_cmd) + model_file = os.path.basename(model_data_path).replace(".java", "") + + if mlflow_model is None: + mlflow_model = Model() + if signature is not None: + mlflow_model.signature = signature + if input_example is not None: + _save_example(mlflow_model, input_example, path) + + pyfunc.add_to_model( + mlflow_model, + loader_module="h2o_mlflow_flavor", + model_path=model_file, + conda_env=_CONDA_ENV_FILE_NAME, + python_env=_PYTHON_ENV_FILE_NAME, + code=code_dir_subpath, + ) + + mlflow_model.add_flavor( + FLAVOR_NAME, + model_file=model_file, + model_type=model_type_upper, + extra_prediction_args=extra_prediction_args, + relevant_columns=_get_relevant_columns(h2o_model), + h2o_version=h2o.__version__, + code=code_dir_subpath, + ) + mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) + + if conda_env is None: + if pip_requirements is None: + default_reqs = get_default_pip_requirements() + inferred_reqs = mlflow.models.infer_pip_requirements( + path, FLAVOR_NAME, fallback=default_reqs + ) + default_reqs = sorted(set(inferred_reqs).union(default_reqs)) + else: + default_reqs = None + conda_env, pip_requirements, pip_constraints = _process_pip_requirements( + default_reqs, pip_requirements, extra_pip_requirements + ) + else: + conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) + + with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: + yaml.safe_dump(conda_env, stream=f, default_flow_style=False) + + if pip_constraints: + write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) + + write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) + + _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) + + +def log_model( + h2o_model, + artifact_path, + conda_env=None, + code_paths=None, + registered_model_name=None, + signature: ModelSignature = None, + input_example: ModelInputExample = None, + pip_requirements=None, + extra_pip_requirements=None, + model_type="MOJO", + extra_prediction_args=[], + **kwargs, +): + """ + Logs an H2O model as an MLflow artifact for the current run. + + :param h2o_model: H2O model to be saved. + :param artifact_path: Run-relative artifact path. + :param conda_env: {{ conda_env }} + :param code_paths: A list of local filesystem paths to Python file dependencies (or directories + containing file dependencies). These files are *prepended* to the system + path when the model is loaded. + :param registered_model_name: If given, create a model version under + ``registered_model_name``, also creating a registered model if one + with the given name does not exist. + + :param signature: {{ signature }} + :param input_example: {{ input_example }} + :param pip_requirements: {{ pip_requirements }} + :param extra_pip_requirements: {{ extra_pip_requirements }} + :param model_type: A flag deciding whether the model is MOJO or POJO. + :param extra_prediction_args: A list of extra arguments for java scoring process. Possible values: + --setConvertInvalidNum - Converts invalid numbers to NA + --predictContributions - Returns also Shapley values a long with the predictions + --predictCalibrated - Return also calibrated prediction values. + :param kwargs: kwargs to pass to ``h2o.save_model`` method. + :return: A :py:class:`ModelInfo ` instance that contains the + metadata of the logged model. + """ + import h2o_mlflow_flavor + return Model.log( + artifact_path=artifact_path, + flavor=h2o_mlflow_flavor, + registered_model_name=registered_model_name, + h2o_model=h2o_model, + conda_env=conda_env, + code_paths=code_paths, + signature=signature, + input_example=input_example, + pip_requirements=pip_requirements, + extra_pip_requirements=extra_pip_requirements, + model_type=model_type, + extra_prediction_args=extra_prediction_args, + **kwargs, + ) + + +def load_model(model_uri, dst_path=None): + """ + Obtains a model from MLFlow registry. + :param model_uri: An unique identifier of the model within MLFlow registry. + :param dst_path: (Optional) A temporary folder for downloading the persisted form of the model. + :return: A model making predictions on Pandas dataframe. + """ + path = _download_artifact_from_uri( + artifact_uri=model_uri, output_path=dst_path + ) + return _load_model(path) + + +def _load_model(path): + flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME) + model_type = flavor_conf["model_type"] + model_file = flavor_conf["model_file"] + extra_prediction_args = flavor_conf["extra_prediction_args"] + relevant_columns = flavor_conf["relevant_columns"] + return _H2OModelWrapper(model_file, model_type, path, extra_prediction_args, relevant_columns) + + +class _H2OModelWrapper: + def __init__(self, model_file, model_type, path, extra_prediction_args, relevant_columns): + self.model_file = model_file + self.model_type = model_type + self.path = path + self.extra_prediction_args = extra_prediction_args if extra_prediction_args is not None else [] + self.relevant_columns = relevant_columns + self.genmodel_jar_path = os.path.join(path, "h2o-genmodel.jar") + + def predict(self, dataframe, params=None): + """ + :param dataframe: Model input data. + :param params: Additional parameters to pass to the model for inference. + + :return: Model predictions. + """ + with tempfile.TemporaryDirectory() as tempdir: + input_file = os.path.join(tempdir, "input.csv") + output_file = os.path.join(tempdir, "output.csv") + separator = "`" + import csv + sub_dataframe = dataframe[self.relevant_columns] + sub_dataframe.to_csv(input_file, index=False, quoting=csv.QUOTE_NONNUMERIC, sep=separator) + if self.model_type == "MOJO": + class_path = self.genmodel_jar_path + type_parameter = "--mojo" + model_artefact = os.path.join(self.path, self.model_file) + else: + class_path_separator = ";" if sys.platform == "win32" else ":" + class_path = self.genmodel_jar_path + class_path_separator + os.path.join(self.path, "classes") + type_parameter = "--pojo" + model_artefact = self.model_file.replace(".class", "") + + java_cmd = ["java", "-cp", class_path, + "-ea", "-Xmx12g", "-XX:ReservedCodeCacheSize=256m", + "hex.genmodel.tools.PredictCsv", "--separator", separator, + "--input", input_file, "--output", output_file, type_parameter, model_artefact, "--decimal"] + java_cmd += self.extra_prediction_args + ret = subprocess.call(java_cmd) + assert ret == 0, "GenModel finished with return code %d." % ret + predicted = pandas.read_csv(output_file) + predicted.index = dataframe.index + return predicted + + +def _load_pyfunc(path): + """ + Load PyFunc implementation. Called by ``pyfunc.load_model``. + + :param path: Local filesystem path to the MLflow Model with the ``h2o`` flavor. + """ + return _load_model(path) diff --git a/h2o-py-mlflow-flavor/setup.cfg b/h2o-py-mlflow-flavor/setup.cfg new file mode 100644 index 000000000000..a986ae5a563b --- /dev/null +++ b/h2o-py-mlflow-flavor/setup.cfg @@ -0,0 +1,20 @@ +[flake8] +# +# E241: (Multiple spaces after ':' or ',') Occasionally aligning code fragments vertically improves readability +# E265: (Block comment should start with '# ') I like having banner comments of the form #--------------------- +# E302: (Functions should be separated with 2 blank lines) PEP8 says that sometimes groups of related functions may be +# separated with 3 lines to improve readability. We do that. +# E303: (Classes should be separated with ? blank lines) "Spare is better than dense". Extra separators don't hurt. +# E701: (Multiple statements on the same line) PEP8 allows multiple statements on the same line in certain situations, +# for example `if foo: continue` is more readable in 1 line than in 2. +# +# D105: (Missing docstring in magic method) Magic methods have well-defined meaning, docstrings are redundant. +# +ignore = E241,E265,E302,E303,E701,D105 +max-line-length = 120 +application-import-names = h2o_mlflow_flavor +import-order-style = smarkets +inline-quotes = " + +[bdist_wheel] +universal = 1 diff --git a/h2o-py-mlflow-flavor/setup.py b/h2o-py-mlflow-flavor/setup.py new file mode 100644 index 000000000000..2335b221fda8 --- /dev/null +++ b/h2o-py-mlflow-flavor/setup.py @@ -0,0 +1,71 @@ +# -*- encoding: utf-8 -*- +from setuptools import setup, find_packages +from codecs import open +import os + +here = os.path.abspath(os.path.dirname(__file__)) + +# Get the long description from the relevant file +with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: + long_description = f.read() + +version = "0.1.0" +packages = find_packages(exclude=["tests*"]) +print("Found packages: %r" % packages) + +setup( + name='h2o_mlflow_flavor', + + # Versions should comply with PEP440. For a discussion on single-sourcing + # the version across setup.py and the project code, see + # https://packaging.python.org/en/latest/single_source_version.html + version = version, + + description='A mlflow flavor for working with H2O-3 MOJO and POJO models', + long_description=long_description, + + # The project's main homepage. + url='https://github.com/h2oai/h2o-3.git', + + # Author details + author='H2O.ai', + author_email='support@h2o.ai', + + # Choose your license + license='Apache v2', + + # See https://pypi.python.org/pypi?%3Aaction=list_classifiers + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 3 - Alpha", + + # Indicate who your project is intended for + "Intended Audience :: Education", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Customer Service", + "Intended Audience :: Financial and Insurance Industry", + "Intended Audience :: Healthcare Industry", + "Intended Audience :: Telecommunications Industry", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Information Analysis", + + # Pick your license as you wish (should match "license" above) + "License :: OSI Approved :: Apache Software License", + + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + + keywords='ML Flow, H2O-3', + + packages=packages, + + # run-time dependencies + install_requires=["mlflow>=1.29.0"] +) diff --git a/settings.gradle b/settings.gradle index f7e8a05d05c2..e1d575b534f9 100644 --- a/settings.gradle +++ b/settings.gradle @@ -10,6 +10,7 @@ include 'h2o-app' include 'h2o-r' include 'h2o-py' include 'h2o-py-cloud-extensions' +include 'h2o-py-mlflow-flavor' include 'h2o-assemblies:main' include 'h2o-assemblies:minimal' include 'h2o-assemblies:steam'