diff --git a/README.md b/README.md index 18ba0b2..b0c964b 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Java library and command-line application for converting [StatsModels](https://w * [Poisson](https://www.statsmodels.org/dev/generated/statsmodels.discrete.discrete_model.Poisson.html) * [OrderedModel](https://www.statsmodels.org/dev/generated/statsmodels.miscmodels.ordinal_model.OrderedModel.html): * Distributions: `logit`, `probit` + * Univariate Time-Series Analysis: + * [ARIMA](https://www.statsmodels.org/dev/generated/statsmodels.tsa.arima.model.ARIMA.html) * Production quality: * Complete test coverage. * Fully compliant with the [JPMML-Evaluator](https://github.com/jpmml/jpmml-evaluator) library. diff --git a/pmml-statsmodels/src/main/java/statsmodels/tools/Bunch.java b/pmml-statsmodels/src/main/java/statsmodels/tools/Bunch.java new file mode 100644 index 0000000..ff11a8e --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tools/Bunch.java @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tools; + +import org.jpmml.python.PythonObject; + +public class Bunch extends PythonObject { + + public Bunch(String module, String name){ + super(module, name); + } + + public void __setstate__(Bunch bunch){ + super.__setstate__(bunch); + } +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/arima/ARIMA.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/arima/ARIMA.java new file mode 100644 index 0000000..335a821 --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/arima/ARIMA.java @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.arima; + +import java.util.List; + +import org.dmg.pmml.Array; +import org.dmg.pmml.DataField; +import org.dmg.pmml.DataType; +import org.dmg.pmml.MiningField.UsageType; +import org.dmg.pmml.MiningFunction; +import org.dmg.pmml.MiningSchema; +import org.dmg.pmml.OpType; +import org.dmg.pmml.time_series.InterceptVector; +import org.dmg.pmml.time_series.MeasurementMatrix; +import org.dmg.pmml.time_series.StateSpaceModel; +import org.dmg.pmml.time_series.StateVector; +import org.dmg.pmml.time_series.TransitionMatrix; +import org.jpmml.converter.CMatrix; +import org.jpmml.converter.Matrix; +import org.jpmml.converter.ModelUtil; +import org.jpmml.converter.PMMLUtil; +import org.jpmml.converter.Schema; +import org.jpmml.python.HasArray; +import org.jpmml.statsmodels.StatsModelsEncoder; +import statsmodels.Results; +import statsmodels.tsa.TimeSeriesModel; +import statsmodels.tsa.statespace.SmootherResults; + +public class ARIMA extends TimeSeriesModel { + + public ARIMA(String module, String name){ + super(module, name); + } + + @Override + public Schema encodeSchema(StatsModelsEncoder encoder){ + Schema schema = super.encodeSchema(encoder); + + @SuppressWarnings("unused") + DataField dataField = encoder.createDataField("horizon", OpType.CONTINUOUS, DataType.INTEGER); + + return schema; + } + + @Override + public org.dmg.pmml.time_series.TimeSeriesModel encodeModel(Results results, Schema schema){ + HasArray predictedState = results.getArray("predicted_state"); + SmootherResults smootherResults = results.get("smoother_results", SmootherResults.class); + + HasArray design = smootherResults.getDesign(); + HasArray obsIntercept = smootherResults.getObsIntercept(); + HasArray transition = smootherResults.getTransition(); + + MiningSchema miningSchema = ModelUtil.createMiningSchema(schema.getLabel()) + .addMiningFields(ModelUtil.createMiningField("horizon", UsageType.SUPPLEMENTARY)); + + StateVector stateVector = new StateVector(createRealArray(predictedState, -1)); + + MeasurementMatrix measurementMatrix = new MeasurementMatrix(createMatrix(design)); + + TransitionMatrix transitionMatrix = new TransitionMatrix(createMatrix(transition)); + + InterceptVector interceptVector = new InterceptVector(createRealArray(obsIntercept, -1)) + .setType(InterceptVector.Type.OBSERVATION); + + StateSpaceModel stateSpaceModel = new StateSpaceModel() + .setStateVector(stateVector) + .setMeasurementMatrix(measurementMatrix) + .setTransitionMatrix(transitionMatrix) + .setInterceptVector(interceptVector); + + org.dmg.pmml.time_series.TimeSeriesModel timeSeriesModel = new org.dmg.pmml.time_series.TimeSeriesModel(MiningFunction.TIME_SERIES, org.dmg.pmml.time_series.TimeSeriesModel.Algorithm.STATE_SPACE_MODEL, miningSchema) + .setStateSpaceModel(stateSpaceModel); + + return timeSeriesModel; + } + + static + private Array createRealArray(HasArray hasArray, int column){ + Matrix matrix = toMatrix(hasArray); + + List columnValues; + + if(column >= 0){ + columnValues = (List)matrix.getColumnValues(column); + } else + + { + columnValues = (List)matrix.getColumnValues(matrix.getColumns() + column); + } + + return PMMLUtil.createRealArray(columnValues); + } + + static + private org.dmg.pmml.Matrix createMatrix(HasArray hasArray){ + Matrix matrix = toMatrix(hasArray); + + org.dmg.pmml.Matrix result = new org.dmg.pmml.Matrix() + .setNbRows(matrix.getRows()) + .setNbCols(matrix.getColumns()); + + for(int row = 0; row < matrix.getRows(); row++){ + List rowValues = (List)matrix.getRowValues(row); + + result.addArrays(PMMLUtil.createRealArray(rowValues)); + } + + return result; + } + + static + private Matrix toMatrix(HasArray hasArray){ + int[] shape = hasArray.getArrayShape(); + List values = hasArray.getArrayContent(); + + if(shape.length == 3){ + + if(shape[2] != 1){ + throw new IllegalArgumentException(); + } + } + + return new CMatrix(values, shape[0], shape[1]); + } + + private static final String NAME_INDEX = "index"; +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Initialization.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Initialization.java new file mode 100644 index 0000000..f0eff4b --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Initialization.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.statespace; + +import org.jpmml.python.CythonObject; + +public class Initialization extends CythonObject { + + public Initialization(String module, String name){ + super(module, name); + } + + @Override + public void __init__(Object[] args){ + super.__setstate__(INIT_ATTRIBUTES, args); + } + + private static final String[] INIT_ATTRIBUTES = { + "k_states", + "constant", + "stationary_cov", + "approximate_diffuse_variance" + }; +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanFilter.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanFilter.java new file mode 100644 index 0000000..3c4d138 --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanFilter.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.statespace; + +import org.jpmml.python.CythonObject; + +public class KalmanFilter extends CythonObject { + + public KalmanFilter(String module, String name){ + super(module, name); + } + + @Override + public void __init__(Object[] args){ + super.__setstate__(INIT_ATTRIBUTES, args); + } + + private static final String[] INIT_ATTRIBUTES = { + "model", + "filter_method", + "inversion_method", + "stability_method", + "conserve_memory", + "filter_timing", + "tolerance", + "loglikelihood_burn" + }; +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanSmoother.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanSmoother.java new file mode 100644 index 0000000..9e4b744 --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/KalmanSmoother.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.statespace; + +import org.jpmml.python.CythonObject; + +public class KalmanSmoother extends CythonObject { + + public KalmanSmoother(String module, String name){ + super(module, name); + } + + @Override + public void __init__(Object[] args){ + super.__setstate__(INIT_ATTRIBUTES, args); + } + + private static final String[] INIT_ATTRIBUTES = { + "model", + "kfilter", + "smoother_output", + "smooth_method" + }; +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/SmootherResults.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/SmootherResults.java new file mode 100644 index 0000000..cce6597 --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/SmootherResults.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.statespace; + +import org.jpmml.python.HasArray; +import org.jpmml.python.PythonObject; + +public class SmootherResults extends PythonObject { + + public SmootherResults(String module, String name){ + super(module, name); + } + + public HasArray getDesign(){ + return getArray("design"); + } + + public HasArray getObsIntercept(){ + return getArray("obs_intercept"); + } + + public HasArray getTransition(){ + return getArray("transition"); + } +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Statespace.java b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Statespace.java new file mode 100644 index 0000000..84e8dff --- /dev/null +++ b/pmml-statsmodels/src/main/java/statsmodels/tsa/statespace/Statespace.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package statsmodels.tsa.statespace; + +import org.jpmml.python.CythonObject; + +public class Statespace extends CythonObject { + + public Statespace(String module, String name){ + super(module, name); + } + + @Override + public void __init__(Object[] args){ + super.__setstate__(INIT_ATTRBUTES, args); + } + + private static final String[] INIT_ATTRBUTES = { + "obs", + "design", + "obs_intercept", + "obs_cov", + "transition", + "state_intercept", + "selection", + "state_cov", + "diagonal_obs_cov" + }; +} \ No newline at end of file diff --git a/pmml-statsmodels/src/main/resources/META-INF/statsmodels2pmml.properties b/pmml-statsmodels/src/main/resources/META-INF/statsmodels2pmml.properties index 3d957b3..be8a42f 100644 --- a/pmml-statsmodels/src/main/resources/META-INF/statsmodels2pmml.properties +++ b/pmml-statsmodels/src/main/resources/META-INF/statsmodels2pmml.properties @@ -28,4 +28,13 @@ statsmodels.regression.linear_model.WLS = statsmodels.regression.LinearRegressio statsmodels.regression.linear_model.RegressionResults = statsmodels.Results statsmodels.regression.linear_model.RegressionResultsWrapper = statsmodels.ResultsWrapper statsmodels.regression.quantile_regression.QuantReg = statsmodels.regression.LinearRegression -statsmodels.regression.quantile_regression.QuantRegResults = statsmodels.Results \ No newline at end of file +statsmodels.regression.quantile_regression.QuantRegResults = statsmodels.Results +statsmodels.tsa.arima.model.ARIMA = statsmodels.tsa.arima.ARIMA +statsmodels.tsa.arima.model.ARIMAResults = statsmodels.Results +statsmodels.tsa.arima.model.ARIMAResultsWrapper = statsmodels.ResultsWrapper +statsmodels.tsa.statespace._initialization.(dInitialization|zInitialization) = statsmodels.tsa.statespace.Initialization +statsmodels.tsa.statespace._kalman_filter.(dKalmanFilter|zKalmanFilter) = statsmodels.tsa.statespace.KalmanFilter +statsmodels.tsa.statespace._kalman_smoother.dKalmanSmoother = statsmodels.tsa.statespace.KalmanSmoother +statsmodels.tsa.statespace._representation.(dStatespace|zStatespace) = statsmodels.tsa.statespace.Statespace +statsmodels.tsa.statespace.kalman_smoother.SmootherResults = statsmodels.tsa.statespace.SmootherResults +statsmodels.tools.tools.Bunch = statsmodels.tools.Bunch diff --git a/pmml-statsmodels/src/test/java/org/jpmml/statsmodels/testing/TimeSeriesTest.java b/pmml-statsmodels/src/test/java/org/jpmml/statsmodels/testing/TimeSeriesTest.java new file mode 100644 index 0000000..3316cfe --- /dev/null +++ b/pmml-statsmodels/src/test/java/org/jpmml/statsmodels/testing/TimeSeriesTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of JPMML-StatsModels + * + * JPMML-StatsModels is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-StatsModels is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-StatsModels. If not, see . + */ +package org.jpmml.statsmodels.testing; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import com.google.common.base.Equivalence; +import com.google.common.collect.Iterables; +import org.jpmml.evaluator.ResultField; +import org.jpmml.evaluator.testing.Batch; +import org.jpmml.evaluator.testing.BatchUtil; +import org.jpmml.evaluator.testing.Conflict; +import org.jpmml.evaluator.time_series.SeriesForecast; +import org.junit.Test; + +public class TimeSeriesTest extends StatsModelsEncoderBatchTest { + + @Override + public void evaluate(Batch batch) throws Exception { + Function, List>> function = new Function, List>>(){ + + @Override + public List> apply(Map map){ + Map.Entry entry = Iterables.getOnlyElement(map.entrySet()); + + String name = entry.getKey(); + SeriesForecast seriesForecast = (SeriesForecast)entry.getValue(); + + List values = seriesForecast.getValues(); + + return values.stream() + .map(value -> Collections.singletonMap(name, value)) + .collect(Collectors.toList()); + } + }; + + List conflicts = BatchUtil.evaluateSingleton(batch, function); + + checkConflicts(conflicts); + } + + @Override + public StatsModelsEncoderBatch createBatch(String algorithm, String dataset, Predicate predicate, Equivalence equivalence){ + StatsModelsEncoderBatch result = new StatsModelsEncoderBatch(algorithm, dataset, predicate, equivalence){ + + @Override + public TimeSeriesTest getArchiveBatchTest(){ + return TimeSeriesTest.this; + } + + @Override + public List> getInput() throws IOException { + String algorithm = getAlgorithm(); + + // XXX + if("SSM".equals(algorithm)){ + return Collections.singletonList(Collections.singletonMap("horizon", 12)); + } + + return super.getInput(); + } + }; + + return result; + } + + @Test + public void evaluateSSMAirline() throws Exception { + evaluate("SSM", "Airline"); + } +} \ No newline at end of file diff --git a/pmml-statsmodels/src/test/resources/csv/Airline.csv b/pmml-statsmodels/src/test/resources/csv/Airline.csv new file mode 100644 index 0000000..cd7b703 --- /dev/null +++ b/pmml-statsmodels/src/test/resources/csv/Airline.csv @@ -0,0 +1,145 @@ +Month,Passengers +1949-01,112 +1949-02,118 +1949-03,132 +1949-04,129 +1949-05,121 +1949-06,135 +1949-07,148 +1949-08,148 +1949-09,136 +1949-10,119 +1949-11,104 +1949-12,118 +1950-01,115 +1950-02,126 +1950-03,141 +1950-04,135 +1950-05,125 +1950-06,149 +1950-07,170 +1950-08,170 +1950-09,158 +1950-10,133 +1950-11,114 +1950-12,140 +1951-01,145 +1951-02,150 +1951-03,178 +1951-04,163 +1951-05,172 +1951-06,178 +1951-07,199 +1951-08,199 +1951-09,184 +1951-10,162 +1951-11,146 +1951-12,166 +1952-01,171 +1952-02,180 +1952-03,193 +1952-04,181 +1952-05,183 +1952-06,218 +1952-07,230 +1952-08,242 +1952-09,209 +1952-10,191 +1952-11,172 +1952-12,194 +1953-01,196 +1953-02,196 +1953-03,236 +1953-04,235 +1953-05,229 +1953-06,243 +1953-07,264 +1953-08,272 +1953-09,237 +1953-10,211 +1953-11,180 +1953-12,201 +1954-01,204 +1954-02,188 +1954-03,235 +1954-04,227 +1954-05,234 +1954-06,264 +1954-07,302 +1954-08,293 +1954-09,259 +1954-10,229 +1954-11,203 +1954-12,229 +1955-01,242 +1955-02,233 +1955-03,267 +1955-04,269 +1955-05,270 +1955-06,315 +1955-07,364 +1955-08,347 +1955-09,312 +1955-10,274 +1955-11,237 +1955-12,278 +1956-01,284 +1956-02,277 +1956-03,317 +1956-04,313 +1956-05,318 +1956-06,374 +1956-07,413 +1956-08,405 +1956-09,355 +1956-10,306 +1956-11,271 +1956-12,306 +1957-01,315 +1957-02,301 +1957-03,356 +1957-04,348 +1957-05,355 +1957-06,422 +1957-07,465 +1957-08,467 +1957-09,404 +1957-10,347 +1957-11,305 +1957-12,336 +1958-01,340 +1958-02,318 +1958-03,362 +1958-04,348 +1958-05,363 +1958-06,435 +1958-07,491 +1958-08,505 +1958-09,404 +1958-10,359 +1958-11,310 +1958-12,337 +1959-01,360 +1959-02,342 +1959-03,406 +1959-04,396 +1959-05,420 +1959-06,472 +1959-07,548 +1959-08,559 +1959-09,463 +1959-10,407 +1959-11,362 +1959-12,405 +1960-01,417 +1960-02,391 +1960-03,419 +1960-04,461 +1960-05,472 +1960-06,535 +1960-07,622 +1960-08,606 +1960-09,508 +1960-10,461 +1960-11,390 +1960-12,432 diff --git a/pmml-statsmodels/src/test/resources/csv/SSMAirline.csv b/pmml-statsmodels/src/test/resources/csv/SSMAirline.csv new file mode 100644 index 0000000..9065a0d --- /dev/null +++ b/pmml-statsmodels/src/test/resources/csv/SSMAirline.csv @@ -0,0 +1,13 @@ +Passengers +462.84681216415777 +419.6656890912251 +432.5549230698275 +432.41897014081127 +464.6245890660358 +525.8395267588148 +589.7411706246921 +572.8376749381728 +522.6846728352011 +477.68338818213715 +426.0288999284137 +453.26381518906385 diff --git a/pmml-statsmodels/src/test/resources/main.py b/pmml-statsmodels/src/test/resources/main.py index bf79683..1ff93d8 100644 --- a/pmml-statsmodels/src/test/resources/main.py +++ b/pmml-statsmodels/src/test/resources/main.py @@ -5,6 +5,7 @@ from statsmodels.genmod.families import Binomial, Gaussian, Poisson from statsmodels.miscmodels.ordinal_model import OrderedModel from statsmodels.tools import add_constant +from statsmodels.tsa.arima.model import ARIMA import numpy import pandas @@ -19,8 +20,29 @@ def split_csv(df): def store_csv(df, name): df.to_csv("csv/" + name + ".csv", index = False) -def store_pkl(results, name): - results.save("pkl/" + name + ".pkl", remove_data = hasattr(results, "remove_data")) +def store_pkl(results, name, remove_data = None): + if remove_data: + remove_data &= hasattr(results, "remove_data") + results.save("pkl/" + name + ".pkl", remove_data = remove_data) + +airline_df = load_csv("Airline") +airline_df["Month"] = pandas.to_datetime(airline_df["Month"]) +airline_df.set_index("Month", inplace = True) +print(airline_df.dtypes) + +airline_y = airline_df["Passengers"] + +def build_airline(model, name, fit_method = "fit"): + results = getattr(model, fit_method)() + print(results.summary()) + + store_pkl(results, name, remove_data = False) + + airline_passengers = results.predict(start = len(airline_y), end = len(airline_y) + 11) + airline_passengers.name = "Passengers" + store_csv(airline_passengers, name) + +build_airline(ARIMA(airline_y, order = (12, 0, 0)), "SSMAirline") audit_df = load_csv("Audit") print(audit_df.dtypes) diff --git a/pmml-statsmodels/src/test/resources/pkl/SSMAirline.pkl b/pmml-statsmodels/src/test/resources/pkl/SSMAirline.pkl new file mode 100644 index 0000000..bca1b7c Binary files /dev/null and b/pmml-statsmodels/src/test/resources/pkl/SSMAirline.pkl differ