From 31d786569e03680da4f9aaae573551f152431f94 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:57:05 +0100 Subject: [PATCH] [KIE-775] drools executable-model fails with a bind variable to a calculation result of int and BigDecimal (#5636) (#2) (#4) Co-authored-by: Toshiya Kobayashi --- .../ArithmeticCoercedExpression.java | 4 ++ .../generator/drlxparse/ConstraintParser.java | 50 +++++++++++++++---- .../expressiontyper/ExpressionTyper.java | 30 ++++++++--- .../bigdecimaltest/BigDecimalTest.java | 41 +++++++++++++++ 4 files changed, 110 insertions(+), 15 deletions(-) diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java index d9074df9ca1..279a23f3fb5 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ArithmeticCoercedExpression.java @@ -51,6 +51,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, this.operator = operator; } + /* + * This coercion only deals with String vs Numeric types. + * BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall() + */ public ArithmeticCoercedExpressionResult coerce() { if (!requiresCoercion()) { diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java index 4397ccfb10f..55bac2bd1ec 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/drlxparse/ConstraintParser.java @@ -97,6 +97,9 @@ import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.stripEnclosedExpr; import static org.drools.modelcompiler.builder.generator.drlxparse.MultipleDrlxParseSuccess.createMultipleDrlxParseSuccess; import static org.drools.modelcompiler.builder.generator.drlxparse.SpecialComparisonCase.specialComparisonFactory; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion; +import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall; import static org.drools.modelcompiler.builder.generator.expressiontyper.FlattenScope.transformFullyQualifiedInlineCastExpr; import static org.drools.mvel.parser.printer.PrintUtil.printNode; import static org.drools.mvel.parser.utils.AstUtils.isLogicalOperator; @@ -189,11 +192,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes } private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) { - DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() ); + DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult)); if (drlx.getExpr() instanceof NameExpr) { decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) ); } else if (drlx.getExpr() instanceof BinaryExpr) { - decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft())); + Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr()); + decl.setBoundVariable(PrintUtil.printNode(leftMostExpression)); + if (singleResult.getExpr() instanceof MethodCallExpr) { + // BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr + ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext(); + ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext); + TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression); + Optional optLeft = leftTypedExpressionResult.getTypedExpression(); + if (!optLeft.isPresent()) { + throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft()); + } + singleResult.setBoundExpr(optLeft.get()); + } } decl.setBelongingPatternDescr(context.getCurrentPatternDescr()); singleResult.setExprBinding( bindId ); @@ -203,6 +218,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe } } + private static Class getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) { + if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) { + // in case of enclosed, bind type should be the calculation result type + // If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior + return (Class)singleResult.getExprType(); + } else { + return singleResult.getLeftExprTypeBeforeCoercion(); + } + } + + private Expression getLeftMostExpression(BinaryExpr binaryExpr) { + Expression left = binaryExpr.getLeft(); + if (left instanceof BinaryExpr) { + return getLeftMostExpression((BinaryExpr) left); + } + return left; + } + /* This is the entry point for Constraint Transformation from a parsed MVEL constraint to a Java Expression @@ -626,17 +659,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT Expression combo; - boolean arithmeticExpr = isArithmeticOperator(operator); boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext ); boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint; + Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType(); + if (equalityExpr) { combo = getEqualityExpression( left, right, operator ).expression; - } else if (arithmeticExpr && (left.isBigDecimal())) { - ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType)); - CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr); - - combo = compiledExpressionResult.getExpression(); + } else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context); + exprType = getBinaryTypeAfterConversion(left.getType(), right.getType()); } else { if (left.getExpression() == null || right.getExpression() == null) { return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr)); @@ -664,7 +696,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT constraintType = Index.ConstraintType.FORALL_SELF_JOIN; } - return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType()) + return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType) .setDecodeConstraintType( constraintType ) .setUsedDeclarations( expressionTyperContext.getUsedDeclarations() ) .setUsedDeclarationsOnLeft( usedDeclarationsOnLeft ) diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java index 69c477c133b..9020f2307e6 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expressiontyper/ExpressionTyper.java @@ -237,7 +237,14 @@ private Optional toTypedExpressionRec(Expression drlxExpr) { right = coerced.getCoercedRight(); final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); - return of(new TypedExpression(combo, left.getType())); + + if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType()); + return of(new TypedExpression(expression, binaryType)); + } else { + return of(new TypedExpression(combo, left.getType())); + } } if (drlxExpr instanceof HalfBinaryExpr) { @@ -810,11 +817,12 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { TypedExpression rightTypedExpression = right.getTypedExpression() .orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!")); binaryExpr.setRight(rightTypedExpression.getExpression()); - java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); - if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), binaryType)) { + if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) { Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType()); return new TypedExpressionCursor(compiledExpression, binaryType); } else { + java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); return new TypedExpressionCursor(binaryExpr, binaryType); } } @@ -823,7 +831,7 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { * Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler. * This method can be generic, so we may centralize the calls in drools-model */ - private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { + public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType); CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr)); return compiledExpressionResult.getExpression(); @@ -832,8 +840,15 @@ private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryE /* * BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger. */ - private static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type type) { - return isArithmeticOperator(operator) && type.equals(BigDecimal.class); + public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)); + } + + /* + * After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger. + */ + public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType; } private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) { @@ -940,6 +955,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[] Expression argumentExpression = methodCallExpr.getArgument(i); if (argumentType != actualArgumentType) { + // unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes + // It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr() + methodCallExpr.replace(argumentExpression, new NameExpr("placeholder")); Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression); methodCallExpr.setArgument(i, coercedExpression); } diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java index 90655bdb715..05acc4e97e3 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java @@ -682,4 +682,45 @@ public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() { public static String intToString(int value) { return Integer.toString(value); } + + @Test + public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (1000 * value1)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)"); + } + + private void bindVariableToBigDecimalCoercion(String binding) { + // KIE-775 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "global java.util.List result;\n" + + "rule R1\n" + + " when\n" + + " BDFact( " + binding + " )\n" + + " then\n" + + " result.add($var);\n" + + "end"; + + KieSession ksession = getKieSession(str); + List result = new ArrayList<>(); + ksession.setGlobal("result", result); + + BDFact bdFact = new BDFact(); + bdFact.setValue1(new BigDecimal("80")); + + ksession.insert(bdFact); + ksession.fireAllRules(); + + assertThat(result).contains(new BigDecimal("80000")); + } }