diff --git a/src/main/java/sklearn2pmml/pipeline/PMMLPipeline.java b/src/main/java/sklearn2pmml/pipeline/PMMLPipeline.java index 80a89b1f3..0d000f81e 100644 --- a/src/main/java/sklearn2pmml/pipeline/PMMLPipeline.java +++ b/src/main/java/sklearn2pmml/pipeline/PMMLPipeline.java @@ -523,12 +523,28 @@ public Estimator getEstimator(){ List steps = getSteps(); if(steps.size() < 1){ - throw new IllegalArgumentException("Expected one or more elements, got zero elements"); + throw new IllegalArgumentException("Expected one or more steps, got zero steps"); } Object[] lastStep = steps.get(steps.size() - 1); - return TupleUtil.extractElement(lastStep, 1, Estimator.class); + try { + return TupleUtil.extractElement(lastStep, 1, Estimator.class); + } catch(IllegalArgumentException iaeEstimator){ + Transformer transformer = null; + + try { + transformer = TupleUtil.extractElement(lastStep, 1, Transformer.class); + } catch(IllegalArgumentException iaeTransformer){ + // Ignored + } + + if(transformer != null){ + throw new IllegalArgumentException("Expected an estimator object as the last step, got a transformer object (" + ClassDictUtil.formatClass(transformer) + ")"); + } + + throw iaeEstimator; + } } @Override diff --git a/src/main/java/sklearn_pandas/DataFrameMapper.java b/src/main/java/sklearn_pandas/DataFrameMapper.java index 4352b2d87..f88c1e13b 100644 --- a/src/main/java/sklearn_pandas/DataFrameMapper.java +++ b/src/main/java/sklearn_pandas/DataFrameMapper.java @@ -48,7 +48,7 @@ public List initializeFeatures(SkLearnEncoder encoder){ List rows = getFeatures(); if(!(Boolean.FALSE).equals(_default)){ - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Attribute \'" + ClassDictUtil.formatMember(this, "default") + "\' must be set to the 'False' value"); } List result = new ArrayList<>();