From 2501a0ba718afec0f784d97c01c3557ae95ee78b Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Wed, 10 Jan 2024 00:13:53 +0900 Subject: [PATCH] [KIE-775] drools executable-model fails with a bind variable to a calculation result of int and BigDecimal (#5636) --- .../ArithmeticCoercedExpression.java | 4 ++ .../generator/drlxparse/ConstraintParser.java | 50 +++++++++++++++---- .../expressiontyper/ExpressionTyper.java | 30 ++++++++--- .../bigdecimaltest/BigDecimalTest.java | 41 +++++++++++++++ 4 files changed, 110 insertions(+), 15 deletions(-) diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java index d9074df9ca1..279a23f3fb5 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java @@ -51,6 +51,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, this.operator = operator; } + /* + * This coercion only deals with String vs Numeric types. + * BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall() + */ public ArithmeticCoercedExpressionResult coerce() { if (!requiresCoercion()) { diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java index f70f296cd3c..04bfcdc68fb 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java @@ -103,6 +103,9 @@ import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod; import static org.drools.modelcompiler.builder.generator.drlxparse.MultipleDrlxParseSuccess.createMultipleDrlxParseSuccess; import static org.drools.modelcompiler.builder.generator.drlxparse.SpecialComparisonCase.specialComparisonFactory; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall; import static org.drools.modelcompiler.builder.generator.expressiontyper.FlattenScope.transformFullyQualifiedInlineCastExpr; import static org.drools.mvel.parser.printer.PrintUtil.printNode; import static org.drools.mvel.parser.utils.AstUtils.isLogicalOperator; @@ -195,11 +198,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes } private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) { - DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() ); + DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult)); if (drlx.getExpr() instanceof NameExpr) { decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) ); } else if (drlx.getExpr() instanceof BinaryExpr) { - decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft())); + Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr()); + decl.setBoundVariable(PrintUtil.printNode(leftMostExpression)); + if (singleResult.getExpr() instanceof MethodCallExpr) { + // BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr + ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext(); + ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext); + TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression); + Optional optLeft = leftTypedExpressionResult.getTypedExpression(); + if (!optLeft.isPresent()) { + throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft()); + } + singleResult.setBoundExpr(optLeft.get()); + } } decl.setBelongingPatternDescr(context.getCurrentPatternDescr()); singleResult.setExprBinding( bindId ); @@ -209,6 +224,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe } } + private static Class getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) { + if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) { + // in case of enclosed, bind type should be the calculation result type + // If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior + return (Class)singleResult.getExprType(); + } else { + return singleResult.getLeftExprTypeBeforeCoercion(); + } + } + + private Expression getLeftMostExpression(BinaryExpr binaryExpr) { + Expression left = binaryExpr.getLeft(); + if (left instanceof BinaryExpr) { + return getLeftMostExpression((BinaryExpr) left); + } + return left; + } + /* This is the entry point for Constraint Transformation from a parsed MVEL constraint to a Java Expression @@ -650,17 +683,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT Expression combo; - boolean arithmeticExpr = isArithmeticOperator(operator); boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext ); boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint; + Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType(); + if (equalityExpr) { combo = getEqualityExpression( left, right, operator ).expression; - } else if (arithmeticExpr && (left.isBigDecimal())) { - ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType)); - CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr); - - combo = compiledExpressionResult.getExpression(); + } else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context); + exprType = getBinaryTypeAfterConversion(left.getType(), right.getType()); } else { if (left.getExpression() == null || right.getExpression() == null) { return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr)); @@ -688,7 +720,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT constraintType = Index.ConstraintType.FORALL_SELF_JOIN; } - return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType()) + return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType) .setDecodeConstraintType( constraintType ) .setUsedDeclarations( expressionTyperContext.getUsedDeclarations() ) .setUsedDeclarationsOnLeft( usedDeclarationsOnLeft ) diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java index 44ea4010453..28f007fea4a 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java @@ -237,7 +237,14 @@ private Optional toTypedExpressionRec(Expression drlxExpr) { right = coerced.getCoercedRight(); final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); - return of(new TypedExpression(combo, left.getType())); + + if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType()); + return of(new TypedExpression(expression, binaryType)); + } else { + return of(new TypedExpression(combo, left.getType())); + } } if (drlxExpr instanceof HalfBinaryExpr) { @@ -809,11 +816,12 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { TypedExpression rightTypedExpression = right.getTypedExpression() .orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!")); binaryExpr.setRight(rightTypedExpression.getExpression()); - java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); - if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), binaryType)) { + if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) { Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType()); return new TypedExpressionCursor(compiledExpression, binaryType); } else { + java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); return new TypedExpressionCursor(binaryExpr, binaryType); } } @@ -822,7 +830,7 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { * Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler. * This method can be generic, so we may centralize the calls in drools-model */ - private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { + public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType); CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr)); return compiledExpressionResult.getExpression(); @@ -831,8 +839,15 @@ private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryE /* * BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger. */ - private static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type type) { - return isArithmeticOperator(operator) && type.equals(BigDecimal.class); + public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)); + } + + /* + * After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger. + */ + public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType; } private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) { @@ -939,6 +954,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[] Expression argumentExpression = methodCallExpr.getArgument(i); if (argumentType != actualArgumentType) { + // unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes + // It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr() + methodCallExpr.replace(argumentExpression, new NameExpr("placeholder")); Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression); methodCallExpr.setArgument(i, coercedExpression); } diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java index ae87e658d2a..fa665510b99 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java @@ -742,4 +742,45 @@ public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() { public static String intToString(int value) { return Integer.toString(value); } + + @Test + public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (1000 * value1)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)"); + } + + private void bindVariableToBigDecimalCoercion(String binding) { + // KIE-775 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "global java.util.List result;\n" + + "rule R1\n" + + " when\n" + + " BDFact( " + binding + " )\n" + + " then\n" + + " result.add($var);\n" + + "end"; + + KieSession ksession = getKieSession(str); + List result = new ArrayList<>(); + ksession.setGlobal("result", result); + + BDFact bdFact = new BDFact(); + bdFact.setValue1(new BigDecimal("80")); + + ksession.insert(bdFact); + ksession.fireAllRules(); + + assertThat(result).contains(new BigDecimal("80000")); + } }