From 1a431a75b9dd14cd67ba505f7fe491ad8b218495 Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Thu, 5 Dec 2024 18:52:17 +0900 Subject: [PATCH] [incubator-kie-drools-6180] accumulate min doesn't evaluate correctly with more than 18 digits BigDecimal --- .../rule/builder/util/AccumulateUtil.java | 16 +- .../BigDecimalMaxAccumulateFunction.java | 100 ++++++++++++ .../BigDecimalMinAccumulateFunction.java | 98 +++++++++++ .../BigIntegerMaxAccumulateFunction.java | 100 ++++++++++++ .../BigIntegerMinAccumulateFunction.java | 98 +++++++++++ .../META-INF/kie.default.properties.conf | 4 + .../integrationtests/AccumulateTest.java | 154 ++++++++++++++++++ 7 files changed, 568 insertions(+), 2 deletions(-) create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java diff --git a/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java b/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java index 192873909f1a..8191d4004ed7 100644 --- a/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java +++ b/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java @@ -55,7 +55,11 @@ public static String getFunctionName(Supplier> exprClassSupplier, Strin functionName = "maxI"; } else if (exprClass == Long.class) { functionName = "maxL"; - } else if (Number.class.isAssignableFrom( exprClass )) { + } else if (exprClass == BigInteger.class) { + functionName = "maxBI"; + } else if (exprClass == BigDecimal.class) { + functionName = "maxBD"; + } else if (isFixedPrecisionNumber(exprClass)) { functionName = "maxN"; } } else if (functionName.equals("min")) { @@ -64,13 +68,21 @@ public static String getFunctionName(Supplier> exprClassSupplier, Strin functionName = "minI"; } else if (exprClass == Long.class) { functionName = "minL"; - } else if (Number.class.isAssignableFrom( exprClass )) { + } else if (exprClass == BigInteger.class) { + functionName = "minBI"; + } else if (exprClass == BigDecimal.class) { + functionName = "minBD"; + } else if (isFixedPrecisionNumber(exprClass)) { functionName = "minN"; } } return functionName; } + private static boolean isFixedPrecisionNumber(Class exprClass) { + return Number.class.isAssignableFrom(exprClass) && exprClass != BigDecimal.class && exprClass != BigInteger.class; + } + @SuppressWarnings("unchecked") public static AccumulateFunction loadAccumulateFunction(ClassLoader classLoader, String identifier, String className) { diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java new file mode 100644 index 000000000000..884fa35a822a --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigDecimal; + +/** + * An implementation of an accumulator capable of calculating maximum values + */ +public class BigDecimalMaxAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + + } + + public void writeExternal(ObjectOutput out) throws IOException { + + } + + protected static class MaxData implements Externalizable { + public BigDecimal max = null; + + public MaxData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + max = (BigDecimal) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(max); + } + + @Override + public String toString() { + return "max"; + } + } + + public MaxData createContext() { + return new MaxData(); + } + + public void init(MaxData data) { + data.max = null; + } + + public void accumulate(MaxData data, + Object value) { + if (value != null) { + BigDecimal bdValue = (BigDecimal) value; + data.max = data.max == null || data.max.compareTo(bdValue) < 0 ? + bdValue : + data.max; + } + } + + public void reverse(MaxData data, + Object value) { + } + + @Override + public boolean tryReverse( MaxData data, Object value ) { + if (value != null) { + return data.max.compareTo((BigDecimal) value) > 0; + } + return true; + } + + public Object getResult(MaxData data) { + return data.max; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigDecimal.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java new file mode 100644 index 000000000000..c3f0a628315b --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigDecimal; + +/** + * An implementation of an accumulator capable of calculating minimum values + */ +public class BigDecimalMinAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + public void writeExternal(ObjectOutput out) throws IOException { + } + + protected static class MinData implements Externalizable { + public BigDecimal min = null; + + public MinData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + min = (BigDecimal) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(min); + } + + @Override + public String toString() { + return "min"; + } + } + + public MinData createContext() { + return new MinData(); + } + + public void init(MinData data) { + data.min = null; + } + + public void accumulate(MinData data, + Object value) { + if (value != null) { + BigDecimal bdValue = (BigDecimal) value; + data.min = data.min == null || data.min.compareTo(bdValue) > 0 ? + bdValue : + data.min; + } + } + + @Override + public boolean tryReverse( MinData data, Object value ) { + if (value != null) { + return data.min.compareTo((BigDecimal) value) < 0; + } + return true; + } + + public void reverse(MinData data, + Object value) { + } + + public Object getResult(MinData data) { + return data.min; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigDecimal.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java new file mode 100644 index 000000000000..2fee047f05d2 --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigInteger; + +/** + * An implementation of an accumulator capable of calculating maximum values + */ +public class BigIntegerMaxAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + + } + + public void writeExternal(ObjectOutput out) throws IOException { + + } + + protected static class MaxData implements Externalizable { + public BigInteger max = null; + + public MaxData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + max = (BigInteger) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(max); + } + + @Override + public String toString() { + return "max"; + } + } + + public MaxData createContext() { + return new MaxData(); + } + + public void init(MaxData data) { + data.max = null; + } + + public void accumulate(MaxData data, + Object value) { + if (value != null) { + BigInteger biValue = (BigInteger) value; + data.max = data.max == null || data.max.compareTo(biValue) < 0 ? + biValue : + data.max; + } + } + + public void reverse(MaxData data, + Object value) { + } + + @Override + public boolean tryReverse( MaxData data, Object value ) { + if (value != null) { + return data.max.compareTo((BigInteger) value) > 0; + } + return true; + } + + public Object getResult(MaxData data) { + return data.max; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigInteger.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java new file mode 100644 index 000000000000..f292e73afff6 --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigInteger; + +/** + * An implementation of an accumulator capable of calculating minimum values + */ +public class BigIntegerMinAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + public void writeExternal(ObjectOutput out) throws IOException { + } + + protected static class MinData implements Externalizable { + public BigInteger min = null; + + public MinData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + min = (BigInteger) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(min); + } + + @Override + public String toString() { + return "min"; + } + } + + public MinData createContext() { + return new MinData(); + } + + public void init(MinData data) { + data.min = null; + } + + public void accumulate(MinData data, + Object value) { + if (value != null) { + BigInteger biValue = (BigInteger) value; + data.min = data.min == null || data.min.compareTo(biValue) > 0 ? + biValue : + data.min; + } + } + + @Override + public boolean tryReverse( MinData data, Object value ) { + if (value != null) { + return data.min.compareTo((BigInteger) value) < 0; + } + return true; + } + + public void reverse(MinData data, + Object value) { + } + + public Object getResult(MinData data) { + return data.min; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigInteger.class; + } +} diff --git a/drools-core/src/main/resources/META-INF/kie.default.properties.conf b/drools-core/src/main/resources/META-INF/kie.default.properties.conf index f3fcd0b2e699..823723c7e171 100644 --- a/drools-core/src/main/resources/META-INF/kie.default.properties.conf +++ b/drools-core/src/main/resources/META-INF/kie.default.properties.conf @@ -44,10 +44,14 @@ drools.accumulate.function.max = org.drools.core.base.accumulators.MaxAccumulate drools.accumulate.function.maxN = org.drools.core.base.accumulators.NumericMaxAccumulateFunction drools.accumulate.function.maxI = org.drools.core.base.accumulators.IntegerMaxAccumulateFunction drools.accumulate.function.maxL = org.drools.core.base.accumulators.LongMaxAccumulateFunction +drools.accumulate.function.maxBI = org.drools.core.base.accumulators.BigIntegerMaxAccumulateFunction +drools.accumulate.function.maxBD = org.drools.core.base.accumulators.BigDecimalMaxAccumulateFunction drools.accumulate.function.min = org.drools.core.base.accumulators.MinAccumulateFunction drools.accumulate.function.minN = org.drools.core.base.accumulators.NumericMinAccumulateFunction drools.accumulate.function.minI = org.drools.core.base.accumulators.IntegerMinAccumulateFunction drools.accumulate.function.minL = org.drools.core.base.accumulators.LongMinAccumulateFunction +drools.accumulate.function.minBI = org.drools.core.base.accumulators.BigIntegerMinAccumulateFunction +drools.accumulate.function.minBD = org.drools.core.base.accumulators.BigDecimalMinAccumulateFunction drools.accumulate.function.count = org.drools.core.base.accumulators.CountAccumulateFunction drools.accumulate.function.collectList = org.drools.core.base.accumulators.CollectListAccumulateFunction drools.accumulate.function.collectSet = org.drools.core.base.accumulators.CollectSetAccumulateFunction diff --git a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java index 53b1c746f186..8b180f266dd9 100644 --- a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java +++ b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java @@ -22,6 +22,7 @@ import java.io.ObjectOutput; import java.io.Serializable; import java.math.BigDecimal; +import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -37,6 +38,7 @@ import org.drools.core.RuleSessionConfiguration; import org.drools.commands.runtime.rule.InsertElementsCommand; import org.drools.kiesession.rulebase.InternalKnowledgeBase; +import org.drools.mvel.compiler.Primitives; import org.drools.testcoverage.common.model.Cheese; import org.drools.testcoverage.common.model.Cheesery; import org.drools.testcoverage.common.model.Order; @@ -3941,4 +3943,156 @@ public void testPeerCollectWithEager(KieBaseTestConfiguration kieBaseTestConfigu kieSession.dispose(); } } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void minWithBigDecimalHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bd : bigDecimal), $min : min($bd))\n" + + "then\n" + + " results.add($min);\n" + + // to confirm if $min is BigDecimal at build time (Not Comparable) + // The return value isn't important to assert + " System.out.println($min.scale());\n" + + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigDecimal(new BigDecimal("2024043020240501130000")); + Primitives p2_smallest = new Primitives(); + p2_smallest.setBigDecimal(new BigDecimal("2024043020240501120000")); + Primitives p3 = new Primitives(); + p3.setBigDecimal(new BigDecimal("2024043020240501150000")); + + kieSession.insert(p1); + kieSession.insert(p2_smallest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(1); + assertThat(results.get(0)).isEqualTo(p2_smallest.getBigDecimal()); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void minWithBigIntegerHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bi : bigInteger), $min : min($bi))\n" + + "then\n" + + " results.add($min);\n" + + // to confirm if $min is BigInteger at build time (Not Comparable) + // The return value isn't important to assert + " System.out.println($min.nextProbablePrime());\n" + + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigInteger(new BigInteger("2024043020240501130000")); + Primitives p2_smallest = new Primitives(); + p2_smallest.setBigInteger(new BigInteger("2024043020240501120000")); + Primitives p3 = new Primitives(); + p3.setBigInteger(new BigInteger("2024043020240501150000")); + + kieSession.insert(p1); + kieSession.insert(p2_smallest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(1); + assertThat(results.get(0)).isEqualTo(p2_smallest.getBigInteger()); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void maxWithBigDecimalHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bd : bigDecimal), $max : max($bd))\n" + + "then\n" + + " results.add($max);\n" + + // to confirm if $max is BigDecimal at build time (Not Comparable) + // The return value isn't important to assert + " System.out.println($max.scale());\n" + + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigDecimal(new BigDecimal("2024043020240501130000")); + Primitives p2_largest = new Primitives(); + p2_largest.setBigDecimal(new BigDecimal("2024043020240501150000")); + Primitives p3 = new Primitives(); + p3.setBigDecimal(new BigDecimal("2024043020240501120000")); + + kieSession.insert(p1); + kieSession.insert(p2_largest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(1); + assertThat(results.get(0)).isEqualTo(p2_largest.getBigDecimal()); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void maxWithBigIntegerHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bi : bigInteger), $max : max($bi))\n" + + "then\n" + + " results.add($max);\n" + + // to confirm if $max is BigInteger at build time (Not Comparable) + // The return value isn't important to assert + " System.out.println($max.nextProbablePrime());\n" + + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigInteger(new BigInteger("2024043020240501130000")); + Primitives p2_largest = new Primitives(); + p2_largest.setBigInteger(new BigInteger("2024043020240501150000")); + Primitives p3 = new Primitives(); + p3.setBigInteger(new BigInteger("2024043020240501120000")); + + kieSession.insert(p1); + kieSession.insert(p2_largest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(1); + assertThat(results.get(0)).isEqualTo(p2_largest.getBigInteger()); + } finally { + kieSession.dispose(); + } + } }