From 1c7d9022eef416b10b7ce9eb1fb59674fba5ee5e Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Fri, 12 Jul 2024 15:02:18 -0400 Subject: [PATCH] feat: add support for Decimal and Decimal score types (#110) - Decimal maps mostly to BigDecimal, although its floating point concepts are ignored (Python does not have an infinite precision MathContext, so it acts more like a dynamic range floating point with an adjustable precision. The precision used is shared in a thread local object that can be changed using decimal.setcontext. - Added `str` constructors to `float` and `int` - Added sanity tests for all variants of penalize/reward/impact and score types --- .../jpyinterpreter/PythonClassTranslator.java | 4 +- .../JavaPythonTypeConversionImplementor.java | 187 ++-- .../jpyinterpreter/types/BuiltinTypes.java | 2 + .../types/collections/PythonLikeTuple.java | 8 +- .../types/numeric/PythonDecimal.java | 805 +++++++++++++++++ .../types/numeric/PythonFloat.java | 56 +- .../types/numeric/PythonInteger.java | 64 +- .../types/numeric/PythonNumber.java | 24 +- jpyinterpreter/src/main/python/conversions.py | 16 +- jpyinterpreter/tests/test_builtins.py | 13 + jpyinterpreter/tests/test_decimal.py | 782 +++++++++++++++++ tests/test_collectors.py | 2 +- tests/test_constraint_streams.py | 317 ++++++- tests/test_score.py | 155 +++- ...ableDecimalScorePythonJavaTypeMapping.java | 88 ++ .../BendableScorePythonJavaTypeMapping.java | 4 +- ...SoftDecimalScorePythonJavaTypeMapping.java | 74 ++ ...dMediumSoftScorePythonJavaTypeMapping.java | 4 +- ...SoftDecimalScorePythonJavaTypeMapping.java | 70 ++ .../HardSoftScorePythonJavaTypeMapping.java | 4 +- ...mpleDecimalScorePythonJavaTypeMapping.java | 66 ++ .../SimpleScorePythonJavaTypeMapping.java | 4 +- .../main/python/_jpype_type_conversions.py | 21 + .../src/main/python/_timefold_java_interop.py | 35 +- .../main/python/score/_constraint_stream.py | 823 +++++++++++++++++- .../main/python/score/_function_translator.py | 15 +- .../src/main/python/score/_score.py | 273 +++++- .../main/python/score/_score_conversions.py | 20 + ...DecimalScorePythonJavaTypeMappingTest.java | 95 ++ ...DecimalScorePythonJavaTypeMappingTest.java | 75 ++ ...DecimalScorePythonJavaTypeMappingTest.java | 67 ++ .../score/PythonBendableDecimalScore.java | 46 + .../PythonHardMediumSoftDecimalScore.java | 40 + .../score/PythonHardSoftDecimalScore.java | 35 + .../score/PythonSimpleDecimalScore.java | 32 + ...DecimalScorePythonJavaTypeMappingTest.java | 63 ++ 36 files changed, 4198 insertions(+), 191 deletions(-) create mode 100644 jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonDecimal.java create mode 100644 jpyinterpreter/tests/test_decimal.py create mode 100644 timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMapping.java create mode 100644 timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMapping.java create mode 100644 timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMapping.java create mode 100644 timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMapping.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMappingTest.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMappingTest.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMappingTest.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonBendableDecimalScore.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardMediumSoftDecimalScore.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardSoftDecimalScore.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonSimpleDecimalScore.java create mode 100644 timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMappingTest.java diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 37baf4b5..3b988c16 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -66,8 +66,8 @@ public class PythonClassTranslator { // $ is illegal in variables/methods in Python public static final String TYPE_FIELD_NAME = "$TYPE"; public static final String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE"; - private static final String JAVA_METHOD_PREFIX = "$method$"; - private static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping"; + public static final String JAVA_METHOD_PREFIX = "$method$"; + public static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping"; public record PreparedClassInfo(PythonLikeType type, String className, String classInternalName) { } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/JavaPythonTypeConversionImplementor.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/JavaPythonTypeConversionImplementor.java index 5c0861a2..ba302add 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/JavaPythonTypeConversionImplementor.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/JavaPythonTypeConversionImplementor.java @@ -1,6 +1,7 @@ package ai.timefold.jpyinterpreter.implementors; import java.lang.reflect.Field; +import java.math.BigDecimal; import java.math.BigInteger; import java.util.IdentityHashMap; import java.util.Iterator; @@ -31,6 +32,7 @@ import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; import ai.timefold.jpyinterpreter.types.errors.TypeError; import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; import ai.timefold.jpyinterpreter.types.numeric.PythonFloat; import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; import ai.timefold.jpyinterpreter.types.numeric.PythonNumber; @@ -65,76 +67,78 @@ public static PythonLikeObject wrapJavaObject(Object object, Map((Iterator) object); + if (object instanceof Iterator iterator) { + return new DelegatePythonIterator<>(iterator); } - if (object instanceof List) { - PythonLikeList out = new PythonLikeList(); + if (object instanceof List list) { + PythonLikeList out = new PythonLikeList<>(); createdObjectMap.put(object, out); - for (Object item : (List) object) { + for (Object item : list) { out.add(wrapJavaObject(item)); } return out; } - if (object instanceof Set) { - PythonLikeSet out = new PythonLikeSet(); + if (object instanceof Set set) { + PythonLikeSet out = new PythonLikeSet<>(); createdObjectMap.put(object, out); - for (Object item : (Set) object) { + for (Object item : set) { out.add(wrapJavaObject(item)); } return out; } - if (object instanceof Map) { - PythonLikeDict out = new PythonLikeDict(); + if (object instanceof Map map) { + PythonLikeDict out = new PythonLikeDict<>(); createdObjectMap.put(object, out); - Set> entrySet = ((Map) object).entrySet(); - for (Map.Entry entry : entrySet) { + var entrySet = map.entrySet(); + for (var entry : entrySet) { out.put(wrapJavaObject(entry.getKey()), wrapJavaObject(entry.getValue())); } return out; } - if (object instanceof Class) { - Class maybeFunctionClass = (Class) object; - if (Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) { - return new PythonCode((Class) maybeFunctionClass); - } + if (object instanceof Class maybeFunctionClass && + Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) { + return new PythonCode((Class) maybeFunctionClass); } - if (object instanceof OpaquePythonReference) { - return new PythonObjectWrapper((OpaquePythonReference) object); + if (object instanceof OpaquePythonReference opaquePythonReference) { + return new PythonObjectWrapper(opaquePythonReference); } // Default: return a JavaObjectWrapper @@ -161,6 +165,10 @@ public static PythonLikeType getPythonLikeType(Class javaClass) { return BuiltinTypes.INT_TYPE; } + if (BigDecimal.class.equals(javaClass) || PythonDecimal.class.equals(javaClass)) { + return BuiltinTypes.DECIMAL_TYPE; + } + if (float.class.equals(javaClass) || double.class.equals(javaClass) || Float.class.equals(javaClass) || Double.class.equals(javaClass) || PythonFloat.class.equals(javaClass)) { @@ -254,8 +262,7 @@ public static T convertPythonObjectToJavaType(Class type, Pytho return null; } - if (object instanceof JavaObjectWrapper) { - JavaObjectWrapper wrappedObject = (JavaObjectWrapper) object; + if (object instanceof JavaObjectWrapper wrappedObject) { Object javaObject = wrappedObject.getWrappedObject(); if (!type.isAssignableFrom(javaObject.getClass())) { throw new TypeError("Cannot convert from (" + getPythonLikeType(javaObject.getClass()) + ") to (" @@ -266,14 +273,13 @@ public static T convertPythonObjectToJavaType(Class type, Pytho if (type.equals(byte.class) || type.equals(short.class) || type.equals(int.class) || type.equals(long.class) || type.equals(float.class) || type.equals(double.class) || Number.class.isAssignableFrom(type)) { - if (!(object instanceof PythonNumber)) { + if (!(object instanceof PythonNumber pythonNumber)) { throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to (" + getPythonLikeType(type) + ")."); } - PythonNumber pythonNumber = (PythonNumber) object; Number value = pythonNumber.getValue(); - if (type.equals(BigInteger.class)) { + if (type.equals(BigInteger.class) || type.equals(BigDecimal.class)) { return (T) value; } @@ -303,11 +309,10 @@ public static T convertPythonObjectToJavaType(Class type, Pytho } if (type.equals(boolean.class) || type.equals(Boolean.class)) { - if (!(object instanceof PythonBoolean)) { + if (!(object instanceof PythonBoolean pythonBoolean)) { throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to (" + getPythonLikeType(type) + ")."); } - PythonBoolean pythonBoolean = (PythonBoolean) object; return (T) (Boolean) pythonBoolean.getBooleanValue(); } @@ -335,6 +340,53 @@ public static void loadName(MethodVisitor methodVisitor, String name) { false); } + private record ReturnValueOpDescriptor( + String wrapperClassName, + String methodName, + String methodDescriptor, + int opcode, + boolean noConversionNeeded) { + public static ReturnValueOpDescriptor noConversion() { + return new ReturnValueOpDescriptor("", "", "", + Opcodes.ARETURN, true); + } + + public static ReturnValueOpDescriptor forNumeric(String methodName, + String methodDescriptor, + int opcode) { + return new ReturnValueOpDescriptor(Type.getInternalName(Number.class), methodName, methodDescriptor, opcode, + false); + } + } + + private static final Map numericReturnValueOpDescriptorMap = Map.of( + Type.BYTE_TYPE, ReturnValueOpDescriptor.forNumeric( + "byteValue", + Type.getMethodDescriptor(Type.BYTE_TYPE), + Opcodes.IRETURN), + Type.SHORT_TYPE, ReturnValueOpDescriptor.forNumeric( + "shortValue", + Type.getMethodDescriptor(Type.SHORT_TYPE), + Opcodes.IRETURN), + Type.INT_TYPE, ReturnValueOpDescriptor.forNumeric( + "intValue", + Type.getMethodDescriptor(Type.INT_TYPE), + Opcodes.IRETURN), + Type.LONG_TYPE, ReturnValueOpDescriptor.forNumeric( + "longValue", + Type.getMethodDescriptor(Type.LONG_TYPE), + Opcodes.LRETURN), + Type.FLOAT_TYPE, ReturnValueOpDescriptor.forNumeric( + "floatValue", + Type.getMethodDescriptor(Type.FLOAT_TYPE), + Opcodes.FRETURN), + Type.DOUBLE_TYPE, ReturnValueOpDescriptor.forNumeric( + "doubleValue", + Type.getMethodDescriptor(Type.DOUBLE_TYPE), + Opcodes.DRETURN), + Type.getType(BigInteger.class), ReturnValueOpDescriptor.noConversion(), + Type.getType(BigDecimal.class), ReturnValueOpDescriptor.noConversion()); + /** * If {@code method} return type is not void, convert TOS into its Java equivalent and return it. * If {@code method} return type is void, immediately return. @@ -344,67 +396,36 @@ public static void loadName(MethodVisitor methodVisitor, String name) { public static void returnValue(MethodVisitor methodVisitor, MethodDescriptor method, StackMetadata stackMetadata) { Type returnAsmType = method.getReturnType(); + if (Type.CHAR_TYPE.equals(returnAsmType)) { + throw new IllegalStateException("Unhandled case for primitive type (char)."); + } + if (Type.VOID_TYPE.equals(returnAsmType)) { methodVisitor.visitInsn(Opcodes.RETURN); return; } - if (Type.BYTE_TYPE.equals(returnAsmType) || - Type.CHAR_TYPE.equals(returnAsmType) || - Type.SHORT_TYPE.equals(returnAsmType) || - Type.INT_TYPE.equals(returnAsmType) || - Type.LONG_TYPE.equals(returnAsmType) || - Type.FLOAT_TYPE.equals(returnAsmType) || - Type.DOUBLE_TYPE.equals(returnAsmType)) { + if (numericReturnValueOpDescriptorMap.containsKey(returnAsmType)) { + var returnValueOpDescriptor = numericReturnValueOpDescriptorMap.get(returnAsmType); methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(PythonNumber.class)); methodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(PythonNumber.class), "getValue", Type.getMethodDescriptor(Type.getType(Number.class)), true); - String wrapperClassName = null; - String methodName = null; - String methodDescriptor = null; - int returnOpcode = 0; - - if (Type.BYTE_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "byteValue"; - methodDescriptor = Type.getMethodDescriptor(Type.BYTE_TYPE); - returnOpcode = Opcodes.IRETURN; - } else if (Type.CHAR_TYPE.equals(returnAsmType)) { - throw new IllegalStateException("Unhandled case for primitive type (char)."); - // returnOpcode = Opcodes.IRETURN; - } else if (Type.SHORT_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "shortValue"; - methodDescriptor = Type.getMethodDescriptor(Type.SHORT_TYPE); - returnOpcode = Opcodes.IRETURN; - } else if (Type.INT_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "intValue"; - methodDescriptor = Type.getMethodDescriptor(Type.INT_TYPE); - returnOpcode = Opcodes.IRETURN; - } else if (Type.FLOAT_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "floatValue"; - methodDescriptor = Type.getMethodDescriptor(Type.FLOAT_TYPE); - returnOpcode = Opcodes.FRETURN; - } else if (Type.LONG_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "longValue"; - methodDescriptor = Type.getMethodDescriptor(Type.LONG_TYPE); - returnOpcode = Opcodes.LRETURN; - } else if (Type.DOUBLE_TYPE.equals(returnAsmType)) { - wrapperClassName = Type.getInternalName(Number.class); - methodName = "doubleValue"; - methodDescriptor = Type.getMethodDescriptor(Type.DOUBLE_TYPE); - returnOpcode = Opcodes.DRETURN; + + if (returnValueOpDescriptor.noConversionNeeded) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, returnAsmType.getInternalName()); + methodVisitor.visitInsn(Opcodes.ARETURN); + return; } + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, - wrapperClassName, methodName, methodDescriptor, + returnValueOpDescriptor.wrapperClassName, + returnValueOpDescriptor.methodName, + returnValueOpDescriptor.methodDescriptor, false); - methodVisitor.visitInsn(returnOpcode); + methodVisitor.visitInsn(returnValueOpDescriptor.opcode); return; } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/BuiltinTypes.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/BuiltinTypes.java index b2f323f7..191aae4c 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/BuiltinTypes.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/BuiltinTypes.java @@ -23,6 +23,7 @@ import ai.timefold.jpyinterpreter.types.collections.view.DictValueView; import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean; import ai.timefold.jpyinterpreter.types.numeric.PythonComplex; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; import ai.timefold.jpyinterpreter.types.numeric.PythonFloat; import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; import ai.timefold.jpyinterpreter.types.numeric.PythonNumber; @@ -60,6 +61,7 @@ public class BuiltinTypes { public static final PythonLikeType BOOLEAN_TYPE = new PythonLikeType("bool", PythonBoolean.class, List.of(INT_TYPE)); public static final PythonLikeType FLOAT_TYPE = new PythonLikeType("float", PythonFloat.class, List.of(NUMBER_TYPE)); public final static PythonLikeType COMPLEX_TYPE = new PythonLikeType("complex", PythonComplex.class, List.of(NUMBER_TYPE)); + public final static PythonLikeType DECIMAL_TYPE = new PythonLikeType("Decimal", PythonDecimal.class, List.of(NUMBER_TYPE)); public static final PythonLikeType STRING_TYPE = new PythonLikeType("str", PythonString.class, List.of(BASE_TYPE)); public static final PythonLikeType BYTES_TYPE = new PythonLikeType("bytes", PythonBytes.class, List.of(BASE_TYPE)); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/collections/PythonLikeTuple.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/collections/PythonLikeTuple.java index 9318011d..23b5b17f 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/collections/PythonLikeTuple.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/collections/PythonLikeTuple.java @@ -113,14 +113,14 @@ public PythonLikeTuple createNewInstance() { return new PythonLikeTuple<>(); } - public static PythonLikeTuple fromItems(PythonLikeObject... items) { - PythonLikeTuple result = new PythonLikeTuple(); + public static PythonLikeTuple fromItems(T... items) { + PythonLikeTuple result = new PythonLikeTuple<>(); Collections.addAll(result, items); return result; } - public static PythonLikeTuple fromList(List other) { - PythonLikeTuple result = new PythonLikeTuple(); + public static PythonLikeTuple fromList(List other) { + PythonLikeTuple result = new PythonLikeTuple<>(); result.addAll(other); return result; } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonDecimal.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonDecimal.java new file mode 100644 index 00000000..53182617 --- /dev/null +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonDecimal.java @@ -0,0 +1,805 @@ +package ai.timefold.jpyinterpreter.types.numeric; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import java.math.RoundingMode; +import java.util.function.BiPredicate; +import java.util.stream.Collectors; + +import ai.timefold.jpyinterpreter.PythonClassTranslator; +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.PythonOverloadImplementor; +import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject; +import ai.timefold.jpyinterpreter.types.BuiltinTypes; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.PythonNone; +import ai.timefold.jpyinterpreter.types.PythonString; +import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; +import ai.timefold.jpyinterpreter.types.errors.TypeError; +import ai.timefold.jpyinterpreter.types.errors.ValueError; +import ai.timefold.solver.core.impl.domain.solution.cloner.PlanningImmutable; + +public class PythonDecimal extends AbstractPythonLikeObject implements PythonNumber, PlanningImmutable { + public final BigDecimal value; + private static final ThreadLocal threadMathContext = + ThreadLocal.withInitial(() -> new MathContext(28, RoundingMode.HALF_EVEN)); + + static { + PythonOverloadImplementor.deferDispatchesFor(PythonDecimal::registerMethods); + } + + private static PythonLikeType registerMethods() throws NoSuchMethodException { + BuiltinTypes.DECIMAL_TYPE.setConstructor((positionalArguments, namedArguments, callerInstance) -> { + if (positionalArguments.size() == 0) { + return new PythonDecimal(BigDecimal.ZERO); + } else if (positionalArguments.size() == 1) { + return PythonDecimal.from(positionalArguments.get(0)); + } else if (positionalArguments.size() == 2) { + // TODO: Support context + throw new ValueError("context constructor not supported"); + } else { + throw new TypeError("function takes at most 2 arguments, got " + positionalArguments.size()); + } + }); + + for (var method : PythonDecimal.class.getDeclaredMethods()) { + if (method.getName().startsWith(PythonClassTranslator.JAVA_METHOD_PREFIX)) { + BuiltinTypes.DECIMAL_TYPE.addMethod( + method.getName().substring(PythonClassTranslator.JAVA_METHOD_PREFIX.length()), + method); + } + } + + return BuiltinTypes.DECIMAL_TYPE; + } + + // *************************** + // Constructors + // *************************** + public PythonDecimal(BigDecimal value) { + super(BuiltinTypes.DECIMAL_TYPE); + this.value = value; + } + + public static PythonDecimal from(PythonLikeObject value) { + if (value instanceof PythonInteger integer) { + return PythonDecimal.valueOf(integer); + } else if (value instanceof PythonFloat pythonFloat) { + return PythonDecimal.valueOf(pythonFloat); + } else if (value instanceof PythonString str) { + return PythonDecimal.valueOf(str); + } else { + throw new TypeError( + "conversion from %s to Decimal is not supported".formatted(value.$getType().getTypeName())); + } + } + + public static PythonDecimal $method$from_float(PythonFloat value) { + return new PythonDecimal(new BigDecimal(value.value, threadMathContext.get())); + } + + public static PythonDecimal valueOf(PythonInteger value) { + return new PythonDecimal(new BigDecimal(value.value, threadMathContext.get())); + } + + public static PythonDecimal valueOf(PythonFloat value) { + return new PythonDecimal(new BigDecimal(value.value, threadMathContext.get())); + } + + public static PythonDecimal valueOf(PythonString value) { + return valueOf(value.value); + } + + public static PythonDecimal valueOf(String value) { + return new PythonDecimal(new BigDecimal(value, threadMathContext.get())); + } + + // *************************** + // Interface methods + // *************************** + + @Override + public Number getValue() { + return value; + } + + @Override + public PythonString $method$__str__() { + return PythonString.valueOf(toString()); + } + + @Override + public PythonString $method$__repr__() { + return PythonString.valueOf("Decimal('%s')".formatted(value.toPlainString())); + } + + @Override + public String toString() { + return value.toPlainString(); + } + + public boolean equals(Object o) { + if (o instanceof PythonNumber number) { + return compareTo(number) == 0; + } else { + return false; + } + } + + @Override + public int hashCode() { + return $method$__hash__().value.intValue(); + } + + @Override + public PythonInteger $method$__hash__() { + var scale = value.scale(); + if (scale <= 0) { + return PythonNumber.computeHash(new PythonInteger(value.toBigInteger()), + PythonInteger.ONE); + } + var scaledValue = value.movePointRight(scale); + return PythonNumber.computeHash(new PythonInteger(scaledValue.toBigInteger()), + new PythonInteger(BigInteger.TEN.pow(scale))); + } + + // *************************** + // Unary operations + // *************************** + public PythonBoolean $method$__bool__() { + return PythonBoolean.valueOf(value.compareTo(BigDecimal.ZERO) != 0); + } + + public PythonInteger $method$__int__() { + return PythonInteger.valueOf(value.toBigInteger()); + } + + public PythonFloat $method$__float__() { + return PythonFloat.valueOf(value.doubleValue()); + } + + public PythonDecimal $method$__pos__() { + return this; + } + + public PythonDecimal $method$__neg__() { + return new PythonDecimal(value.negate()); + } + + public PythonDecimal $method$__abs__() { + return new PythonDecimal(value.abs()); + } + + // *************************** + // Binary operations + // *************************** + public PythonBoolean $method$__lt__(PythonDecimal other) { + return PythonBoolean.valueOf(value.compareTo(other.value) < 0); + } + + public PythonBoolean $method$__lt__(PythonInteger other) { + return $method$__lt__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__lt__(PythonFloat other) { + return $method$__lt__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__le__(PythonDecimal other) { + return PythonBoolean.valueOf(value.compareTo(other.value) <= 0); + } + + public PythonBoolean $method$__le__(PythonInteger other) { + return $method$__le__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__le__(PythonFloat other) { + return $method$__le__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__gt__(PythonDecimal other) { + return PythonBoolean.valueOf(value.compareTo(other.value) > 0); + } + + public PythonBoolean $method$__gt__(PythonInteger other) { + return $method$__gt__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__gt__(PythonFloat other) { + return $method$__gt__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__ge__(PythonDecimal other) { + return PythonBoolean.valueOf(value.compareTo(other.value) >= 0); + } + + public PythonBoolean $method$__ge__(PythonInteger other) { + return $method$__ge__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__ge__(PythonFloat other) { + return $method$__ge__(PythonDecimal.valueOf(other)); + } + + public PythonBoolean $method$__eq__(PythonDecimal other) { + return PythonBoolean.valueOf(value.compareTo(other.value) == 0); + } + + public PythonBoolean $method$__eq__(PythonInteger other) { + return PythonBoolean.valueOf(value.compareTo(new BigDecimal(other.value)) == 0); + } + + public PythonBoolean $method$__eq__(PythonFloat other) { + return PythonBoolean.valueOf(value.compareTo(new BigDecimal(other.value)) == 0); + } + + public PythonBoolean $method$__neq__(PythonDecimal other) { + return $method$__eq__(other).not(); + } + + public PythonBoolean $method$__neq__(PythonInteger other) { + return $method$__eq__(other).not(); + } + + public PythonBoolean $method$__neq__(PythonFloat other) { + return $method$__eq__(other).not(); + } + + public PythonDecimal $method$__add__(PythonDecimal other) { + return new PythonDecimal(value.add(other.value, threadMathContext.get())); + } + + public PythonDecimal $method$__add__(PythonInteger other) { + return $method$__add__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__radd__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__add__(this); + } + + public PythonDecimal $method$__sub__(PythonDecimal other) { + return new PythonDecimal(value.subtract(other.value, threadMathContext.get())); + } + + public PythonDecimal $method$__sub__(PythonInteger other) { + return $method$__sub__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rsub__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__sub__(this); + } + + public PythonDecimal $method$__mul__(PythonDecimal other) { + return new PythonDecimal(value.multiply(other.value, threadMathContext.get())); + } + + public PythonDecimal $method$__mul__(PythonInteger other) { + return $method$__mul__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rmul__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__mul__(this); + } + + public PythonDecimal $method$__truediv__(PythonDecimal other) { + return new PythonDecimal(value.divide(other.value, threadMathContext.get())); + } + + public PythonDecimal $method$__truediv__(PythonInteger other) { + return $method$__truediv__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rtruediv__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__truediv__(this); + } + + public PythonDecimal $method$__floordiv__(PythonDecimal other) { + var newSignNum = switch (value.signum() * other.value.signum()) { + case -1 -> BigDecimal.ONE.negate(); + case 0 -> BigDecimal.ZERO; + case 1 -> BigDecimal.ONE; + default -> throw new IllegalStateException("Unexpected signum (%d)." + .formatted(value.signum() * other.value.signum())); + }; + // Need to round toward 0, but Java floors the result, so take the absolute and + // multiply by the sign-num + return new PythonDecimal(value.abs().divideToIntegralValue(other.value.abs()) + .multiply(newSignNum, threadMathContext.get())); + } + + public PythonDecimal $method$__floordiv__(PythonInteger other) { + return $method$__floordiv__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rfloordiv__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__floordiv__(this); + } + + public PythonDecimal $method$__mod__(PythonDecimal other) { + return new PythonDecimal( + value.subtract($method$__floordiv__(other).value.multiply(other.value, threadMathContext.get()))); + } + + public PythonDecimal $method$__mod__(PythonInteger other) { + return $method$__mod__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rmod__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__mod__(this); + } + + public PythonDecimal $method$__pow__(PythonDecimal other) { + if (other.value.stripTrailingZeros().scale() <= 0) { + // other is an int + return new PythonDecimal(value.pow(other.value.intValue(), threadMathContext.get())); + } + return new PythonDecimal(new BigDecimal(Math.pow(value.doubleValue(), other.value.doubleValue()), + threadMathContext.get())); + } + + public PythonDecimal $method$__pow__(PythonInteger other) { + return $method$__pow__(PythonDecimal.valueOf(other)); + } + + public PythonDecimal $method$__rpow__(PythonInteger other) { + return PythonDecimal.valueOf(other).$method$__mod__(this); + } + + // *************************** + // Other methods + // *************************** + public PythonInteger $method$adjusted() { + // scale is the negative exponent that the big int is multiplied by + // len(unscaled) - 1 = floor(log_10(unscaled)) + // floor(log_10(unscaled)) - scale = exponent in engineering notation + return PythonInteger.valueOf(value.unscaledValue().toString().length() - 1 - value.scale()); + } + + public PythonLikeTuple $method$as_integer_ratio() { + var parts = value.divideAndRemainder(BigDecimal.ONE); + var integralPart = parts[0]; + var fractionPart = parts[1]; + if (fractionPart.compareTo(BigDecimal.ZERO) == 0) { + // No decimal part, as integer ratio = (self, 1) + return PythonLikeTuple.fromItems(PythonInteger.valueOf(integralPart.toBigInteger()), + PythonInteger.ONE); + } + var scale = fractionPart.scale(); + var scaledDenominator = BigDecimal.ONE.movePointRight(scale).toBigInteger(); + var scaledIntegralPart = integralPart.movePointRight(scale).toBigInteger(); + var scaledFractionPart = fractionPart.movePointRight(scale).toBigInteger(); + var scaledNumerator = scaledIntegralPart.add(scaledFractionPart); + var commonFactors = scaledNumerator.gcd(scaledDenominator); + var reducedNumerator = scaledNumerator.divide(commonFactors); + var reducedDenominator = scaledDenominator.divide(commonFactors); + return PythonLikeTuple.fromItems(PythonInteger.valueOf(reducedNumerator), + PythonInteger.valueOf(reducedDenominator)); + } + + public PythonLikeTuple $method$as_tuple() { + // TODO: Use named tuple + return PythonLikeTuple.fromItems(PythonInteger.valueOf(value.signum() >= 0 ? 0 : 1), + value.unscaledValue().abs().toString() + .chars() + .mapToObj(digit -> PythonInteger.valueOf(digit - '0')) + .collect(Collectors.toCollection(PythonLikeTuple::new)), + PythonInteger.valueOf(-value.scale())); + } + + public PythonDecimal $method$canonical() { + return this; + } + + public PythonDecimal $method$compare(PythonDecimal other) { + return new PythonDecimal(BigDecimal.valueOf(value.compareTo(other.value))); + } + + public PythonDecimal $method$compare_signal(PythonDecimal other) { + return $method$compare(other); + } + + // See https://speleotrove.com/decimal/damisc.html#refcotot + public PythonDecimal $method$compare_total(PythonDecimal other) { + var result = $method$compare(other); + if (result.value.compareTo(BigDecimal.ZERO) != 0) { + return result; + } + var sigNum = value.scale() - other.value.scale(); + if (sigNum < 0) { + return new PythonDecimal(BigDecimal.ONE); + } + if (sigNum > 0) { + return new PythonDecimal(BigDecimal.valueOf(-1L)); + } + return result; // Can only reach here if result == BigDecimal.ZERO + } + + public PythonDecimal $method$compare_total_mag(PythonDecimal other) { + return new PythonDecimal(value.abs()).$method$compare_total(new PythonDecimal(other.value.abs())); + } + + public PythonDecimal $method$conjugate() { + return this; + } + + public PythonDecimal $method$copy_abs() { + return new PythonDecimal(value.abs()); + } + + public PythonDecimal $method$copy_negate() { + return new PythonDecimal(value.negate()); + } + + public PythonDecimal $method$copy_sign(PythonDecimal other) { + var signChange = value.signum() * other.value.signum(); + var multiplier = switch (signChange) { + case -1 -> BigDecimal.valueOf(-1); + case 0, 1 -> BigDecimal.ONE; // Note: there also a -0 BigDecimal in Python. + default -> throw new IllegalStateException("Unexpected signum (%d).".formatted(signChange)); + }; + return new PythonDecimal(value.multiply(multiplier)); + } + + private static BigDecimal getEToPrecision(int precision) { + return getESubPowerToPrecision(BigDecimal.ONE, precision); + } + + private static BigDecimal getESubPowerToPrecision(BigDecimal value, int precision) { + // Uses taylor series e^x = sum(x^n/n! for n in 0...infinity) + var numerator = BigDecimal.ONE; + var denominator = BigDecimal.ONE; + var total = BigDecimal.ZERO; + var extendedContext = new MathContext(precision + 8, RoundingMode.HALF_EVEN); + for (var index = 1; index < 100; index++) { + total = total.add(numerator.divide(denominator, extendedContext), extendedContext); + numerator = numerator.multiply(value); + denominator = denominator.multiply(BigDecimal.valueOf(index)); + } + return total; + } + + private static BigDecimal getEPower(BigDecimal value, int precision) { + var extendedPrecision = precision + 8; + + // Do e^x = e^(int(x))*e^(frac(x)) + var e = getEToPrecision(extendedPrecision); + var integralPart = value.toBigInteger().intValue(); + var fractionPart = value.remainder(BigDecimal.ONE); + return e.pow(integralPart).multiply(getESubPowerToPrecision(fractionPart, extendedPrecision), + threadMathContext.get()); + } + + public PythonDecimal $method$exp() { + var precision = threadMathContext.get().getPrecision(); + return new PythonDecimal(getEPower(value, precision)); + } + + public PythonDecimal $method$fma(PythonDecimal multiplier, PythonDecimal summand) { + return new PythonDecimal(this.value.multiply(multiplier.value).add(summand.value, threadMathContext.get())); + } + + public PythonDecimal $method$fma(PythonInteger multiplier, PythonDecimal summand) { + return $method$fma(PythonDecimal.valueOf(multiplier), summand); + } + + public PythonDecimal $method$fma(PythonDecimal multiplier, PythonInteger summand) { + return $method$fma(multiplier, PythonDecimal.valueOf(summand)); + } + + public PythonDecimal $method$fma(PythonInteger multiplier, PythonInteger summand) { + return $method$fma(PythonDecimal.valueOf(multiplier), PythonDecimal.valueOf(summand)); + } + + public PythonBoolean $method$is_canonical() { + return PythonBoolean.TRUE; + } + + public PythonBoolean $method$is_finite() { + // We don't support infinite or NaN Decimals + return PythonBoolean.TRUE; + } + + public PythonBoolean $method$is_infinite() { + // We don't support infinite or NaN Decimals + return PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_nan() { + // We don't support infinite or NaN Decimals + return PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_normal() { + // We don't support subnormal Decimals + return PythonBoolean.TRUE; + } + + public PythonBoolean $method$is_qnan() { + // We don't support infinite or NaN Decimals + return PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_signed() { + // Same as `isNegative()` + return value.compareTo(BigDecimal.ZERO) < 0 ? PythonBoolean.TRUE : PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_snan() { + // We don't support infinite or NaN Decimals + return PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_subnormal() { + // We don't support subnormal Decimals + return PythonBoolean.FALSE; + } + + public PythonBoolean $method$is_zero() { + return value.compareTo(BigDecimal.ZERO) == 0 ? PythonBoolean.TRUE : PythonBoolean.FALSE; + } + + public PythonDecimal $method$ln() { + return new PythonDecimal(new BigDecimal( + Math.log(value.doubleValue()), + threadMathContext.get())); + } + + public PythonDecimal $method$log10() { + return new PythonDecimal(new BigDecimal( + Math.log10(value.doubleValue()), + threadMathContext.get())); + } + + public PythonDecimal $method$logb() { + // Finds the exponent b in a * 10^b, where a in [1, 10) + return new PythonDecimal(BigDecimal.valueOf(value.precision() - value.scale() - 1)); + } + + private static PythonDecimal logicalOp(BiPredicate op, + BigDecimal a, BigDecimal b) { + if (a.scale() < 0 || b.scale() < 0) { + throw new ValueError("Invalid Operation: both operands must be positive integers consisting of 1's and 0's"); + } + var aText = a.toPlainString(); + var bText = b.toPlainString(); + if (aText.length() > bText.length()) { + bText = "0".repeat(aText.length() - bText.length()) + bText; + } else if (aText.length() < bText.length()) { + aText = "0".repeat(bText.length() - aText.length()) + aText; + } + + var digitCount = aText.length(); + var result = new StringBuilder(); + for (int i = 0; i < digitCount; i++) { + var aBit = switch (aText.charAt(i)) { + case '0' -> false; + case '1' -> true; + default -> throw new ValueError(("Invalid Operation: first operand (%s) is not a positive integer " + + "consisting of 1's and 0's").formatted(a)); + }; + var bBit = switch (bText.charAt(i)) { + case '0' -> false; + case '1' -> true; + default -> throw new ValueError(("Invalid Operation: second operand (%s) is not a positive integer " + + "consisting of 1's and 0's").formatted(b)); + }; + result.append(op.test(aBit, bBit) ? '1' : '0'); + } + return new PythonDecimal(new BigDecimal(result.toString())); + } + + public PythonDecimal $method$logical_and(PythonDecimal other) { + return logicalOp(Boolean::logicalAnd, this.value, other.value); + } + + public PythonDecimal $method$logical_or(PythonDecimal other) { + return logicalOp(Boolean::logicalOr, this.value, other.value); + } + + public PythonDecimal $method$logical_xor(PythonDecimal other) { + return logicalOp(Boolean::logicalXor, this.value, other.value); + } + + public PythonDecimal $method$logical_invert() { + return logicalOp(Boolean::logicalXor, this.value, new BigDecimal("1".repeat(threadMathContext.get().getPrecision()))); + } + + public PythonDecimal $method$max(PythonDecimal other) { + return new PythonDecimal(value.max(other.value)); + } + + public PythonDecimal $method$max_mag(PythonDecimal other) { + var result = $method$compare_total_mag(other).value.intValue(); + if (result >= 0) { + return this; + } else { + return other; + } + } + + public PythonDecimal $method$min(PythonDecimal other) { + return new PythonDecimal(value.min(other.value)); + } + + public PythonDecimal $method$min_mag(PythonDecimal other) { + var result = $method$compare_total_mag(other).value.intValue(); + if (result <= 0) { + return this; + } else { + return other; + } + } + + private BigDecimal getLastPlaceUnit(MathContext mathContext) { + int remainingPrecision = mathContext.getPrecision() - value.stripTrailingZeros().precision(); + return BigDecimal.ONE.movePointLeft(value.scale() + remainingPrecision + 1); + } + + public PythonDecimal $method$next_minus() { + var context = new MathContext(threadMathContext.get().getPrecision(), RoundingMode.FLOOR); + var lastPlaceUnit = getLastPlaceUnit(context); + return new PythonDecimal(value.subtract(lastPlaceUnit, context)); + } + + public PythonDecimal $method$next_plus() { + var context = new MathContext(threadMathContext.get().getPrecision(), RoundingMode.CEILING); + var lastPlaceUnit = getLastPlaceUnit(context); + return new PythonDecimal(value.add(lastPlaceUnit, context)); + } + + public PythonDecimal $method$next_toward(PythonDecimal other) { + var result = $method$compare(other).value.intValue(); + switch (result) { + case -1 -> { + return $method$next_plus(); + } + case 1 -> { + return $method$next_minus(); + } + case 0 -> { + return this; + } + default -> throw new IllegalStateException(); + } + } + + public PythonDecimal $method$normalize() { + return new PythonDecimal(value.stripTrailingZeros()); + } + + public PythonString $method$number_class() { + var result = value.compareTo(BigDecimal.ZERO); + if (result < 0) { + return PythonString.valueOf("-Normal"); + } else if (result > 0) { + return PythonString.valueOf("+Normal"); + } else { + return PythonString.valueOf("+Zero"); + } + } + + public PythonDecimal $method$quantize(PythonDecimal other) { + return new PythonDecimal(value.setScale(other.value.scale(), threadMathContext.get().getRoundingMode())); + } + + public PythonDecimal $method$radix() { + return new PythonDecimal(BigDecimal.TEN); + } + + public PythonDecimal $method$remainder_near(PythonDecimal other) { + var floorQuotient = $method$__floordiv__(other).value; + var firstRemainder = new PythonDecimal(value.subtract(floorQuotient.multiply(other.value, threadMathContext.get()))); + var secondRemainder = other.$method$__sub__(firstRemainder).$method$__neg__(); + var comparison = firstRemainder.$method$compare_total_mag(secondRemainder).value.intValue(); + return switch (comparison) { + case -1 -> firstRemainder; + case 1 -> secondRemainder; + case 0 -> { + if (floorQuotient.longValue() % 2 == 0) { + yield firstRemainder; + } else { + yield secondRemainder; + } + } + default -> throw new IllegalStateException(); + }; + } + + public PythonDecimal $method$rotate(PythonInteger other) { + var amount = -other.value.intValue(); + if (amount == 0) { + return this; + } + var precision = threadMathContext.get().getPrecision(); + if (Math.abs(amount) > precision) { + throw new ValueError("other must be between -%d and %d".formatted(amount, amount)); + } + var digitString = value.unscaledValue().toString(); + digitString = "0".repeat(precision - digitString.length()) + digitString; + if (amount < 0) { + // Turn a rotate right to a rotate left + amount = precision + amount; + } + var rotatedResult = digitString.substring(precision - amount, precision) + digitString.substring(0, precision - amount); + var unscaledResult = new BigInteger(rotatedResult); + return new PythonDecimal(new BigDecimal(unscaledResult, value.scale())); + } + + public PythonBoolean $method$same_quantum(PythonDecimal other) { + return PythonBoolean.valueOf( + value.ulp().compareTo(other.value.ulp()) == 0); + } + + public PythonDecimal $method$scaleb(PythonInteger other) { + return new PythonDecimal(value.movePointRight(other.value.intValue())); + } + + public PythonDecimal $method$shift(PythonInteger other) { + var amount = other.value.intValue(); + if (amount == 0) { + return this; + } + var precision = threadMathContext.get().getPrecision(); + if (Math.abs(amount) > precision) { + throw new ValueError("other must be between -%d and %d".formatted(amount, amount)); + } + return new PythonDecimal(value.movePointLeft(amount)); + } + + public PythonDecimal $method$sqrt() { + return new PythonDecimal(value.sqrt(threadMathContext.get())); + } + + public PythonString $method$to_eng_string() { + return new PythonString(value.toEngineeringString()); + } + + public PythonInteger $method$to_integral() { + return $method$to_integral_value(); + } + + public PythonInteger $method$to_integral_exact() { + // TODO: set signals in the context object + return $method$to_integral_value(); + } + + public PythonInteger $method$to_integral_value() { + return new PythonInteger(value.divideToIntegralValue(BigDecimal.ONE, threadMathContext.get()).toBigInteger()); + } + + public PythonInteger $method$__round__() { + // Round without an argument ignores thread math context + var first = value.toBigInteger(); + var second = first.add(BigInteger.ONE); + var firstDiff = value.subtract(new BigDecimal(first)); + var secondDiff = new BigDecimal(second).subtract(value); + var comparison = firstDiff.compareTo(secondDiff); + return switch (comparison) { + case -1 -> new PythonInteger(first); + case 1 -> new PythonInteger(second); + case 0 -> { + if (first.intValue() % 2 == 0) { + yield new PythonInteger(first); + } else { + yield new PythonInteger(second); + } + } + default -> throw new IllegalStateException(); + }; + } + + public PythonLikeObject $method$__round__(PythonLikeObject maybePrecision) { + if (maybePrecision instanceof PythonNone) { + return $method$__round__(); + } + if (!(maybePrecision instanceof PythonInteger precision)) { + throw new ValueError("ndigits must be an integer"); + } + // Round with an argument uses thread math context + var integralPart = value.toBigInteger(); + return new PythonDecimal(value.round(new MathContext( + integralPart.toString().length() + precision.value.intValue(), + threadMathContext.get().getRoundingMode()))); + } +} diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonFloat.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonFloat.java index bf584771..91ac4356 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonFloat.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonFloat.java @@ -38,11 +38,6 @@ public class PythonFloat extends AbstractPythonLikeObject implements PythonNumbe PythonOverloadImplementor.deferDispatchesFor(PythonFloat::registerMethods); } - public PythonFloat(double value) { - super(BuiltinTypes.FLOAT_TYPE); - this.value = value; - } - private static PythonLikeType registerMethods() throws NoSuchMethodException { PythonLikeComparable.setup(BuiltinTypes.FLOAT_TYPE); @@ -51,16 +46,7 @@ private static PythonLikeType registerMethods() throws NoSuchMethodException { if (positionalArguments.isEmpty()) { return new PythonFloat(0.0); } else if (positionalArguments.size() == 1) { - PythonLikeObject value = positionalArguments.get(0); - if (value instanceof PythonInteger) { - return ((PythonInteger) value).asFloat(); - } else if (value instanceof PythonFloat) { - return value; - } else { - PythonLikeType valueType = value.$getType(); - PythonLikeFunction asFloatFunction = (PythonLikeFunction) (valueType.$getAttributeOrError("__float__")); - return asFloatFunction.$call(List.of(value), Map.of(), null); - } + return PythonFloat.from(positionalArguments.get(0)); } else { throw new ValueError("float takes 0 or 1 arguments, got " + positionalArguments.size()); } @@ -172,6 +158,36 @@ private static PythonLikeType registerMethods() throws NoSuchMethodException { return BuiltinTypes.FLOAT_TYPE; } + public PythonFloat(double value) { + super(BuiltinTypes.FLOAT_TYPE); + this.value = value; + } + + public static PythonFloat from(PythonLikeObject value) { + if (value instanceof PythonInteger integer) { + return integer.asFloat(); + } else if (value instanceof PythonFloat) { + return (PythonFloat) value; + } else if (value instanceof PythonString str) { + try { + var literal = switch (str.value.toLowerCase()) { + case "nan", "+nan" -> "+NaN"; + case "-nan" -> "-NaN"; + case "inf", "+inf", "infinity" -> "+Infinity"; + case "-inf", "-infinity" -> "-Infinity"; + default -> str.value; + }; + return new PythonFloat(Double.parseDouble(literal)); + } catch (NumberFormatException e) { + throw new ValueError("invalid literal for float(): %s".formatted(value)); + } + } else { + PythonLikeType valueType = value.$getType(); + PythonLikeFunction asFloatFunction = (PythonLikeFunction) (valueType.$getAttributeOrError("__float__")); + return (PythonFloat) asFloatFunction.$call(List.of(value), Map.of(), null); + } + } + @Override public Number getValue() { return value; @@ -218,12 +234,10 @@ public String toString() { @Override public boolean equals(Object o) { - if (o instanceof Number) { - return ((Number) o).doubleValue() == value; - } else if (o instanceof PythonFloat) { - return ((PythonFloat) o).value == value; - } else if (o instanceof PythonInteger) { - return ((PythonInteger) o).getValue().doubleValue() == value; + if (o instanceof Number number) { + return number.doubleValue() == value; + } else if (o instanceof PythonNumber number) { + return compareTo(number) == 0; } else { return false; } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonInteger.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonInteger.java index b91b0e2c..09f5f555 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonInteger.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonInteger.java @@ -44,24 +44,17 @@ public class PythonInteger extends AbstractPythonLikeObject implements PythonNum private static PythonLikeType registerMethods() throws NoSuchMethodException { // Constructor - BuiltinTypes.INT_TYPE.setConstructor(((positionalArguments, namedArguments, callerInstance) -> { - if (positionalArguments.size() == 0) { + BuiltinTypes.INT_TYPE.setConstructor((positionalArguments, namedArguments, callerInstance) -> { + if (positionalArguments.isEmpty()) { return PythonInteger.valueOf(0); } else if (positionalArguments.size() == 1) { - PythonLikeObject value = positionalArguments.get(0); - if (value instanceof PythonInteger) { - return value; - } else if (value instanceof PythonFloat) { - return ((PythonFloat) value).asInteger(); - } else { - PythonLikeType valueType = value.$getType(); - PythonLikeFunction asIntFunction = (PythonLikeFunction) (valueType.$getAttributeOrError("__int__")); - return asIntFunction.$call(List.of(value), Map.of(), null); - } + return PythonInteger.from(positionalArguments.get(0)); + } else if (positionalArguments.size() == 2) { + return PythonInteger.fromUsingBase(positionalArguments.get(0), positionalArguments.get(1)); } else { - throw new ValueError("int expects 0 or 1 arguments, got " + positionalArguments.size()); + throw new TypeError("int takes at most 2 arguments, got " + positionalArguments.size()); } - })); + }); // Unary BuiltinTypes.INT_TYPE.addUnaryMethod(PythonUnaryOperator.AS_BOOLEAN, PythonInteger.class.getMethod("asBoolean")); BuiltinTypes.INT_TYPE.addUnaryMethod(PythonUnaryOperator.AS_INT, PythonInteger.class.getMethod("asInteger")); @@ -232,6 +225,39 @@ public PythonInteger(BigInteger value) { this.value = value; } + private static PythonInteger from(PythonLikeObject value) { + if (value instanceof PythonInteger integer) { + return integer; + } else if (value instanceof PythonFloat pythonFloat) { + return pythonFloat.asInteger(); + } else if (value instanceof PythonString str) { + try { + return new PythonInteger(new BigInteger(str.value)); + } catch (NumberFormatException e) { + throw new ValueError("invalid literal for int() with base 10: %s".formatted(value)); + } + } else { + PythonLikeType valueType = value.$getType(); + PythonLikeFunction asIntFunction = (PythonLikeFunction) (valueType.$getAttributeOrError("__int__")); + return (PythonInteger) asIntFunction.$call(List.of(value), Map.of(), null); + } + } + + private static PythonInteger fromUsingBase(PythonLikeObject value, PythonLikeObject base) { + if (value instanceof PythonString str && base instanceof PythonInteger baseInt) { + try { + return new PythonInteger(new BigInteger(str.value, baseInt.value.intValue())); + } catch (NumberFormatException e) { + throw new ValueError( + "invalid literal for int() with base %d: %s".formatted(baseInt.value.intValue(), value)); + } + } else { + PythonLikeType valueType = value.$getType(); + PythonLikeFunction asIntFunction = (PythonLikeFunction) (valueType.$getAttributeOrError("__int__")); + return (PythonInteger) asIntFunction.$call(List.of(value, base), Map.of(), null); + } + } + @Override public Number getValue() { return value; @@ -256,12 +282,10 @@ public byte asByte() { @Override public boolean equals(Object o) { - if (o instanceof Number) { - return value.equals(BigInteger.valueOf(((Number) o).longValue())); - } else if (o instanceof PythonInteger) { - return ((PythonInteger) o).value.equals(value); - } else if (o instanceof PythonFloat) { - return value.doubleValue() == ((PythonFloat) o).value; + if (o instanceof Number number) { + return value.equals(BigInteger.valueOf(number.longValue())); + } else if (o instanceof PythonNumber number) { + return compareTo(number) == 0; } else { return false; } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonNumber.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonNumber.java index 2fa84dec..3c5cb4f8 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonNumber.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonNumber.java @@ -1,5 +1,6 @@ package ai.timefold.jpyinterpreter.types.numeric; +import java.math.BigDecimal; import java.math.BigInteger; import ai.timefold.jpyinterpreter.PythonLikeObject; @@ -20,22 +21,29 @@ default int compareTo(PythonNumber pythonNumber) { Number value = getValue(); Number otherValue = pythonNumber.getValue(); - if (value instanceof BigInteger) { - if (otherValue instanceof BigInteger) { - return ((BigInteger) value).compareTo((BigInteger) otherValue); - } else { - return Double.compare(value.longValue(), otherValue.doubleValue()); + if (value instanceof BigInteger self) { + if (otherValue instanceof BigInteger other) { + return self.compareTo(other); + } else if (otherValue instanceof BigDecimal other) { + return new BigDecimal(self).compareTo(other); + } + } + if (value instanceof BigDecimal self) { + if (otherValue instanceof BigDecimal other) { + return self.compareTo(other); + } else if (otherValue instanceof BigInteger other) { + return self.compareTo(new BigDecimal(other)); } - } else { - return Double.compare(value.doubleValue(), otherValue.doubleValue()); } + // If comparing against a float, convert both arguments to float + return Double.compare(value.doubleValue(), otherValue.doubleValue()); } static PythonInteger computeHash(PythonInteger numerator, PythonInteger denominator) { PythonInteger P = MODULUS; // Remove common factors of P. (Unnecessary if m and n already coprime.) - while (numerator.modulo(P) == PythonInteger.ZERO && denominator.modulo(P) == PythonInteger.ZERO) { + while (numerator.modulo(P).equals(PythonInteger.ZERO) && denominator.modulo(P).equals(PythonInteger.ZERO)) { numerator = numerator.floorDivide(P); denominator = denominator.floorDivide(P); } diff --git a/jpyinterpreter/src/main/python/conversions.py b/jpyinterpreter/src/main/python/conversions.py index 45528be5..a719f3c2 100644 --- a/jpyinterpreter/src/main/python/conversions.py +++ b/jpyinterpreter/src/main/python/conversions.py @@ -132,6 +132,7 @@ def init_type_to_compiled_java_class(): import ai.timefold.jpyinterpreter.types.datetime as java_datetime_types import datetime import builtins + import decimal if len(type_to_compiled_java_class) > 0: return @@ -145,6 +146,7 @@ def init_type_to_compiled_java_class(): type_to_compiled_java_class[float] = BuiltinTypes.FLOAT_TYPE type_to_compiled_java_class[complex] = BuiltinTypes.COMPLEX_TYPE type_to_compiled_java_class[bool] = BuiltinTypes.BOOLEAN_TYPE + type_to_compiled_java_class[decimal.Decimal] = BuiltinTypes.DECIMAL_TYPE type_to_compiled_java_class[type(None)] = BuiltinTypes.NONE_TYPE type_to_compiled_java_class[str] = BuiltinTypes.STRING_TYPE @@ -370,12 +372,14 @@ def convert_to_java_python_like_object(value, instance_map=None): from java.util import HashMap from java.math import BigInteger from types import ModuleType + from decimal import Decimal from ai.timefold.jpyinterpreter import PythonLikeObject, CPythonBackedPythonInterpreter from ai.timefold.jpyinterpreter.types import PythonString, PythonBytes, PythonByteArray, PythonNone, \ PythonModule, PythonSlice, PythonRange, NotImplemented as JavaNotImplemented from ai.timefold.jpyinterpreter.types.collections import PythonLikeList, PythonLikeTuple, PythonLikeSet, \ PythonLikeFrozenSet, PythonLikeDict - from ai.timefold.jpyinterpreter.types.numeric import PythonInteger, PythonFloat, PythonBoolean, PythonComplex + from ai.timefold.jpyinterpreter.types.numeric import PythonInteger, PythonFloat, PythonBoolean, PythonComplex, \ + PythonDecimal from ai.timefold.jpyinterpreter.types.wrappers import PythonObjectWrapper, CPythonType, OpaquePythonReference if instance_map is None: @@ -400,6 +404,10 @@ def convert_to_java_python_like_object(value, instance_map=None): out = PythonFloat.valueOf(JDouble(value)) put_in_instance_map(instance_map, value, out) return out + elif isinstance(value, Decimal): + out = PythonDecimal.valueOf(str(value)) + put_in_instance_map(instance_map, value, out) + return out elif isinstance(value, complex): out = PythonComplex.valueOf(convert_to_java_python_like_object(value.real, instance_map), convert_to_java_python_like_object(value.imag, instance_map)) @@ -519,10 +527,12 @@ def unwrap_python_like_object(python_like_object, clone_map=None, default=NotImp PythonModule, PythonSlice, PythonRange, CPythonBackedPythonLikeObject, PythonLikeType, PythonLikeGenericType, \ NotImplemented as JavaNotImplemented, PythonCell, PythonLikeFunction from ai.timefold.jpyinterpreter.types.collections import PythonLikeTuple, PythonLikeFrozenSet - from ai.timefold.jpyinterpreter.types.numeric import PythonInteger, PythonFloat, PythonBoolean, PythonComplex + from ai.timefold.jpyinterpreter.types.numeric import PythonInteger, PythonFloat, PythonBoolean, PythonComplex, \ + PythonDecimal from ai.timefold.jpyinterpreter.types.wrappers import JavaObjectWrapper, PythonObjectWrapper, CPythonType, \ OpaquePythonReference from types import CellType + from decimal import Decimal if clone_map is None: clone_map = PythonCloneMap(IdentityHashMap(), dict()) @@ -552,6 +562,8 @@ def unwrap_python_like_object(python_like_object, clone_map=None, default=NotImp return clone_map.add_clone(python_like_object, python_like_object == PythonBoolean.TRUE) elif isinstance(python_like_object, PythonInteger): return clone_map.add_clone(python_like_object, int(python_like_object.getValue().toString(16), 16)) + elif isinstance(python_like_object, PythonDecimal): + return clone_map.add_clone(python_like_object, Decimal(str(python_like_object))) elif isinstance(python_like_object, PythonComplex): real = unwrap_python_like_object(python_like_object.getReal(), clone_map, default) imaginary = unwrap_python_like_object(python_like_object.getImaginary(), clone_map, default) diff --git a/jpyinterpreter/tests/test_builtins.py b/jpyinterpreter/tests/test_builtins.py index 23c9afb2..d89b962b 100644 --- a/jpyinterpreter/tests/test_builtins.py +++ b/jpyinterpreter/tests/test_builtins.py @@ -229,11 +229,23 @@ def my_function(function: Callable[[any], bool], iterable: Iterable) -> tuple: def test_float(): + import math + def my_function(x: any) -> float: return float(x) verifier = verifier_for(my_function) verifier.verify(10, expected_result=10.0) + verifier.verify('1.0', expected_result=1.0) + verifier.verify_property('nan', predicate=math.isnan) + verifier.verify_property('NaN', predicate=math.isnan) + verifier.verify_property('-nan', predicate=math.isnan) + verifier.verify_property('-NaN', predicate=math.isnan) + verifier.verify('inf', expected_result=float('inf')) + verifier.verify('INF', expected_result=float('inf')) + verifier.verify('-inf', expected_result=float('-inf')) + verifier.verify('infinity', expected_result=float('inf')) + verifier.verify('-infinity', expected_result=float('-inf')) def test_format(): @@ -324,6 +336,7 @@ def my_function(x: any) -> int: verifier = verifier_for(my_function) verifier.verify(1.5, expected_result=1) verifier.verify(1.0, expected_result=1) + verifier.verify('2', expected_result=2) def test_isinstance(): diff --git a/jpyinterpreter/tests/test_decimal.py b/jpyinterpreter/tests/test_decimal.py new file mode 100644 index 00000000..d513c5fb --- /dev/null +++ b/jpyinterpreter/tests/test_decimal.py @@ -0,0 +1,782 @@ +from .conftest import verifier_for +from decimal import Decimal +from typing import Callable + + +def around(a: Decimal) -> Callable[[Decimal], bool]: + def predicate(b: Decimal) -> bool: + return abs(a - b) < 0.00001 + return predicate + + +def test_add(): + def decimal_add(a: Decimal, b: Decimal) -> Decimal: + return a + b + + def int_add(a: Decimal, b: int) -> Decimal: + return a + b + + decimal_add_verifier = verifier_for(decimal_add) + int_add_verifier = verifier_for(int_add) + + decimal_add_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(2)) + decimal_add_verifier.verify(Decimal(1), Decimal(-1), expected_result=Decimal(0)) + decimal_add_verifier.verify(Decimal(-1), Decimal(1), expected_result=Decimal(0)) + decimal_add_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(1)) + decimal_add_verifier.verify(Decimal('1.5'), Decimal('1.5'), expected_result=Decimal('3.0')) + + int_add_verifier.verify(Decimal(1), 1, expected_result=Decimal(2)) + int_add_verifier.verify(Decimal(1), -1, expected_result=Decimal(0)) + int_add_verifier.verify(Decimal(-1), 1, expected_result=Decimal(0)) + int_add_verifier.verify(Decimal(0), 1, expected_result=Decimal(1)) + int_add_verifier.verify(Decimal('1.5'), 1, expected_result=Decimal('2.5')) + + +def test_sub(): + def decimal_sub(a: Decimal, b: Decimal) -> Decimal: + return a - b + + def int_sub(a: Decimal, b: int) -> Decimal: + return a - b + + decimal_sub_verifier = verifier_for(decimal_sub) + int_sub_verifier = verifier_for(int_sub) + + decimal_sub_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(0)) + decimal_sub_verifier.verify(Decimal(1), Decimal(-1), expected_result=Decimal(2)) + decimal_sub_verifier.verify(Decimal(-1), Decimal(1), expected_result=Decimal(-2)) + decimal_sub_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(-1)) + decimal_sub_verifier.verify(Decimal('1.5'), Decimal('1.5'), expected_result=Decimal(0)) + + int_sub_verifier.verify(Decimal(1), 1, expected_result=Decimal(0)) + int_sub_verifier.verify(Decimal(1), -1, expected_result=Decimal(2)) + int_sub_verifier.verify(Decimal(-1), 1, expected_result=Decimal(-2)) + int_sub_verifier.verify(Decimal(0), 1, expected_result=Decimal(-1)) + int_sub_verifier.verify(Decimal('1.5'), 1, expected_result=Decimal('0.5')) + + +def test_multiply(): + def decimal_multiply(a: Decimal, b: Decimal) -> Decimal: + return a * b + + def int_multiply(a: Decimal, b: int) -> Decimal: + return a * b + + decimal_multiply_verifier = verifier_for(decimal_multiply) + int_multiply_verifier = verifier_for(int_multiply) + + decimal_multiply_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(1)) + decimal_multiply_verifier.verify(Decimal(1), Decimal(-1), expected_result=Decimal(-1)) + decimal_multiply_verifier.verify(Decimal(-1), Decimal(1), expected_result=Decimal(-1)) + decimal_multiply_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(0)) + decimal_multiply_verifier.verify(Decimal('1.5'), Decimal('1.5'), expected_result=Decimal('2.25')) + + int_multiply_verifier.verify(Decimal(1), 1, expected_result=Decimal(1)) + int_multiply_verifier.verify(Decimal(1), -1, expected_result=Decimal(-1)) + int_multiply_verifier.verify(Decimal(-1), 1, expected_result=Decimal(-1)) + int_multiply_verifier.verify(Decimal(0), 1, expected_result=Decimal(0)) + int_multiply_verifier.verify(Decimal('1.5'), 2, expected_result=Decimal('3.0')) + + +def test_truediv(): + def decimal_truediv(a: Decimal, b: Decimal) -> Decimal: + return a / b + + def int_truediv(a: Decimal, b: int) -> Decimal: + return a / b + + decimal_truediv_verifier = verifier_for(decimal_truediv) + int_truediv_verifier = verifier_for(int_truediv) + + decimal_truediv_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(1)) + decimal_truediv_verifier.verify(Decimal(1), Decimal(-1), expected_result=Decimal(-1)) + decimal_truediv_verifier.verify(Decimal(-1), Decimal(1), expected_result=Decimal(-1)) + decimal_truediv_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(0)) + decimal_truediv_verifier.verify(Decimal(3), Decimal(2), expected_result=Decimal('1.5')) + + int_truediv_verifier.verify(Decimal(1), 1, expected_result=Decimal(1)) + int_truediv_verifier.verify(Decimal(1), -1, expected_result=Decimal(-1)) + int_truediv_verifier.verify(Decimal(-1), 1, expected_result=Decimal(-1)) + int_truediv_verifier.verify(Decimal(0), 1, expected_result=Decimal(0)) + int_truediv_verifier.verify(Decimal(3), 2, expected_result=Decimal('1.5')) + + +def test_floordiv(): + def decimal_floordiv(a: Decimal, b: Decimal) -> Decimal: + return a // b + + def int_floordiv(a: Decimal, b: int) -> Decimal: + return a // b + + decimal_floordiv_verifier = verifier_for(decimal_floordiv) + int_floordiv_verifier = verifier_for(int_floordiv) + + decimal_floordiv_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(1)) + decimal_floordiv_verifier.verify(Decimal(1), Decimal(-1), expected_result=Decimal(-1)) + decimal_floordiv_verifier.verify(Decimal(-1), Decimal(1), expected_result=Decimal(-1)) + decimal_floordiv_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(0)) + decimal_floordiv_verifier.verify(Decimal(-7), Decimal(4), expected_result=Decimal('-1')) + + int_floordiv_verifier.verify(Decimal(1), 1, expected_result=Decimal(1)) + int_floordiv_verifier.verify(Decimal(1), -1, expected_result=Decimal(-1)) + int_floordiv_verifier.verify(Decimal(-1), 1, expected_result=Decimal(-1)) + int_floordiv_verifier.verify(Decimal(0), 1, expected_result=Decimal(0)) + int_floordiv_verifier.verify(Decimal(3), 2, expected_result=Decimal(1)) + + +def test_mod(): + def decimal_mod(a: Decimal, b: Decimal) -> Decimal: + return a % b + + def int_mod(a: Decimal, b: int) -> Decimal: + return a % b + + decimal_mod_verifier = verifier_for(decimal_mod) + int_mod_verifier = verifier_for(int_mod) + + decimal_mod_verifier.verify(Decimal(-7), Decimal(4), expected_result=Decimal(-3)) + decimal_mod_verifier.verify(Decimal(0), Decimal(1), expected_result=Decimal(0)) + decimal_mod_verifier.verify(Decimal(3), Decimal(2), expected_result=Decimal('1')) + decimal_mod_verifier.verify(Decimal('3.5'), Decimal(2), expected_result=Decimal('1.5')) + + int_mod_verifier.verify(Decimal(1), 1, expected_result=Decimal(0)) + int_mod_verifier.verify(Decimal('3.5'), 2, expected_result=Decimal('1.5')) + int_mod_verifier.verify(Decimal(3), 2, expected_result=Decimal(1)) + + +def test_negate(): + def negate(x: Decimal) -> Decimal: + return -x + + negate_verifier = verifier_for(negate) + + negate_verifier.verify(Decimal(1), expected_result=Decimal(-1)) + negate_verifier.verify(Decimal(-1), expected_result=Decimal(1)) + + +def test_pos(): + def pos(x: Decimal) -> Decimal: + return +x + + pos_verifier = verifier_for(pos) + + pos_verifier.verify(Decimal(1), expected_result=Decimal(1)) + pos_verifier.verify(Decimal(-1), expected_result=Decimal(-1)) + + +def test_abs(): + def decimal_abs(x: Decimal) -> Decimal: + return abs(x) + + abs_verifier = verifier_for(decimal_abs) + + abs_verifier.verify(Decimal(1), expected_result=Decimal(1)) + abs_verifier.verify(Decimal(-1), expected_result=Decimal(1)) + + +def test_pow(): + def decimal_pow(a: Decimal, b: Decimal) -> Decimal: + return a ** b + + def int_pow(a: Decimal, b: int) -> Decimal: + return a ** b + + decimal_pow_verifier = verifier_for(decimal_pow) + int_pow_verifier = verifier_for(int_pow) + + decimal_pow_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(1)) + decimal_pow_verifier.verify(Decimal(2), Decimal(2), expected_result=Decimal(4)) + decimal_pow_verifier.verify(Decimal(3), Decimal(2), expected_result=Decimal(9)) + decimal_pow_verifier.verify(Decimal(2), Decimal(3), expected_result=Decimal(8)) + decimal_pow_verifier.verify(Decimal(2), Decimal(-1), expected_result=Decimal(0.5)) + decimal_pow_verifier.verify(Decimal(4), Decimal('0.5'), expected_result=Decimal(2)) + + int_pow_verifier.verify(Decimal(1), 2, expected_result=Decimal(1)) + int_pow_verifier.verify(Decimal(2), 2, expected_result=Decimal(4)) + int_pow_verifier.verify(Decimal(3), 2, expected_result=Decimal(9)) + int_pow_verifier.verify(Decimal(2), 3, expected_result=Decimal(8)) + int_pow_verifier.verify(Decimal(2), -1, expected_result=Decimal(0.5)) + + +def test_comparisons(): + def lt(a: Decimal, b: Decimal) -> bool: + return a < b + + def gt(a: Decimal, b: Decimal) -> bool: + return a > b + + def le(a: Decimal, b: Decimal) -> bool: + return a <= b + + def ge(a: Decimal, b: Decimal) -> bool: + return a >= b + + def eq(a: Decimal, b: Decimal) -> bool: + return a == b + + def ne(a: Decimal, b: Decimal) -> bool: + return a != b + + lt_verifier = verifier_for(lt) + gt_verifier = verifier_for(gt) + le_verifier = verifier_for(le) + ge_verifier = verifier_for(ge) + eq_verifier = verifier_for(eq) + ne_verifier = verifier_for(ne) + + lt_verifier.verify(Decimal(1), Decimal(1), expected_result=False) + gt_verifier.verify(Decimal(1), Decimal(1), expected_result=False) + le_verifier.verify(Decimal(1), Decimal(1), expected_result=True) + ge_verifier.verify(Decimal(1), Decimal(1), expected_result=True) + eq_verifier.verify(Decimal(1), Decimal(1), expected_result=True) + ne_verifier.verify(Decimal(1), Decimal(1), expected_result=False) + + lt_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=False) + gt_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=False) + le_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=True) + ge_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=True) + eq_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=True) + ne_verifier.verify(Decimal(1), Decimal('1.0'), expected_result=False) + + lt_verifier.verify(Decimal(1), Decimal(2), expected_result=True) + gt_verifier.verify(Decimal(1), Decimal(2), expected_result=False) + le_verifier.verify(Decimal(1), Decimal(2), expected_result=True) + ge_verifier.verify(Decimal(1), Decimal(2), expected_result=False) + eq_verifier.verify(Decimal(1), Decimal(2), expected_result=False) + ne_verifier.verify(Decimal(1), Decimal(2), expected_result=True) + + lt_verifier.verify(Decimal(2), Decimal(1), expected_result=False) + gt_verifier.verify(Decimal(2), Decimal(1), expected_result=True) + le_verifier.verify(Decimal(2), Decimal(1), expected_result=False) + ge_verifier.verify(Decimal(2), Decimal(1), expected_result=True) + eq_verifier.verify(Decimal(2), Decimal(1), expected_result=False) + ne_verifier.verify(Decimal(2), Decimal(1), expected_result=True) + + +def test_hash(): + def decimal_hash(a: Decimal) -> int: + return hash(a) + + hash_verifier = verifier_for(decimal_hash) + hash_verifier.verify(Decimal(1), expected_result=hash(Decimal(1))) + hash_verifier.verify(Decimal('1.5'), expected_result=hash(Decimal('1.5'))) + + +def test_round(): + def decimal_round(a: Decimal) -> int: + return round(a) + + def decimal_round_with_digits(a: Decimal, digits: int) -> Decimal: + return round(a, digits) + + decimal_round_verifier = verifier_for(decimal_round) + decimal_round_with_digits_verifier = verifier_for(decimal_round_with_digits) + + decimal_round_verifier.verify(Decimal('1.2'), expected_result=1) + decimal_round_verifier.verify(Decimal('1.5'), expected_result=2) + decimal_round_verifier.verify(Decimal('1.7'), expected_result=2) + decimal_round_verifier.verify(Decimal('2.5'), expected_result=2) + + decimal_round_with_digits_verifier.verify(Decimal('13.22'), 1, expected_result=Decimal('13.2')) + decimal_round_with_digits_verifier.verify(Decimal('13.22'), 2, expected_result=Decimal('13.22')) + decimal_round_with_digits_verifier.verify(Decimal('13.27'), 1, expected_result=Decimal('13.3')) + decimal_round_with_digits_verifier.verify(Decimal('13.25'), 1, expected_result=Decimal('13.2')) + + +def test_adjusted(): + def adjusted(a: Decimal) -> int: + return a.adjusted() + + adjusted_verifier = verifier_for(adjusted) + adjusted_verifier.verify(Decimal(100), expected_result=2) + adjusted_verifier.verify(Decimal('0.001'), expected_result=-3) + + +def test_as_integer_ratio(): + def as_integer_ratio(a: Decimal) -> tuple[int, int]: + return a.as_integer_ratio() + + adjusted_verifier = verifier_for(as_integer_ratio) + adjusted_verifier.verify(Decimal(100), expected_result=(100, 1)) + adjusted_verifier.verify(Decimal('-3.14'), expected_result=(-157, 50)) + + +# TODO: Make as_tuple use NamedTuple +def test_as_tuple(): + def as_tuple(a: Decimal) -> tuple[int, tuple[int,...], int]: + return a.as_tuple() + + def matches_tuple(t: tuple[int, tuple[int,...], int]) -> Callable[[tuple[int, tuple[int,...], int]], bool]: + def predicate(tested: tuple[int, tuple[int,...], int]) -> bool: + return t == tested + + return predicate + + as_tuple_verifier = verifier_for(as_tuple) + as_tuple_verifier.verify_property(Decimal(100), predicate=matches_tuple((0, (1, 0, 0), 0))) + as_tuple_verifier.verify_property(Decimal(-100), predicate=matches_tuple((1, (1, 0, 0), 0))) + as_tuple_verifier.verify_property(Decimal('123.45'), predicate=matches_tuple((0, (1, 2, 3, 4, 5), -2))) + + +def test_canonical(): + def canonical(a: Decimal) -> Decimal: + return a.canonical() + + canonical_verifier = verifier_for(canonical) + canonical_verifier.verify(Decimal(100), expected_result=Decimal(100)) + + +def test_compare(): + def compare(a: Decimal, b: Decimal) -> Decimal: + return a.compare(b) + + compare_verifier = verifier_for(compare) + compare_verifier.verify(Decimal(-5), Decimal(5), expected_result=Decimal(-1)) + compare_verifier.verify(Decimal(5), Decimal(-5), expected_result=Decimal(1)) + compare_verifier.verify(Decimal(5), Decimal(5), expected_result=Decimal(0)) + + +def test_compare_signal(): + def compare_signal(a: Decimal, b: Decimal) -> Decimal: + return a.compare_signal(b) + + compare_signal_verifier = verifier_for(compare_signal) + compare_signal_verifier.verify(Decimal(-5), Decimal(5), expected_result=Decimal(-1)) + compare_signal_verifier.verify(Decimal(5), Decimal(-5), expected_result=Decimal(1)) + compare_signal_verifier.verify(Decimal(5), Decimal(5), expected_result=Decimal(0)) + + +def test_compare_total(): + def compare_total(a: Decimal, b: Decimal) -> Decimal: + return a.compare_total(b) + + compare_total_verifier = verifier_for(compare_total) + compare_total_verifier.verify(Decimal(-5), Decimal(5), expected_result=Decimal(-1)) + compare_total_verifier.verify(Decimal(5), Decimal(-5), expected_result=Decimal(1)) + compare_total_verifier.verify(Decimal(5), Decimal(5), expected_result=Decimal(0)) + compare_total_verifier.verify(Decimal('12.0'), Decimal('12'), expected_result=Decimal(-1)) + compare_total_verifier.verify(Decimal('12'), Decimal('12.0'), expected_result=Decimal(1)) + + +def test_compare_total_mag(): + def compare_total_mag(a: Decimal, b: Decimal) -> Decimal: + return a.compare_total_mag(b) + + compare_total_mag_verifier = verifier_for(compare_total_mag) + compare_total_mag_verifier.verify(Decimal(3), Decimal(5), expected_result=Decimal(-1)) + compare_total_mag_verifier.verify(Decimal(-7), Decimal(5), expected_result=Decimal(1)) + compare_total_mag_verifier.verify(Decimal(-5), Decimal(5), expected_result=Decimal(0)) + compare_total_mag_verifier.verify(Decimal(5), Decimal(-5), expected_result=Decimal(0)) + compare_total_mag_verifier.verify(Decimal(5), Decimal(5), expected_result=Decimal(0)) + compare_total_mag_verifier.verify(Decimal('12.0'), Decimal('12'), expected_result=Decimal(-1)) + compare_total_mag_verifier.verify(Decimal('12'), Decimal('12.0'), expected_result=Decimal(1)) + compare_total_mag_verifier.verify(Decimal('12.0'), Decimal('-12'), expected_result=Decimal(-1)) + compare_total_mag_verifier.verify(Decimal('-12'), Decimal('12.0'), expected_result=Decimal(1)) + + +def test_conjugate(): + def conjugate(a: Decimal) -> Decimal: + return a.conjugate() + + conjugate_verifier = verifier_for(conjugate) + conjugate_verifier.verify(Decimal(10), expected_result=Decimal(10)) + + +def test_copy_abs(): + def copy_abs(a: Decimal) -> Decimal: + return a.copy_abs() + + copy_abs_verifier = verifier_for(copy_abs) + copy_abs_verifier.verify(Decimal(10), expected_result=Decimal(10)) + copy_abs_verifier.verify(Decimal(-10), expected_result=Decimal(10)) + + +def test_copy_negate(): + def copy_negate(a: Decimal) -> Decimal: + return a.copy_negate() + + copy_negate_verifier = verifier_for(copy_negate) + copy_negate_verifier.verify(Decimal(10), expected_result=Decimal(-10)) + copy_negate_verifier.verify(Decimal(-10), expected_result=Decimal(10)) + + +def test_copy_sign(): + def copy_sign(a: Decimal, b: Decimal) -> Decimal: + return a.copy_sign(b) + + copy_sign_verifier = verifier_for(copy_sign) + copy_sign_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(1)) + copy_sign_verifier.verify(Decimal('2.3'), Decimal('-1.5'), expected_result=Decimal('-2.3')) + copy_sign_verifier.verify(Decimal('-1.5'), Decimal('2.3'), expected_result=Decimal('1.5')) + + +def test_exp(): + def exp(a: Decimal) -> Decimal: + return a.exp() + + exp_verifier = verifier_for(exp) + exp_verifier.verify(Decimal(1), expected_result=Decimal('2.718281828459045235360287471')) + exp_verifier.verify(Decimal(321), expected_result=Decimal('2.561702493119680037517373933E+139')) + + +def test_fma(): + def decimal_decimal_fma(a: Decimal, b: Decimal, c: Decimal) -> Decimal: + return a.fma(b, c) + + def int_decimal_fma(a: Decimal, b: int, c: Decimal) -> Decimal: + return a.fma(b, c) + + def decimal_int_fma(a: Decimal, b: Decimal, c: int) -> Decimal: + return a.fma(b, c) + + def int_int_fma(a: Decimal, b: int, c: int) -> Decimal: + return a.fma(b, c) + + fma_decimal_decimal_verifier = verifier_for(decimal_decimal_fma) + fma_int_decimal_verifier = verifier_for(int_decimal_fma) + fma_decimal_int_decimal_verifier = verifier_for(decimal_int_fma) + fma_int_int_decimal_verifier = verifier_for(int_int_fma) + + fma_decimal_decimal_verifier.verify(Decimal(2), Decimal(3), Decimal(5), expected_result=Decimal(11)) + fma_int_decimal_verifier.verify(Decimal(2), 3, Decimal(5), expected_result=Decimal(11)) + fma_decimal_int_decimal_verifier.verify(Decimal(2), Decimal(3), 5, expected_result=Decimal(11)) + fma_int_int_decimal_verifier.verify(Decimal(2), 3, 5, expected_result=Decimal(11)) + + +def test_is_canonical(): + def is_canonical(a: Decimal) -> bool: + return a.is_canonical() + + is_canonical_verifier = verifier_for(is_canonical) + is_canonical_verifier.verify(Decimal(10), expected_result=True) + + +def test_is_finite(): + def is_finite(a: Decimal) -> bool: + return a.is_finite() + + is_finite_verifier = verifier_for(is_finite) + is_finite_verifier.verify(Decimal(10), expected_result=True) + + +def test_is_infinite(): + def is_infinite(a: Decimal) -> bool: + return a.is_infinite() + + is_infinite_verifier = verifier_for(is_infinite) + is_infinite_verifier.verify(Decimal(10), expected_result=False) + + +def test_is_nan(): + def is_nan(a: Decimal) -> bool: + return a.is_nan() + + is_nan_verifier = verifier_for(is_nan) + is_nan_verifier.verify(Decimal(10), expected_result=False) + + +def test_is_normal(): + def is_normal(a: Decimal) -> bool: + return a.is_normal() + + is_normal_verifier = verifier_for(is_normal) + is_normal_verifier.verify(Decimal(10), expected_result=True) + + +def test_is_qnan(): + def is_qnan(a: Decimal) -> bool: + return a.is_qnan() + + is_qnan_verifier = verifier_for(is_qnan) + is_qnan_verifier.verify(Decimal(10), expected_result=False) + + +def test_is_signed(): + def is_signed(a: Decimal) -> bool: + return a.is_signed() + + is_signed_verifier = verifier_for(is_signed) + is_signed_verifier.verify(Decimal(10), expected_result=False) + is_signed_verifier.verify(Decimal(0), expected_result=False) + is_signed_verifier.verify(Decimal(-10), expected_result=True) + + +def test_is_snan(): + def is_snan(a: Decimal) -> bool: + return a.is_snan() + + is_snan_verifier = verifier_for(is_snan) + is_snan_verifier.verify(Decimal(10), expected_result=False) + + +def test_is_subnormal(): + def is_subnormal(a: Decimal) -> bool: + return a.is_subnormal() + + is_subnormal_verifier = verifier_for(is_subnormal) + is_subnormal_verifier.verify(Decimal(10), expected_result=False) + + +def test_is_zero(): + def is_zero(a: Decimal) -> bool: + return a.is_zero() + + is_zero_verifier = verifier_for(is_zero) + is_zero_verifier.verify(Decimal(10), expected_result=False) + is_zero_verifier.verify(Decimal(0), expected_result=True) + + +def test_ln(): + def ln(a: Decimal) -> Decimal: + return a.ln() + + ln_verifier = verifier_for(ln) + ln_verifier.verify_property(Decimal(1), predicate=around(Decimal(0))) + ln_verifier.verify_property(Decimal(1).exp(), predicate=around(Decimal(1))) + ln_verifier.verify_property(Decimal('2.5').exp(), predicate=around(Decimal('2.5'))) + + +def test_log10(): + def log10(a: Decimal) -> Decimal: + return a.log10() + + log10_verifier = verifier_for(log10) + log10_verifier.verify_property(Decimal(1), predicate=around(Decimal(0))) + log10_verifier.verify_property(Decimal(10), predicate=around(Decimal(1))) + log10_verifier.verify_property(Decimal('0.1'), predicate=around(Decimal(-1))) + log10_verifier.verify_property(Decimal('5'), predicate=around(Decimal('0.69897'))) + + +def test_logb(): + def logb(a: Decimal) -> Decimal: + return a.logb() + + logb_verifier = verifier_for(logb) + logb_verifier.verify(Decimal(1), expected_result=Decimal(0)) + logb_verifier.verify(Decimal(100), expected_result=Decimal(2)) + logb_verifier.verify(Decimal(200), expected_result=Decimal(2)) + logb_verifier.verify(Decimal('0.1'), expected_result=Decimal(-1)) + logb_verifier.verify(Decimal('0.5'), expected_result=Decimal(-1)) + + +def test_logical_and(): + def logical_and(a: Decimal, b: Decimal) -> Decimal: + return a.logical_and(b) + + logical_and_verifier = verifier_for(logical_and) + logical_and_verifier.verify(Decimal('1010'), Decimal('1100'), expected_result=Decimal('1000')) + + +def test_logical_invert(): + def logical_invert(a: Decimal) -> Decimal: + return a.logical_invert() + + logical_invert_verifier = verifier_for(logical_invert) + logical_invert_verifier.verify(Decimal('1010'), expected_result=Decimal('1111111111111111111111110101')) + + +def test_logical_or(): + def logical_or(a: Decimal, b: Decimal) -> Decimal: + return a.logical_or(b) + + logical_or_verifier = verifier_for(logical_or) + logical_or_verifier.verify(Decimal('1010'), Decimal('1100'), expected_result=Decimal('1110')) + + +def test_logical_xor(): + def logical_xor(a: Decimal, b: Decimal) -> Decimal: + return a.logical_xor(b) + + logical_xor_verifier = verifier_for(logical_xor) + logical_xor_verifier.verify(Decimal('1010'), Decimal('1100'), expected_result=Decimal('0110')) + + +def test_max(): + def decimal_max(a: Decimal, b: Decimal) -> Decimal: + return a.max(b) + + decimal_max_verifier = verifier_for(decimal_max) + decimal_max_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(2)) + decimal_max_verifier.verify(Decimal(2), Decimal(1), expected_result=Decimal(2)) + decimal_max_verifier.verify(Decimal(1), Decimal(-2), expected_result=Decimal(1)) + + +def test_max_mag(): + def decimal_max_mag(a: Decimal, b: Decimal) -> Decimal: + return a.max_mag(b) + + decimal_max_mag_verifier = verifier_for(decimal_max_mag) + decimal_max_mag_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(2)) + decimal_max_mag_verifier.verify(Decimal(2), Decimal(1), expected_result=Decimal(2)) + decimal_max_mag_verifier.verify(Decimal(1), Decimal(-2), expected_result=Decimal(-2)) + + +def test_min(): + def decimal_min(a: Decimal, b: Decimal) -> Decimal: + return a.min(b) + + decimal_min_verifier = verifier_for(decimal_min) + decimal_min_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(1)) + decimal_min_verifier.verify(Decimal(2), Decimal(1), expected_result=Decimal(1)) + decimal_min_verifier.verify(Decimal(1), Decimal(-2), expected_result=Decimal(-2)) + + +def test_min_mag(): + def decimal_min_mag(a: Decimal, b: Decimal) -> Decimal: + return a.min_mag(b) + + decimal_min_mag_verifier = verifier_for(decimal_min_mag) + decimal_min_mag_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal(1)) + decimal_min_mag_verifier.verify(Decimal(2), Decimal(1), expected_result=Decimal(1)) + decimal_min_mag_verifier.verify(Decimal(1), Decimal(-2), expected_result=Decimal(1)) + + +def test_next_minus(): + def next_minus(a: Decimal) -> Decimal: + return a.next_minus() + + next_minus_verifier = verifier_for(next_minus) + next_minus_verifier.verify(Decimal(1), expected_result=Decimal('0.9999999999999999999999999999')) + next_minus_verifier.verify(Decimal('0.9999999999999999999999999999'), + expected_result=Decimal('0.9999999999999999999999999998')) + + +def test_next_plus(): + def next_plus(a: Decimal) -> Decimal: + return a.next_plus() + + next_plus_verifier = verifier_for(next_plus) + next_plus_verifier.verify(Decimal(1), expected_result=Decimal('1.000000000000000000000000001')) + next_plus_verifier.verify(Decimal('1.000000000000000000000000001'), + expected_result=Decimal('1.000000000000000000000000002')) + + +def test_next_toward(): + def next_toward(a: Decimal, b: Decimal) -> Decimal: + return a.next_toward(b) + + next_toward_verifier = verifier_for(next_toward) + next_toward_verifier.verify(Decimal(1), Decimal(0), expected_result=Decimal('0.9999999999999999999999999999')) + next_toward_verifier.verify(Decimal(1), Decimal(2), expected_result=Decimal('1.000000000000000000000000001')) + next_toward_verifier.verify(Decimal(1), Decimal(1), expected_result=Decimal(1)) + + +def test_normalize(): + def normalize(a: Decimal) -> Decimal: + return a.normalize() + + normalize_verifier = verifier_for(normalize) + normalize_verifier.verify(Decimal(10), expected_result=Decimal(10)) + + +def test_number_class(): + def number_class(a: Decimal) -> str: + return a.number_class() + + number_class_verifier = verifier_for(number_class) + number_class_verifier.verify(Decimal(1), expected_result='+Normal') + number_class_verifier.verify(Decimal(-1), expected_result='-Normal') + number_class_verifier.verify(Decimal(0), expected_result='+Zero') + + +def test_quantize(): + def quantize(a: Decimal, b: Decimal) -> Decimal: + return a.quantize(b) + + quantize_verifier = verifier_for(quantize) + quantize_verifier.verify(Decimal('1.41421356'), Decimal('1.000'), + expected_result=Decimal('1.414')) + + +def test_radix(): + def radix(a: Decimal) -> Decimal: + return a.radix() + + radix_verifier = verifier_for(radix) + radix_verifier.verify(Decimal(1), expected_result=Decimal(10)) + + +def test_remainder_near(): + def remainder_near(a: Decimal, b: Decimal) -> Decimal: + return a.remainder_near(b) + + remainder_near_verifier = verifier_for(remainder_near) + remainder_near_verifier.verify(Decimal(18), Decimal(10), expected_result=Decimal(-2)) + remainder_near_verifier.verify(Decimal(25), Decimal(10), expected_result=Decimal(5)) + remainder_near_verifier.verify(Decimal(35), Decimal(10), expected_result=Decimal(-5)) + + +def test_rotate(): + def rotate(a: Decimal, b: int) -> Decimal: + return a.rotate(b) + + rotate_verifier = verifier_for(rotate) + rotate_verifier.verify(Decimal('12.34'), 3, expected_result=Decimal('12340.00')) + rotate_verifier.verify(Decimal('12.34'), -3, expected_result=Decimal('23400000000000000000000000.01')) + + +def test_same_quantum(): + def same_quantum(a: Decimal, b: Decimal) -> bool: + return a.same_quantum(b) + + same_quantum_verifier = verifier_for(same_quantum) + same_quantum_verifier.verify(Decimal(1), Decimal(2), expected_result=True) + same_quantum_verifier.verify(Decimal(1), Decimal(10), expected_result=True) + same_quantum_verifier.verify(Decimal('0.1'), Decimal('0.01'), expected_result=False) + + +def test_scaleb(): + def scaleb(a: Decimal, b: int) -> Decimal: + return a.scaleb(b) + + scaleb_verifier = verifier_for(scaleb) + scaleb_verifier.verify(Decimal(1), 2, expected_result=Decimal(100)) + scaleb_verifier.verify(Decimal(1), -2, expected_result=Decimal('0.01')) + + +def test_sqrt(): + def sqrt(a: Decimal) -> Decimal: + return a.sqrt() + + sqrt_verifier = verifier_for(sqrt) + sqrt_verifier.verify(Decimal(1), expected_result=Decimal(1)) + sqrt_verifier.verify(Decimal(2), expected_result=Decimal('1.414213562373095048801688724')) + sqrt_verifier.verify(Decimal(9), expected_result=Decimal(3)) + + +def test_to_eng_string(): + def to_eng_string(a: Decimal) -> str: + return a.to_eng_string() + + to_eng_string_verifier = verifier_for(to_eng_string) + to_eng_string_verifier.verify(Decimal('123E+1'), expected_result='1.23E+3') + + +def test_to_integral(): + def to_integral(a: Decimal) -> Decimal: + return a.to_integral() + + to_integral_verifier = verifier_for(to_integral) + to_integral_verifier.verify(Decimal('1.23'), Decimal('1')) + to_integral_verifier.verify(Decimal('1.7'), Decimal('2')) + to_integral_verifier.verify(Decimal('1.5'), Decimal('2')) + + +def test_to_integral_exact(): + def to_integral_exact(a: Decimal) -> Decimal: + return a.to_integral_exact() + + to_integral_exact_verifier = verifier_for(to_integral_exact) + to_integral_exact_verifier.verify(Decimal('1.23'), Decimal('1')) + to_integral_exact_verifier.verify(Decimal('1.7'), Decimal('2')) + to_integral_exact_verifier.verify(Decimal('1.5'), Decimal('2')) + + +def test_to_integral_value(): + def to_to_integral_value(a: Decimal) -> Decimal: + return a.to_to_integral_value() + + to_to_integral_value_verifier = verifier_for(to_to_integral_value) + to_to_integral_value_verifier.verify(Decimal('1.23'), Decimal('1')) + to_to_integral_value_verifier.verify(Decimal('1.7'), Decimal('2')) + to_to_integral_value_verifier.verify(Decimal('1.5'), Decimal('2')) diff --git a/tests/test_collectors.py b/tests/test_collectors.py index 2c589d29..7075e4ce 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -572,7 +572,7 @@ def define_constraints(constraint_factory: ConstraintFactory): lambda entity: entity.value )) .reward(SimpleScore.ONE, - lambda balance: balance.unfairness().movePointRight(3).intValue()) + lambda balance: round(balance.unfairness() * 1000)) .as_constraint('Balanced value') ] diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index f908f6ff..056fc91b 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -6,6 +6,7 @@ import inspect import re from dataclasses import dataclass, field +from decimal import Decimal from typing import Annotated, List from ai.timefold.solver.core.api.score.stream import Joiners as JavaJoiners, \ ConstraintCollectors as JavaConstraintCollectors, ConstraintFactory as JavaConstraintFactory @@ -40,10 +41,18 @@ class Solution: score: Annotated[SimpleScore, PlanningScore] = field(default=None) -def create_score_manager(constraint_provider): +@planning_solution +@dataclass +class DecimalSolution: + entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_list: Annotated[List[Value], ProblemFactCollectionProperty, ValueRangeProvider] + score: Annotated[SimpleDecimalScore, PlanningScore] = field(default=None) + + +def create_score_manager(constraint_provider, solution_class: type = Solution, entity_classes: list[type] = (Entity,)): return SolutionManager.create(SolverFactory.create( - SolverConfig(solution_class=Solution, - entity_class_list=[Entity], + SolverConfig(solution_class=solution_class, + entity_class_list=entity_classes, score_director_factory_config=ScoreDirectorFactoryConfig( constraint_provider_function=constraint_provider )))) @@ -722,6 +731,308 @@ def define_constraints(constraint_factory: ConstraintFactory): assert score_manager.explain(problem).score == SimpleScore.of(9_000_000_000) +def test_sanity(): + int_impact_functions = [ + None, + lambda a: a.value.number, + lambda a, b: a.value.number, + lambda a, b, c: a.value.number, + lambda a, b, c, d: a.value.number, + ] + + i = 0 + + def build_stream(constraint_factory: ConstraintFactory, + method: str, + cardinality: int, + has_impact_function: bool) -> Constraint: + nonlocal i + i += 1 + + def expander(x): + return None + + expanders = [expander] * (cardinality - 1) + current = constraint_factory.for_each(Entity) + if expanders: + current = current.expand(*expanders) + + impact_method = getattr(current, method) + + if has_impact_function: + return (impact_method(SimpleScore.ONE, int_impact_functions[cardinality]) + .as_constraint(f'Constraint {i}')) + else: + return (impact_method(SimpleScore.ONE) + .as_constraint(f'Constraint {i}')) + + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + build_stream(constraint_factory, method, cardinality, + use_impact_function) + for method in ['penalize', 'reward', 'impact'] + for cardinality in [1, 2, 3, 4] + for use_impact_function in [True, False] + ] + + score_manager = create_score_manager(define_constraints) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(1) + + entity_a.value = value_1 + entity_b.value = value_1 + + problem = Solution([entity_a, entity_b], [value_1]) + + # 3 positive method + 1 negative methods = 2 positive + # 4 cardinalities + # 1 impact + 1 non-impact = 2 + # 2 * 4 * 2 = 16 + assert score_manager.explain(problem).score == SimpleScore.of(16) + + +def test_sanity_decimal(): + decimal_impact_functions = [ + None, + lambda a: a.value.number, + lambda a, b: a.value.number, + lambda a, b, c: a.value.number, + lambda a, b, c, d: a.value.number, + ] + + i = 0 + + def build_stream(constraint_factory: ConstraintFactory, + method: str, + cardinality: int, + has_impact_function: bool) -> Constraint: + nonlocal i + i += 1 + + def expander(x): + return None + + expanders = [expander] * (cardinality - 1) + current = constraint_factory.for_each(Entity) + if expanders: + current = current.expand(*expanders) + + impact_method = getattr(current, method) + + if has_impact_function: + return (impact_method(SimpleDecimalScore.ONE, decimal_impact_functions[cardinality]) + .as_constraint(f'Constraint {i}')) + else: + return (impact_method(SimpleDecimalScore.ONE) + .as_constraint(f'Constraint {i}')) + + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + build_stream(constraint_factory, method, cardinality, + use_impact_function) + for method in ['penalize_decimal', 'reward_decimal', 'impact_decimal'] + for cardinality in [1, 2, 3, 4] + for use_impact_function in [True, False] + ] + + score_manager = create_score_manager(define_constraints, solution_class=DecimalSolution) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(Decimal(1)) + + entity_a.value = value_1 + entity_b.value = value_1 + + problem = DecimalSolution([entity_a, entity_b], [value_1]) + + # 3 positive method + 1 negative methods = 2 positive + # 4 cardinalities + # 1 impact + 1 non-impact = 2 + # 2 * 4 * 2 = 16 + assert score_manager.explain(problem).score == SimpleDecimalScore.of(Decimal(16)) + + +def test_sanity_configurable(): + class ConstraintConfiguration: + pass + + for i in range(3 * 4 * 2): + weight_name = f'w{i + 1}' + weight_annotation = Annotated[SimpleScore, ConstraintWeight(f'Constraint {i + 1}', + constraint_package='pkg')] + weight_value = field(default=SimpleScore.ONE) + setattr(ConstraintConfiguration, weight_name, weight_value) + ConstraintConfiguration.__annotations__[weight_name] = weight_annotation + + ConstraintConfiguration = constraint_configuration(dataclass(ConstraintConfiguration)) + + @planning_solution + @dataclass + class ConfigurationSolution: + configuration: Annotated[ConstraintConfiguration, ConstraintConfigurationProvider] + entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_list: Annotated[List[Value], ProblemFactCollectionProperty, ValueRangeProvider] + score: Annotated[SimpleScore, PlanningScore] = field(default=None) + + + int_impact_functions = [ + None, + lambda a: a.value.number, + lambda a, b: a.value.number, + lambda a, b, c: a.value.number, + lambda a, b, c, d: a.value.number, + ] + + i = 0 + + def build_stream(constraint_factory: ConstraintFactory, + method: str, + cardinality: int, + has_impact_function: bool) -> Constraint: + nonlocal i + i += 1 + + def expander(x): + return None + + expanders = [expander] * (cardinality - 1) + current = constraint_factory.for_each(Entity) + if expanders: + current = current.expand(*expanders) + + impact_method = getattr(current, method) + + if has_impact_function: + return (impact_method(int_impact_functions[cardinality]) + .as_constraint('pkg', f'Constraint {i}')) + else: + return (impact_method() + .as_constraint('pkg', f'Constraint {i}')) + + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + build_stream(constraint_factory, method, cardinality, + use_impact_function) + for method in ['penalize_configurable', 'reward_configurable', 'impact_configurable'] + for cardinality in [1, 2, 3, 4] + for use_impact_function in [True, False] + ] + + score_manager = create_score_manager(define_constraints, solution_class=ConfigurationSolution) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(1) + + entity_a.value = value_1 + entity_b.value = value_1 + + problem = ConfigurationSolution(ConstraintConfiguration(), [entity_a, entity_b], [value_1]) + + # 3 positive method + 1 negative methods = 2 positive + # 4 cardinalities + # 1 impact + 1 non-impact = 2 + # 2 * 4 * 2 = 16 + assert score_manager.explain(problem).score == SimpleScore.of(16) + + +def test_sanity_configurable_decimal(): + class ConstraintConfiguration: + pass + + for i in range(3 * 4 * 2): + weight_name = f'w{i + 1}' + weight_annotation = Annotated[SimpleDecimalScore, ConstraintWeight(f'Constraint {i + 1}', + constraint_package='pkg')] + weight_value = field(default=SimpleDecimalScore.ONE) + setattr(ConstraintConfiguration, weight_name, weight_value) + ConstraintConfiguration.__annotations__[weight_name] = weight_annotation + + ConstraintConfiguration = constraint_configuration(dataclass(ConstraintConfiguration)) + + @planning_solution + @dataclass + class ConfigurationSolution: + configuration: Annotated[ConstraintConfiguration, ConstraintConfigurationProvider] + entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_list: Annotated[List[Value], ProblemFactCollectionProperty, ValueRangeProvider] + score: Annotated[SimpleDecimalScore, PlanningScore] = field(default=None) + + + decimal_impact_functions = [ + None, + lambda a: a.value.number, + lambda a, b: a.value.number, + lambda a, b, c: a.value.number, + lambda a, b, c, d: a.value.number, + ] + + i = 0 + + def build_stream(constraint_factory: ConstraintFactory, + method: str, + cardinality: int, + has_impact_function: bool) -> Constraint: + nonlocal i + i += 1 + + def expander(x): + return None + + expanders = [expander] * (cardinality - 1) + current = constraint_factory.for_each(Entity) + if expanders: + current = current.expand(*expanders) + + impact_method = getattr(current, method) + + if has_impact_function: + return (impact_method(decimal_impact_functions[cardinality]) + .as_constraint('pkg', f'Constraint {i}')) + else: + return (impact_method() + .as_constraint('pkg', f'Constraint {i}')) + + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + build_stream(constraint_factory, method, cardinality, + use_impact_function) + for method in ['penalize_configurable_decimal', + 'reward_configurable_decimal', + 'impact_configurable_decimal'] + for cardinality in [1, 2, 3, 4] + for use_impact_function in [True, False] + ] + + score_manager = create_score_manager(define_constraints, solution_class=ConfigurationSolution) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(Decimal(1)) + + entity_a.value = value_1 + entity_b.value = value_1 + + problem = ConfigurationSolution(ConstraintConfiguration(), [entity_a, entity_b], [value_1]) + + # 3 positive method + 1 negative methods = 2 positive + # 4 cardinalities + # 1 impact + 1 non-impact = 2 + # 2 * 4 * 2 = 16 + assert score_manager.explain(problem).score == SimpleDecimalScore.of(Decimal(16)) + + ignored_python_functions = { '_call_comparison_java_joiner', '__init__', diff --git a/tests/test_score.py b/tests/test_score.py index dbcd5c8f..822df64a 100644 --- a/tests/test_score.py +++ b/tests/test_score.py @@ -1,5 +1,10 @@ -from timefold.solver.score import SimpleScore, HardSoftScore, HardMediumSoftScore, BendableScore - +from dataclasses import dataclass, field +from decimal import Decimal +from timefold.solver import * +from timefold.solver.config import * +from timefold.solver.domain import * +from timefold.solver.score import * +from typing import Annotated def test_simple_score(): uninit_score = SimpleScore(10, init_score=-2) @@ -43,3 +48,149 @@ def test_bendable_score(): assert BendableScore.parse('-500init/[1/-2/3]hard/[-30/40]soft') == uninit_score assert BendableScore.parse('[1/-2/3]hard/[-30/40]soft') == score + + +def test_simple_decimal_score(): + uninit_score = SimpleDecimalScore(Decimal('10.1'), init_score=-2) + score = SimpleDecimalScore.of(Decimal('10.1')) + + assert str(uninit_score) == '-2init/10.1' + assert str(score) == '10.1' + + assert SimpleDecimalScore.parse('-2init/10.1') == uninit_score + assert SimpleDecimalScore.parse('10.1') == score + + +def test_hard_soft_decimal_score(): + uninit_score = HardSoftDecimalScore(Decimal('100.1'), Decimal('20.2'), init_score=-3) + score = HardSoftDecimalScore.of(Decimal('100.1'), Decimal('20.2')) + + assert str(uninit_score) == '-3init/100.1hard/20.2soft' + assert str(score) == '100.1hard/20.2soft' + + assert HardSoftDecimalScore.parse('-3init/100.1hard/20.2soft') == uninit_score + assert HardSoftDecimalScore.parse('100.1hard/20.2soft') == score + + +def test_hard_medium_soft_decimal_score(): + uninit_score = HardMediumSoftDecimalScore(Decimal('1000.1'), Decimal('200.2'), Decimal('30.3'), init_score=-4) + score = HardMediumSoftDecimalScore.of(Decimal('1000.1'), Decimal('200.2'), Decimal('30.3')) + + assert str(uninit_score) == '-4init/1000.1hard/200.2medium/30.3soft' + assert str(score) == '1000.1hard/200.2medium/30.3soft' + + assert HardMediumSoftDecimalScore.parse('-4init/1000.1hard/200.2medium/30.3soft') == uninit_score + assert HardMediumSoftDecimalScore.parse('1000.1hard/200.2medium/30.3soft') == score + + +def test_bendable_decimal_score(): + uninit_score = BendableDecimalScore((Decimal('1.1'), Decimal('-2.2'), Decimal('3.3')), + (Decimal('-30.3'), Decimal('40.4')), init_score=-500) + score = BendableDecimalScore.of((Decimal('1.1'), Decimal('-2.2'), Decimal('3.3')), + (Decimal('-30.3'), Decimal('40.4'))) + + print(str(uninit_score)) + assert str(uninit_score) == '-500init/[1.1/-2.2/3.3]hard/[-30.3/40.4]soft' + assert str(score) == '[1.1/-2.2/3.3]hard/[-30.3/40.4]soft' + + assert BendableDecimalScore.parse('-500init/[1.1/-2.2/3.3]hard/[-30.3/40.4]soft') == uninit_score + assert BendableDecimalScore.parse('[1.1/-2.2/3.3]hard/[-30.3/40.4]soft') == score + + +def test_sanity_score_type(): + @planning_entity + @dataclass + class Entity: + value: Annotated[int | None, PlanningVariable] = field(default=None) + + for score_type, score_value in ( + (SimpleScore, SimpleScore.ONE), + (HardSoftScore, HardSoftScore.ONE_HARD), + (HardMediumSoftScore, HardMediumSoftScore.ONE_HARD), + (BendableScore, BendableScore.of((1, ), (0, ))), + (SimpleDecimalScore, SimpleDecimalScore.ONE), + (HardSoftDecimalScore, HardSoftDecimalScore.ONE_HARD), + (HardMediumSoftDecimalScore, HardMediumSoftDecimalScore.ONE_HARD), + (BendableDecimalScore, BendableDecimalScore.of((Decimal(1), ), (Decimal(0), ))) + ): + score_annotation = PlanningScore + if score_type == BendableScore or score_type == BendableDecimalScore: + score_annotation = PlanningScore(bendable_hard_levels_size=1, + bendable_soft_levels_size=1) + + @planning_solution + @dataclass + class Solution: + entities: Annotated[list[Entity], PlanningEntityCollectionProperty] + values: Annotated[list[int], ValueRangeProvider] + score: Annotated[score_type | None, score_annotation] = field(default=None) + + @constraint_provider + def constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .reward(score_value) + .as_constraint('Minimize value') + ] + + solver_config = SolverConfig( + solution_class=Solution, + entity_class_list=[Entity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=constraints + ), + termination_config=TerminationConfig( + best_score_limit=str(score_value) + ) + ) + + solver_factory = SolverFactory.create(solver_config) + solver = solver_factory.build_solver() + problem = Solution(entities=[Entity()], + values=[1]) + solution = solver.solve(problem) + assert solution.entities[0].value == 1 + assert solution.score == score_value + + +def test_simple_decimal_score_domain(): + @planning_entity + @dataclass + class Entity: + value: Annotated[Decimal | None, PlanningVariable] = field(default=None) + + @planning_solution + @dataclass + class Solution: + entities: Annotated[list[Entity], PlanningEntityCollectionProperty] + values: Annotated[list[Decimal], ValueRangeProvider] + score: Annotated[SimpleDecimalScore | None, PlanningScore] = field(default=None) + + + @constraint_provider + def constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .penalize_decimal(SimpleDecimalScore.of(Decimal('0.1')), lambda e: e.value) + .as_constraint('Minimize value') + ] + + solver_config = SolverConfig( + solution_class=Solution, + entity_class_list=[Entity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=constraints + ), + termination_config=TerminationConfig( + best_score_limit='-0.2' + ) + ) + + solver_factory = SolverFactory.create(solver_config) + solver = solver_factory.build_solver() + problem = Solution(entities=[Entity() for i in range(2)], + values=[Decimal(1), Decimal(2), Decimal(3)]) + solution = solver.solve(problem) + assert solution.entities[0].value == 1 + assert solution.entities[1].value == 1 + assert solution.score == SimpleDecimalScore.of(Decimal('-0.2')) diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMapping.java new file mode 100644 index 00000000..741be6ae --- /dev/null +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMapping.java @@ -0,0 +1,88 @@ +package ai.timefold.solver.python.score; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonJavaTypeMapping; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore; + +public final class BendableDecimalScorePythonJavaTypeMapping + implements PythonJavaTypeMapping { + private final PythonLikeType type; + private final Constructor constructor; + private final Field initScoreField; + private final Field hardScoresField; + private final Field softScoresField; + + public BendableDecimalScorePythonJavaTypeMapping(PythonLikeType type) + throws ClassNotFoundException, NoSuchFieldException, NoSuchMethodException { + this.type = type; + Class clazz = type.getJavaClass(); + constructor = clazz.getConstructor(); + initScoreField = clazz.getField("init_score"); + hardScoresField = clazz.getField("hard_scores"); + softScoresField = clazz.getField("soft_scores"); + } + + @Override + public PythonLikeType getPythonType() { + return type; + } + + @Override + public Class getJavaType() { + return BendableBigDecimalScore.class; + } + + private static PythonLikeTuple toPythonList(BigDecimal[] scores) { + PythonLikeTuple out = new PythonLikeTuple<>(); + for (var score : scores) { + out.add(new PythonDecimal(score)); + } + return out; + } + + @Override + public PythonLikeObject toPythonObject(BendableBigDecimalScore javaObject) { + try { + var instance = constructor.newInstance(); + initScoreField.set(instance, PythonInteger.valueOf(javaObject.initScore())); + hardScoresField.set(instance, toPythonList(javaObject.hardScores())); + softScoresField.set(instance, toPythonList(javaObject.softScores())); + return (PythonLikeObject) instance; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + + @Override + public BendableBigDecimalScore toJavaObject(PythonLikeObject pythonObject) { + try { + var initScore = ((PythonInteger) initScoreField.get(pythonObject)).value.intValue(); + var hardScoreTuple = ((PythonLikeTuple) hardScoresField.get(pythonObject)); + var softScoreTuple = ((PythonLikeTuple) softScoresField.get(pythonObject)); + BigDecimal[] hardScores = new BigDecimal[hardScoreTuple.size()]; + BigDecimal[] softScores = new BigDecimal[softScoreTuple.size()]; + for (int i = 0; i < hardScores.length; i++) { + hardScores[i] = ((PythonDecimal) hardScoreTuple.get(i)).value; + } + for (int i = 0; i < softScores.length; i++) { + softScores[i] = ((PythonDecimal) softScoreTuple.get(i)).value; + } + if (initScore == 0) { + return BendableBigDecimalScore.of(hardScores, softScores); + } else { + return BendableBigDecimalScore.ofUninitialized(initScore, hardScores, softScores); + } + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } +} diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableScorePythonJavaTypeMapping.java index 7a9e193b..2f05fdc8 100644 --- a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableScorePythonJavaTypeMapping.java +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/BendableScorePythonJavaTypeMapping.java @@ -55,7 +55,7 @@ public PythonLikeObject toPythonObject(BendableLongScore javaObject) { softScoresField.set(instance, toPythonList(javaObject.softScores())); return (PythonLikeObject) instance; } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } @@ -79,7 +79,7 @@ public BendableLongScore toJavaObject(PythonLikeObject pythonObject) { return BendableLongScore.ofUninitialized(initScore, hardScores, softScores); } } catch (IllegalAccessException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } } diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMapping.java new file mode 100644 index 00000000..2edafd72 --- /dev/null +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMapping.java @@ -0,0 +1,74 @@ +package ai.timefold.solver.python.score; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; + +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonJavaTypeMapping; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.hardmediumsoftbigdecimal.HardMediumSoftBigDecimalScore; + +public final class HardMediumSoftDecimalScorePythonJavaTypeMapping + implements PythonJavaTypeMapping { + private final PythonLikeType type; + private final Constructor constructor; + private final Field initScoreField; + private final Field hardScoreField; + private final Field mediumScoreField; + private final Field softScoreField; + + public HardMediumSoftDecimalScorePythonJavaTypeMapping(PythonLikeType type) + throws ClassNotFoundException, NoSuchFieldException, NoSuchMethodException { + this.type = type; + Class clazz = type.getJavaClass(); + constructor = clazz.getConstructor(); + initScoreField = clazz.getField("init_score"); + hardScoreField = clazz.getField("hard_score"); + mediumScoreField = clazz.getField("medium_score"); + softScoreField = clazz.getField("soft_score"); + } + + @Override + public PythonLikeType getPythonType() { + return type; + } + + @Override + public Class getJavaType() { + return HardMediumSoftBigDecimalScore.class; + } + + @Override + public PythonLikeObject toPythonObject(HardMediumSoftBigDecimalScore javaObject) { + try { + var instance = constructor.newInstance(); + initScoreField.set(instance, PythonInteger.valueOf(javaObject.initScore())); + hardScoreField.set(instance, new PythonDecimal(javaObject.hardScore())); + mediumScoreField.set(instance, new PythonDecimal(javaObject.mediumScore())); + softScoreField.set(instance, new PythonDecimal(javaObject.softScore())); + return (PythonLikeObject) instance; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + + @Override + public HardMediumSoftBigDecimalScore toJavaObject(PythonLikeObject pythonObject) { + try { + var initScore = ((PythonInteger) initScoreField.get(pythonObject)).value.intValue(); + var hardScore = ((PythonDecimal) hardScoreField.get(pythonObject)).value; + var mediumScore = ((PythonDecimal) mediumScoreField.get(pythonObject)).value; + var softScore = ((PythonDecimal) softScoreField.get(pythonObject)).value; + if (initScore == 0) { + return HardMediumSoftBigDecimalScore.of(hardScore, mediumScore, softScore); + } else { + return HardMediumSoftBigDecimalScore.ofUninitialized(initScore, hardScore, mediumScore, softScore); + } + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } +} diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftScorePythonJavaTypeMapping.java index f98cf322..4f73c357 100644 --- a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftScorePythonJavaTypeMapping.java +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardMediumSoftScorePythonJavaTypeMapping.java @@ -50,7 +50,7 @@ public PythonLikeObject toPythonObject(HardMediumSoftLongScore javaObject) { softScoreField.set(instance, PythonInteger.valueOf(javaObject.softScore())); return (PythonLikeObject) instance; } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } @@ -67,7 +67,7 @@ public HardMediumSoftLongScore toJavaObject(PythonLikeObject pythonObject) { return HardMediumSoftLongScore.ofUninitialized(initScore, hardScore, mediumScore, softScore); } } catch (IllegalAccessException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } } diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMapping.java new file mode 100644 index 00000000..af0ffd85 --- /dev/null +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMapping.java @@ -0,0 +1,70 @@ +package ai.timefold.solver.python.score; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; + +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonJavaTypeMapping; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.hardsoftbigdecimal.HardSoftBigDecimalScore; + +public final class HardSoftDecimalScorePythonJavaTypeMapping + implements PythonJavaTypeMapping { + private final PythonLikeType type; + private final Constructor constructor; + private final Field initScoreField; + private final Field hardScoreField; + private final Field softScoreField; + + public HardSoftDecimalScorePythonJavaTypeMapping(PythonLikeType type) + throws ClassNotFoundException, NoSuchFieldException, NoSuchMethodException { + this.type = type; + Class clazz = type.getJavaClass(); + constructor = clazz.getConstructor(); + initScoreField = clazz.getField("init_score"); + hardScoreField = clazz.getField("hard_score"); + softScoreField = clazz.getField("soft_score"); + } + + @Override + public PythonLikeType getPythonType() { + return type; + } + + @Override + public Class getJavaType() { + return HardSoftBigDecimalScore.class; + } + + @Override + public PythonLikeObject toPythonObject(HardSoftBigDecimalScore javaObject) { + try { + var instance = constructor.newInstance(); + initScoreField.set(instance, PythonInteger.valueOf(javaObject.initScore())); + hardScoreField.set(instance, new PythonDecimal(javaObject.hardScore())); + softScoreField.set(instance, new PythonDecimal(javaObject.softScore())); + return (PythonLikeObject) instance; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + + @Override + public HardSoftBigDecimalScore toJavaObject(PythonLikeObject pythonObject) { + try { + var initScore = ((PythonInteger) initScoreField.get(pythonObject)).value.intValue(); + var hardScore = ((PythonDecimal) hardScoreField.get(pythonObject)).value; + var softScore = ((PythonDecimal) softScoreField.get(pythonObject)).value; + if (initScore == 0) { + return HardSoftBigDecimalScore.of(hardScore, softScore); + } else { + return HardSoftBigDecimalScore.ofUninitialized(initScore, hardScore, softScore); + } + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } +} diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftScorePythonJavaTypeMapping.java index 20b94866..15bc998c 100644 --- a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftScorePythonJavaTypeMapping.java +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/HardSoftScorePythonJavaTypeMapping.java @@ -46,7 +46,7 @@ public PythonLikeObject toPythonObject(HardSoftLongScore javaObject) { softScoreField.set(instance, PythonInteger.valueOf(javaObject.softScore())); return (PythonLikeObject) instance; } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } @@ -62,7 +62,7 @@ public HardSoftLongScore toJavaObject(PythonLikeObject pythonObject) { return HardSoftLongScore.ofUninitialized(initScore, hardScore, softScore); } } catch (IllegalAccessException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } } diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMapping.java new file mode 100644 index 00000000..4f8cf3f7 --- /dev/null +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMapping.java @@ -0,0 +1,66 @@ +package ai.timefold.solver.python.score; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; + +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonJavaTypeMapping; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore; + +public final class SimpleDecimalScorePythonJavaTypeMapping + implements PythonJavaTypeMapping { + private final PythonLikeType type; + private final Constructor constructor; + private final Field initScoreField; + private final Field scoreField; + + public SimpleDecimalScorePythonJavaTypeMapping(PythonLikeType type) + throws ClassNotFoundException, NoSuchFieldException, NoSuchMethodException { + this.type = type; + Class clazz = type.getJavaClass(); + constructor = clazz.getConstructor(); + initScoreField = clazz.getField("init_score"); + scoreField = clazz.getField("score"); + } + + @Override + public PythonLikeType getPythonType() { + return type; + } + + @Override + public Class getJavaType() { + return SimpleBigDecimalScore.class; + } + + @Override + public PythonLikeObject toPythonObject(SimpleBigDecimalScore javaObject) { + try { + var instance = constructor.newInstance(); + initScoreField.set(instance, PythonInteger.valueOf(javaObject.initScore())); + scoreField.set(instance, new PythonDecimal(javaObject.score())); + return (PythonLikeObject) instance; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + + @Override + public SimpleBigDecimalScore toJavaObject(PythonLikeObject pythonObject) { + try { + var initScore = ((PythonInteger) initScoreField.get(pythonObject)).value.intValue(); + var score = ((PythonDecimal) scoreField.get(pythonObject)).value; + if (initScore == 0) { + return SimpleBigDecimalScore.of(score); + } else { + return SimpleBigDecimalScore.ofUninitialized(initScore, score); + } + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } +} diff --git a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleScorePythonJavaTypeMapping.java b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleScorePythonJavaTypeMapping.java index 749ec0f6..bff42450 100644 --- a/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleScorePythonJavaTypeMapping.java +++ b/timefold-solver-python-core/src/main/java/ai/timefold/solver/python/score/SimpleScorePythonJavaTypeMapping.java @@ -43,7 +43,7 @@ public PythonLikeObject toPythonObject(SimpleLongScore javaObject) { scoreField.set(instance, PythonInteger.valueOf(javaObject.score())); return (PythonLikeObject) instance; } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } @@ -58,7 +58,7 @@ public SimpleLongScore toJavaObject(PythonLikeObject pythonObject) { return SimpleLongScore.ofUninitialized(initScore, score); } } catch (IllegalAccessException e) { - throw new RuntimeException(e); + throw new IllegalStateException(e); } } } diff --git a/timefold-solver-python-core/src/main/python/_jpype_type_conversions.py b/timefold-solver-python-core/src/main/python/_jpype_type_conversions.py index 0ac85d1f..638a342d 100644 --- a/timefold-solver-python-core/src/main/python/_jpype_type_conversions.py +++ b/timefold-solver-python-core/src/main/python/_jpype_type_conversions.py @@ -2,6 +2,7 @@ from jpype.types import * from types import FunctionType from typing import TYPE_CHECKING +from decimal import Decimal import timefold.solver._timefold_java_interop as _timefold_java_interop if TYPE_CHECKING: @@ -256,6 +257,26 @@ def to_python_score(score) -> 'Score': return _timefold_java_interop._python_score_mapping_dict['BendableScore'](score.hardScores(), score.softScores(), init_score=score.initScore()) + elif isinstance(score, _timefold_java_interop._java_score_mapping_dict['SimpleDecimalScore']): + return _timefold_java_interop._python_score_mapping_dict['SimpleDecimalScore'](Decimal(score.score().toPlainString()), + init_score=score.initScore()) + elif isinstance(score, _timefold_java_interop._java_score_mapping_dict['HardSoftDecimalScore']): + return _timefold_java_interop._python_score_mapping_dict['HardSoftDecimalScore'](Decimal(score.hardScore().toPlainString()), + Decimal(score.softScore().toPlainString()), + init_score=score.initScore()) + elif isinstance(score, _timefold_java_interop._java_score_mapping_dict['HardMediumSoftDecimalScore']): + return _timefold_java_interop._python_score_mapping_dict['HardMediumSoftDecimalScore'](Decimal(score.hardScore().toPlainString()), + Decimal(score.mediumScore().toPlainString()), + Decimal(score.softScore().toPlainString()), + init_score=score.initScore()) + elif isinstance(score, _timefold_java_interop._java_score_mapping_dict['BendableDecimalScore']): + return _timefold_java_interop._python_score_mapping_dict['BendableDecimalScore']([Decimal(part.toPlainString()) + for part in score.hardScores() + ], + [Decimal(part.toPlainString()) + for part in score.softScores() + ], + init_score=score.initScore()) else: raise TypeError(f'Unexpected score type: {type(score)}') diff --git a/timefold-solver-python-core/src/main/python/_timefold_java_interop.py b/timefold-solver-python-core/src/main/python/_timefold_java_interop.py index 05a7f833..09f041fe 100644 --- a/timefold-solver-python-core/src/main/python/_timefold_java_interop.py +++ b/timefold-solver-python-core/src/main/python/_timefold_java_interop.py @@ -105,38 +105,69 @@ def register_score_python_java_type_mappings(): _scores_registered = True - from .score._score import SimpleScore, HardSoftScore, HardMediumSoftScore, BendableScore + from decimal import Decimal + from .score._score import (SimpleScore, HardSoftScore, HardMediumSoftScore, BendableScore, + SimpleDecimalScore, HardSoftDecimalScore, HardMediumSoftDecimalScore, + BendableDecimalScore) from ai.timefold.solver.core.api.score.buildin.simplelong import SimpleLongScore as _SimpleScore from ai.timefold.solver.core.api.score.buildin.hardsoftlong import HardSoftLongScore as _HardSoftScore from ai.timefold.solver.core.api.score.buildin.hardmediumsoftlong import HardMediumSoftLongScore as _HardMediumSoftScore from ai.timefold.solver.core.api.score.buildin.bendablelong import BendableLongScore as _BendableScore + from ai.timefold.solver.core.api.score.buildin.simplebigdecimal import SimpleBigDecimalScore as _SimpleDecimalScore + from ai.timefold.solver.core.api.score.buildin.hardsoftbigdecimal import HardSoftBigDecimalScore as _HardSoftDecimalScore + from ai.timefold.solver.core.api.score.buildin.hardmediumsoftbigdecimal import HardMediumSoftBigDecimalScore as _HardMediumSoftDecimalScore + from ai.timefold.solver.core.api.score.buildin.bendablebigdecimal import BendableBigDecimalScore as _BendableDecimalScore + from ai.timefold.solver.python.score import (SimpleScorePythonJavaTypeMapping, HardSoftScorePythonJavaTypeMapping, HardMediumSoftScorePythonJavaTypeMapping, - BendableScorePythonJavaTypeMapping) + BendableScorePythonJavaTypeMapping, + SimpleDecimalScorePythonJavaTypeMapping, + HardSoftDecimalScorePythonJavaTypeMapping, + HardMediumSoftDecimalScorePythonJavaTypeMapping, + BendableDecimalScorePythonJavaTypeMapping, + ) from _jpyinterpreter import translate_python_class_to_java_class, add_python_java_type_mapping _python_score_mapping_dict['SimpleScore'] = SimpleScore _python_score_mapping_dict['HardSoftScore'] = HardSoftScore _python_score_mapping_dict['HardMediumSoftScore'] = HardMediumSoftScore _python_score_mapping_dict['BendableScore'] = BendableScore + _python_score_mapping_dict['SimpleDecimalScore'] = SimpleDecimalScore + _python_score_mapping_dict['HardSoftDecimalScore'] = HardSoftDecimalScore + _python_score_mapping_dict['HardMediumSoftDecimalScore'] = HardMediumSoftDecimalScore + _python_score_mapping_dict['BendableDecimalScore'] = BendableDecimalScore _java_score_mapping_dict['SimpleScore'] = _SimpleScore _java_score_mapping_dict['HardSoftScore'] = _HardSoftScore _java_score_mapping_dict['HardMediumSoftScore'] = _HardMediumSoftScore _java_score_mapping_dict['BendableScore'] = _BendableScore + _java_score_mapping_dict['SimpleDecimalScore'] = _SimpleDecimalScore + _java_score_mapping_dict['HardSoftDecimalScore'] = _HardSoftDecimalScore + _java_score_mapping_dict['HardMediumSoftDecimalScore'] = _HardMediumSoftDecimalScore + _java_score_mapping_dict['BendableDecimalScore'] = _BendableDecimalScore SimpleScoreType = translate_python_class_to_java_class(SimpleScore) HardSoftScoreType = translate_python_class_to_java_class(HardSoftScore) HardMediumSoftScoreType = translate_python_class_to_java_class(HardMediumSoftScore) BendableScoreType = translate_python_class_to_java_class(BendableScore) + SimpleDecimalScoreType = translate_python_class_to_java_class(SimpleDecimalScore) + HardSoftDecimalScoreType = translate_python_class_to_java_class(HardSoftDecimalScore) + HardMediumSoftDecimalScoreType = translate_python_class_to_java_class(HardMediumSoftDecimalScore) + BendableDecimalScoreType = translate_python_class_to_java_class(BendableDecimalScore) + add_python_java_type_mapping(SimpleScorePythonJavaTypeMapping(SimpleScoreType)) add_python_java_type_mapping(HardSoftScorePythonJavaTypeMapping(HardSoftScoreType)) add_python_java_type_mapping(HardMediumSoftScorePythonJavaTypeMapping(HardMediumSoftScoreType)) add_python_java_type_mapping(BendableScorePythonJavaTypeMapping(BendableScoreType)) + add_python_java_type_mapping(SimpleDecimalScorePythonJavaTypeMapping(SimpleDecimalScoreType)) + add_python_java_type_mapping(HardSoftDecimalScorePythonJavaTypeMapping(HardSoftDecimalScoreType)) + add_python_java_type_mapping(HardMediumSoftDecimalScorePythonJavaTypeMapping(HardMediumSoftDecimalScoreType)) + add_python_java_type_mapping(BendableDecimalScorePythonJavaTypeMapping(BendableDecimalScoreType)) + def forward_logging_events(event: 'PythonLoggingEvent') -> None: logger.log(event.level().getPythonLevelNumber(), diff --git a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py index 03047e09..bcdf128c 100644 --- a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py +++ b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py @@ -2,6 +2,7 @@ import jpype.imports # noqa from jpype import JClass from typing import TYPE_CHECKING, Type, Callable, overload, TypeVar, Generic, Any, Union, cast +from decimal import Decimal if TYPE_CHECKING: from ai.timefold.solver.core.api.score.stream.uni import (UniConstraintCollector, @@ -537,7 +538,7 @@ def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A], in constraint_weight : Score the weight of the constraint. - match_weigher : Callable[[A], int] + match_weigher : Callable[[A], int], optional a function that computes the weight of a match. If absent, each match has weight ``1``. @@ -553,6 +554,36 @@ def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A], in to_long_function_cast(match_weigher, self.a_type)), self.a_type) + def penalize_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A], Decimal] = None) -> \ + 'UniConstraintBuilder[A, ScoreType]': + """ + Applies a negative Score impact, subtracting the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.penalizeBigDecimal(constraint_weight), self.a_type) + else: + return UniConstraintBuilder(self.delegate.penalizeBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -580,6 +611,36 @@ def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] to_long_function_cast(match_weigher, self.a_type)), self.a_type) + def reward_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A], Decimal] = None) -> \ + 'UniConstraintBuilder[A, ScoreType]': + """ + Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.reward(constraint_weight), self.a_type) + else: + return UniConstraintBuilder(self.delegate.rewardBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + def impact(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -609,6 +670,37 @@ def impact(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] self.a_type)), self.a_type) + def impact_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A], Decimal] = None) -> \ + 'UniConstraintBuilder[A, ScoreType]': + """ + Positively or negatively impacts the `Score` by `constraint_weight` multiplied by match weight for each match + and returns a builder to apply optional constraint properties. + Use `penalize` or `reward` instead, unless this constraint can both have positive and negative weights. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.impact(constraint_weight), self.a_type) + else: + return UniConstraintBuilder(self.delegate.impactBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + def penalize_configurable(self, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -637,6 +729,36 @@ def penalize_configurable(self, match_weigher: Callable[[A], int] = None) -> \ self.a_type)), self.a_type) + def penalize_configurable_decimal(self, match_weigher: Callable[[A], Decimal] = None) \ + -> 'UniConstraintBuilder[A, ScoreType]': + """ + Negatively impacts the Score, subtracting the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `penalize` instead. + + Parameters + ---------- + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.penalizeConfigurable(), self.a_type) + else: + return UniConstraintBuilder(self.delegate.penalizeConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + def reward_configurable(self, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -645,7 +767,7 @@ def reward_configurable(self, match_weigher: Callable[[A], int] = None) -> \ The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `reward` instead. Parameters ---------- @@ -665,6 +787,36 @@ def reward_configurable(self, match_weigher: Callable[[A], int] = None) -> \ self.a_type)), self.a_type) + def reward_configurable_decimal(self, match_weigher: Callable[[A], Decimal] = None) \ + -> 'UniConstraintBuilder[A, ScoreType]': + """ + Positively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `reward` instead. + + Parameters + ---------- + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.rewardConfigurable(), self.a_type) + else: + return UniConstraintBuilder(self.delegate.rewardConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + def impact_configurable(self, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -673,7 +825,7 @@ def impact_configurable(self, match_weigher: Callable[[A], int] = None) -> \ The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `impact` instead. Parameters ---------- @@ -694,6 +846,37 @@ def impact_configurable(self, match_weigher: Callable[[A], int] = None) -> \ self.a_type) + def impact_configurable_decimal(self, match_weigher: Callable[[A], Decimal] = None) \ + -> 'UniConstraintBuilder[A, ScoreType]': + """ + Positively or negatively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `impact` instead. + + Parameters + ---------- + match_weigher : Callable[[A], Decimal], optional + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + UniConstraintBuilder + a `UniConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return UniConstraintBuilder(self.delegate.impactConfigurable(), self.a_type) + else: + return UniConstraintBuilder(self.delegate.impactConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + return_type=BigDecimal)), + self.a_type) + + class BiConstraintStream(Generic[A, B]): """ A ConstraintStream that matches two facts. @@ -1165,7 +1348,7 @@ def complement(self, cls: type[A], padding=None): padding : Callable[[A], B] a function that computes the padding value for the second fact in the new tuple. """ - if None == padding: + if None is padding: result = self.delegate.complement(get_class(cls)) return BiConstraintStream(result, self.package, self.a_type, self.b_type) java_padding = function_cast(padding, self.a_type) @@ -1201,6 +1384,38 @@ def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], self.b_type)), self.a_type, self.b_type) + + def penalize_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Applies a negative Score impact, subtracting the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.penalize(constraint_weight), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.penalizeBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1230,6 +1445,37 @@ def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], i self.b_type)), self.a_type, self.b_type) + def reward_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.reward(constraint_weight), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.rewardBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + def impact(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1260,6 +1506,39 @@ def impact(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], i self.b_type)), self.a_type, self.b_type) + + def impact_decimal(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Positively or negatively impacts the `Score` by `constraint_weight` multiplied by match weight for each match + and returns a builder to apply optional constraint properties. + Use `penalize` or `reward` instead, unless this constraint can both have positive and negative weights. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.impact(constraint_weight), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.impactBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + def penalize_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1290,6 +1569,37 @@ def penalize_configurable(self, match_weigher: Callable[[A, B], int] = None) -> self.b_type)), self.a_type, self.b_type) + def penalize_configurable_decimal(self, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Negatively impacts the Score, subtracting the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `penalize` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.penalizeConfigurable(), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.penalizeConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + def reward_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1298,7 +1608,7 @@ def reward_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `reward` instead. Parameters ---------- @@ -1320,6 +1630,37 @@ def reward_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ self.b_type)), self.a_type, self.b_type) + def reward_configurable_decimal(self, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Positively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `reward` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.rewardConfigurable(), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.rewardConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + def impact_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1328,7 +1669,7 @@ def impact_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `impact` instead. Parameters ---------- @@ -1350,6 +1691,37 @@ def impact_configurable(self, match_weigher: Callable[[A, B], int] = None) -> \ self.b_type)), self.a_type, self.b_type) + def impact_configurable_decimal(self, match_weigher: Callable[[A, B], Decimal] = None) -> \ + 'BiConstraintBuilder[A, B, ScoreType]': + """ + Positively or negatively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `impact` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + BiConstraintBuilder + a `BiConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return BiConstraintBuilder(self.delegate.impactConfigurable(), self.a_type, self.b_type) + else: + return BiConstraintBuilder(self.delegate.impactConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + return_type=BigDecimal)), + self.a_type, self.b_type) + class TriConstraintStream(Generic[A, B, C]): """ @@ -1864,11 +2236,109 @@ def penalize(self, constraint_weight: ScoreType, self.c_type)), self.a_type, self.b_type, self.c_type) - def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C], int] = None) -> \ - 'TriConstraintBuilder[A, B, C, ScoreType]': + def penalize_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Applies a negative Score impact, subtracting the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return TriConstraintBuilder(self.delegate.penalize(constraint_weight), + self.a_type, self.b_type, self.c_type) + else: + return TriConstraintBuilder(self.delegate.penalizeBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type) + + def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C], int] = None) -> \ + 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C], int] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + if match_weigher is None: + return TriConstraintBuilder(self.delegate.reward(constraint_weight), self.a_type, self.b_type, + self.c_type) + else: + return TriConstraintBuilder(self.delegate.rewardLong(constraint_weight, + to_long_function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type)), + self.a_type, self.b_type, self.c_type) + + def reward_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return TriConstraintBuilder(self.delegate.reward(constraint_weight), + self.a_type, self.b_type, self.c_type) + else: + return TriConstraintBuilder(self.delegate.rewardBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type) + + def impact(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C], int] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': """ - Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + Positively or negatively impacts the `Score` by `constraint_weight` multiplied by match weight for each match and returns a builder to apply optional constraint properties. + Use `penalize` or `reward` instead, unless this constraint can both have positive and negative weights. Parameters ---------- @@ -1885,18 +2355,18 @@ def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C] a `TriConstraintBuilder` """ if match_weigher is None: - return TriConstraintBuilder(self.delegate.reward(constraint_weight), self.a_type, self.b_type, - self.c_type) + return TriConstraintBuilder(self.delegate.impact(constraint_weight), + self.a_type, self.b_type, self.c_type) else: - return TriConstraintBuilder(self.delegate.rewardLong(constraint_weight, + return TriConstraintBuilder(self.delegate.impactLong(constraint_weight, to_long_function_cast(match_weigher, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type, self.c_type) - def impact(self, constraint_weight: ScoreType, - match_weigher: Callable[[A, B, C], int] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + def impact_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': """ Positively or negatively impacts the `Score` by `constraint_weight` multiplied by match weight for each match and returns a builder to apply optional constraint properties. @@ -1907,7 +2377,7 @@ def impact(self, constraint_weight: ScoreType, constraint_weight : Score the weight of the constraint. - match_weigher : Callable[[A, B, C], int] + match_weigher : Callable[[A, B, C], Decimal] a function that computes the weight of a match. If absent, each match has weight ``1``. @@ -1916,15 +2386,17 @@ def impact(self, constraint_weight: ScoreType, TriConstraintBuilder a `TriConstraintBuilder` """ + from java.math import BigDecimal if match_weigher is None: return TriConstraintBuilder(self.delegate.impact(constraint_weight), self.a_type, self.b_type, self.c_type) else: - return TriConstraintBuilder(self.delegate.impactLong(constraint_weight, - to_long_function_cast(match_weigher, - self.a_type, - self.b_type, - self.c_type)), + return TriConstraintBuilder(self.delegate.impactBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), self.a_type, self.b_type, self.c_type) def penalize_configurable(self, match_weigher: Callable[[A, B, C], int] = None) \ @@ -1959,6 +2431,39 @@ def penalize_configurable(self, match_weigher: Callable[[A, B, C], int] = None) self.c_type)), self.a_type, self.b_type, self.c_type) + def penalize_configurable_decimal(self, match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Negatively impacts the Score, subtracting the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `penalize` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return TriConstraintBuilder(self.delegate.penalizeConfigurable(), + self.a_type, self.b_type, self.c_type) + else: + return TriConstraintBuilder(self.delegate.penalizeConfigurableBigDecimal( + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type) + def reward_configurable(self, match_weigher: Callable[[A, B, C], int] = None) -> \ 'TriConstraintBuilder[A, B, C, ScoreType]': """ @@ -1967,7 +2472,7 @@ def reward_configurable(self, match_weigher: Callable[[A, B, C], int] = None) -> The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `reward` instead. Parameters ---------- @@ -1991,6 +2496,40 @@ def reward_configurable(self, match_weigher: Callable[[A, B, C], int] = None) -> self.c_type)), self.a_type, self.b_type, self.c_type) + + def reward_configurable_decimal(self, match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Positively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `reward` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return TriConstraintBuilder(self.delegate.rewardConfigurable(), + self.a_type, self.b_type, self.c_type) + else: + return TriConstraintBuilder(self.delegate.rewardConfigurableBigDecimal( + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type) + def impact_configurable(self, match_weigher: Callable[[A, B, C], int] = None) \ -> 'TriConstraintBuilder[A, B, C, ScoreType]': """ @@ -1999,7 +2538,7 @@ def impact_configurable(self, match_weigher: Callable[[A, B, C], int] = None) \ The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `impact` instead. Parameters ---------- @@ -2023,6 +2562,39 @@ def impact_configurable(self, match_weigher: Callable[[A, B, C], int] = None) \ self.c_type)), self.a_type, self.b_type, self.c_type) + def impact_configurable_decimal(self, match_weigher: Callable[[A, B, C], Decimal] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': + """ + Positively or negatively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `impact` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + TriConstraintBuilder + a `TriConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return TriConstraintBuilder(self.delegate.impactConfigurable(), + self.a_type, self.b_type, self.c_type) + else: + return TriConstraintBuilder(self.delegate.impactConfigurableBigDecimal( + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type) + class QuadConstraintStream(Generic[A, B, C, D]): """ @@ -2522,6 +3094,40 @@ def penalize(self, constraint_weight: ScoreType, self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def penalize_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Applies a negative Score impact, subtracting the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.penalize(constraint_weight), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.penalizeBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + def reward(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C, D], int] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ @@ -2554,6 +3160,40 @@ def reward(self, constraint_weight: ScoreType, self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def reward_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Applies a positive Score impact, adding the constraint_weight multiplied by the match weight, + and returns a builder to apply optional constraint properties. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.reward(constraint_weight), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.rewardBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + def impact(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C, D], int] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ @@ -2587,6 +3227,41 @@ def impact(self, constraint_weight: ScoreType, self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def impact_decimal(self, constraint_weight: ScoreType, + match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Positively or negatively impacts the `Score` by `constraint_weight` multiplied by match weight for each match + and returns a builder to apply optional constraint properties. + Use `penalize` or `reward` instead, unless this constraint can both have positive and negative weights. + + Parameters + ---------- + constraint_weight : Score + the weight of the constraint. + + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.impact(constraint_weight), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.impactBigDecimal(constraint_weight, + function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + def penalize_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) \ -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ @@ -2620,6 +3295,39 @@ def penalize_configurable(self, match_weigher: Callable[[A, B, C, D], int] = Non self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def penalize_configurable_decimal(self, match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Negatively impacts the Score, subtracting the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `penalize` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.penalizeConfigurable(), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.penalizeConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + def reward_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) \ -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ @@ -2628,7 +3336,8 @@ def reward_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `reward` instead. + Parameters ---------- match_weigher : Callable[[A, B, C, D], int] @@ -2652,6 +3361,39 @@ def reward_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def reward_configurable_decimal(self, match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Positively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `reward` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.rewardConfigurable(), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.rewardConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + def impact_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) \ -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ @@ -2660,7 +3402,7 @@ def impact_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, so end users can change the constraint weights dynamically. This constraint may be deactivated if the `ConstraintWeight` is zero. - If there is no `constraint_configuration`, use `penalize` instead. + If there is no `constraint_configuration`, use `impact` instead. Parameters ---------- @@ -2685,6 +3427,39 @@ def impact_configurable(self, match_weigher: Callable[[A, B, C, D], int] = None) self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) + def impact_configurable_decimal(self, match_weigher: Callable[[A, B, C, D], Decimal] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + """ + Positively or negatively impacts the Score, adding the ConstraintWeight for each match, + and returns a builder to apply optional constraint properties. + The constraint weight comes from a `ConstraintWeight` annotated member on the `constraint_configuration`, + so end users can change the constraint weights dynamically. + This constraint may be deactivated if the `ConstraintWeight` is zero. + If there is no `constraint_configuration`, use `impact` instead. + + Parameters + ---------- + match_weigher : Callable[[A, B, C, D], Decimal] + a function that computes the weight of a match. + If absent, each match has weight ``1``. + + Returns + ------- + QuadConstraintBuilder + a `QuadConstraintBuilder` + """ + from java.math import BigDecimal + if match_weigher is None: + return QuadConstraintBuilder(self.delegate.impactConfigurable(), + self.a_type, self.b_type, self.c_type, self.d_type) + else: + return QuadConstraintBuilder(self.delegate.impactConfigurableBigDecimal(function_cast(match_weigher, + self.a_type, + self.b_type, + self.c_type, + self.d_type, + return_type=BigDecimal)), + self.a_type, self.b_type, self.c_type, self.d_type) + # Must be on the bottom, .group_by depends on this module from ._constraint_factory import * diff --git a/timefold-solver-python-core/src/main/python/score/_function_translator.py b/timefold-solver-python-core/src/main/python/score/_function_translator.py index ed80bbd6..18daf809 100644 --- a/timefold-solver-python-core/src/main/python/score/_function_translator.py +++ b/timefold-solver-python-core/src/main/python/score/_function_translator.py @@ -78,7 +78,7 @@ def _check_if_type_args_are_python_object_wrappers(type_args): return False -def function_cast(function, *type_args): +def function_cast(function, *type_args, return_type=None): arg_count = len(inspect.signature(function).parameters) if len(type_args) != arg_count: raise ValueError(f'Invalid function: expected {len(type_args)} arguments but got {arg_count}') @@ -90,18 +90,21 @@ def function_cast(function, *type_args): from ai.timefold.solver.core.api.function import TriFunction, QuadFunction, PentaFunction from ai.timefold.jpyinterpreter import PythonLikeObject + if return_type is None: + return_type = PythonLikeObject + try: _check_if_bytecode_translation_possible() if arg_count == 1: - return translate_python_bytecode_to_java_bytecode(function, Function, *type_args, PythonLikeObject) + return translate_python_bytecode_to_java_bytecode(function, Function, *type_args, return_type) elif arg_count == 2: - return translate_python_bytecode_to_java_bytecode(function, BiFunction, *type_args, PythonLikeObject) + return translate_python_bytecode_to_java_bytecode(function, BiFunction, *type_args, return_type) elif arg_count == 3: - return translate_python_bytecode_to_java_bytecode(function, TriFunction, *type_args, PythonLikeObject) + return translate_python_bytecode_to_java_bytecode(function, TriFunction, *type_args, return_type) elif arg_count == 4: - return translate_python_bytecode_to_java_bytecode(function, QuadFunction, *type_args, PythonLikeObject) + return translate_python_bytecode_to_java_bytecode(function, QuadFunction, *type_args, return_type) elif arg_count == 5: - return translate_python_bytecode_to_java_bytecode(function, PentaFunction, *type_args, PythonLikeObject) + return translate_python_bytecode_to_java_bytecode(function, PentaFunction, *type_args, return_type) except: # noqa return default_function_cast(function, arg_count) diff --git a/timefold-solver-python-core/src/main/python/score/_score.py b/timefold-solver-python-core/src/main/python/score/_score.py index 73f7fd02..982011ce 100644 --- a/timefold-solver-python-core/src/main/python/score/_score.py +++ b/timefold-solver-python-core/src/main/python/score/_score.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from typing import ClassVar from dataclasses import dataclass, field -from jpype import JArray, JInt +from jpype import JArray, JLong +from decimal import Decimal from .._timefold_java_interop import _java_score_mapping_dict @@ -288,9 +289,9 @@ def parse(score_text: str) -> 'BendableScore': return BendableScore(hard_scores, soft_scores, init_score=int(init.rstrip('init'))) def _to_java_score(self): - IntArrayCls = JArray(JInt) - hard_scores = IntArrayCls(self.hard_scores) - soft_scores = IntArrayCls(self.soft_scores) + LongArrayCls = JArray(JLong) + hard_scores = LongArrayCls(self.hard_scores) + soft_scores = LongArrayCls(self.soft_scores) if self.init_score < 0: return _java_score_mapping_dict['BendableScore'].ofUninitialized(self.init_score, hard_scores, soft_scores) else: @@ -303,7 +304,269 @@ def __str__(self): f'{self.init_score}init/{hard_text}/{soft_text}') +############################################################## +# Decimal variants +############################################################## +@dataclass(unsafe_hash=True, order=True) +class SimpleDecimalScore(Score): + """ + This Score is based on one level of `Decimal` constraints. + This class is immutable. + + Attributes + ---------- + score : Decimal + The total of the broken negative constraints and fulfilled positive constraints. + Their weight is included in the total. + The score is usually a negative number because most use cases only have negative constraints. + """ + ZERO: ClassVar['SimpleDecimalScore'] + ONE: ClassVar['SimpleDecimalScore'] + + score: Decimal = field(compare=True) + + @property + def is_feasible(self) -> bool: + return self.is_solution_initialized + + @staticmethod + def of(score: Decimal) -> 'SimpleDecimalScore': + return SimpleDecimalScore(score, init_score=0) + + @staticmethod + def parse(score_text: str) -> 'SimpleDecimalScore': + if 'init' in score_text: + init, score = score_text.split('/') + else: + init = '0init' + score = score_text + + return SimpleDecimalScore(Decimal(score), init_score=int(init.rstrip('init'))) + + def _to_java_score(self): + if self.init_score < 0: + return _java_score_mapping_dict['SimpleDecimalScore'].ofUninitialized(self.init_score, self.score) + else: + return _java_score_mapping_dict['SimpleDecimalScore'].of(self.score) + + def __str__(self): + return (f'{self.score}' if self.is_solution_initialized else + f'{self.init_score}init/{self.score}') + + +SimpleDecimalScore.ZERO = SimpleDecimalScore.of(Decimal(0)) +SimpleDecimalScore.ONE = SimpleDecimalScore.of(Decimal(1)) + + +@dataclass(unsafe_hash=True, order=True) +class HardSoftDecimalScore(Score): + """ + This Score is based on two levels of int constraints: hard and soft. + Hard constraints have priority over soft constraints. + Hard constraints determine feasibility. + + This class is immutable. + + Attributes + ---------- + hard_score : Decimal + The total of the broken negative hard constraints and fulfilled positive hard constraints. + Their weight is included in the total. + The hard score is usually a negative number because most use cases only have negative constraints. + + soft_score : Decimal + The total of the broken negative soft constraints and fulfilled positive soft constraints. + Their weight is included in the total. + The soft score is usually a negative number because most use cases only have negative constraints. + + In a normal score comparison, the soft score is irrelevant if the two scores don't have the same hard score. + """ + ZERO: ClassVar['HardSoftDecimalScore'] + ONE_HARD: ClassVar['HardSoftDecimalScore'] + ONE_SOFT: ClassVar['HardSoftDecimalScore'] + + hard_score: Decimal = field(compare=True) + soft_score: Decimal = field(compare=True) + + @property + def is_feasible(self) -> bool: + return self.is_solution_initialized and self.hard_score >= 0 + + @staticmethod + def of(hard_score: Decimal, soft_score: Decimal) -> 'HardSoftDecimalScore': + return HardSoftDecimalScore(hard_score, soft_score, init_score=0) + + @staticmethod + def parse(score_text: str) -> 'HardSoftDecimalScore': + if 'init' in score_text: + init, hard, soft = score_text.split('/') + else: + init = '0init' + hard, soft = score_text.split('/') + + return HardSoftDecimalScore(Decimal(hard.rstrip('hard')), Decimal(soft.rstrip('soft')), + init_score=int(init.rstrip('init'))) + + def _to_java_score(self): + if self.init_score < 0: + return _java_score_mapping_dict['HardSoftDecimalScore'].ofUninitialized(self.init_score, self.hard_score, self.soft_score) + else: + return _java_score_mapping_dict['HardSoftDecimalScore'].of(self.hard_score, self.soft_score) + + def __str__(self): + return (f'{self.hard_score}hard/{self.soft_score}soft' if self.is_solution_initialized else + f'{self.init_score}init/{self.hard_score}hard/{self.soft_score}soft') + + +HardSoftDecimalScore.ZERO = HardSoftDecimalScore.of(Decimal(0), Decimal(0)) +HardSoftDecimalScore.ONE_HARD = HardSoftDecimalScore.of(Decimal(1), Decimal(0)) +HardSoftDecimalScore.ONE_SOFT = HardSoftDecimalScore.of(Decimal(0), Decimal(1)) + + +@dataclass(unsafe_hash=True, order=True) +class HardMediumSoftDecimalScore(Score): + """ + This Score is based on three levels of int constraints: hard, medium and soft. + Hard constraints have priority over medium constraints. + Medium constraints have priority over soft constraints. + Hard constraints determine feasibility. + + This class is immutable. + + Attributes + ---------- + hard_score : Decimal + The total of the broken negative hard constraints and fulfilled positive hard constraints. + Their weight is included in the total. + The hard score is usually a negative number because most use cases only have negative constraints. + + medium_score : Decimal + The total of the broken negative medium constraints and fulfilled positive medium constraints. + Their weight is included in the total. + The medium score is usually a negative number because most use cases only have negative constraints. + + In a normal score comparison, + the medium score is irrelevant if the two scores don't have the same hard score. + + soft_score : Decimal + The total of the broken negative soft constraints and fulfilled positive soft constraints. + Their weight is included in the total. + The soft score is usually a negative number because most use cases only have negative constraints. + + In a normal score comparison, + the soft score is irrelevant if the two scores don't have the same hard and medium score. + """ + ZERO: ClassVar['HardMediumSoftDecimalScore'] + ONE_HARD: ClassVar['HardMediumSoftDecimalScore'] + ONE_MEDIUM: ClassVar['HardMediumSoftDecimalScore'] + ONE_SOFT: ClassVar['HardMediumSoftDecimalScore'] + + hard_score: Decimal = field(compare=True) + medium_score: Decimal = field(compare=True) + soft_score: Decimal = field(compare=True) + + @property + def is_feasible(self) -> bool: + return self.is_solution_initialized and self.hard_score >= 0 + + @staticmethod + def of(hard_score: Decimal, medium_score: Decimal, soft_score: Decimal) -> 'HardMediumSoftDecimalScore': + return HardMediumSoftDecimalScore(hard_score, medium_score, soft_score, init_score=0) + + @staticmethod + def parse(score_text: str) -> 'HardMediumSoftDecimalScore': + if 'init' in score_text: + init, hard, medium, soft = score_text.split('/') + else: + init = '0init' + hard, medium, soft = score_text.split('/') + + return HardMediumSoftDecimalScore(Decimal(hard.rstrip('hard')), Decimal(medium.rstrip('medium')), + Decimal(soft.rstrip('soft')), init_score=int(init.rstrip('init'))) + + def _to_java_score(self): + if self.init_score < 0: + return _java_score_mapping_dict['HardMediumSoftDecimalScore'].ofUninitialized(self.init_score, self.hard_score, + self.medium_score, self.soft_score) + else: + return _java_score_mapping_dict['HardMediumSoftDecimalScore'].of(self.hard_score, self.medium_score, self.soft_score) + + def __str__(self): + return (f'{self.hard_score}hard/{self.medium_score}medium/{self.soft_score}soft' + if self.is_solution_initialized else + f'{self.init_score}init/{self.hard_score}hard/{self.medium_score}medium/{self.soft_score}soft') + + +HardMediumSoftDecimalScore.ZERO = HardMediumSoftDecimalScore.of(Decimal(0), Decimal(0), Decimal(0)) +HardMediumSoftDecimalScore.ONE_HARD = HardMediumSoftDecimalScore.of(Decimal(1), Decimal(0), Decimal(0)) +HardMediumSoftDecimalScore.ONE_MEDIUM = HardMediumSoftDecimalScore.of(Decimal(0), Decimal(1), Decimal(0)) +HardMediumSoftDecimalScore.ONE_SOFT = HardMediumSoftDecimalScore.of(Decimal(0), Decimal(0), Decimal(1)) + + +@dataclass(unsafe_hash=True, order=True) +class BendableDecimalScore(Score): + """ + This Score is based on n levels of int constraints. + The number of levels is bendable at configuration time. + + This class is immutable. + + Attributes + ---------- + hard_scores : tuple[Decimal, ...] + A tuple of hard scores, with earlier hard scores having higher priority than later ones. + + soft_scores : tuple[Decimal, ...] + A tuple of soft scores, with earlier soft scores having higher priority than later ones + """ + hard_scores: tuple[Decimal, ...] = field(compare=True) + soft_scores: tuple[Decimal, ...] = field(compare=True) + + @property + def is_feasible(self) -> bool: + return self.is_solution_initialized and all(score >= 0 for score in self.hard_scores) + + @staticmethod + def of(hard_scores: tuple[Decimal, ...], soft_scores: tuple[Decimal, ...]) -> 'BendableDecimalScore': + return BendableDecimalScore(hard_scores, soft_scores, init_score=0) + + @staticmethod + def parse(score_text: str) -> 'BendableDecimalScore': + if 'init' in score_text: + init, hard_score_text, soft_score_text = score_text.split('/[') + else: + hard_score_text, soft_score_text = score_text.split('/[') + # Remove leading [ from hard score text, + # since there is no init score in the text + # (and thus the split will not consume it) + hard_score_text = hard_score_text[1:] + init = '0init' + + hard_scores = tuple([Decimal(score) for score in hard_score_text[:hard_score_text.index(']')].split('/')]) + soft_scores = tuple([Decimal(score) for score in soft_score_text[:soft_score_text.index(']')].split('/')]) + return BendableDecimalScore(hard_scores, soft_scores, init_score=int(init.rstrip('init'))) + + def _to_java_score(self): + from java.math import BigDecimal + BigDecimalArrayCls = JArray(BigDecimal) + hard_scores = BigDecimalArrayCls([BigDecimal(str(score)) for score in self.hard_scores]) + soft_scores = BigDecimalArrayCls([BigDecimal(str(score)) for score in self.soft_scores]) + if self.init_score < 0: + return _java_score_mapping_dict['BendableDecimalScore'].ofUninitialized(self.init_score, hard_scores, + soft_scores) + else: + return _java_score_mapping_dict['BendableDecimalScore'].of(hard_scores, soft_scores) + + def __str__(self): + hard_text = f'[{"/".join([str(score) for score in self.hard_scores])}]hard' + soft_text = f'[{"/".join([str(score) for score in self.soft_scores])}]soft' + return (f'{hard_text}/{soft_text}' if self.is_solution_initialized else + f'{self.init_score}init/{hard_text}/{soft_text}') + + # Import score conversions here to register conversions (circular import) from ._score_conversions import * -__all__ = ['Score', 'SimpleScore', 'HardSoftScore', 'HardMediumSoftScore', 'BendableScore'] +__all__ = ['Score', + 'SimpleScore', 'HardSoftScore', 'HardMediumSoftScore', 'BendableScore', + 'SimpleDecimalScore', 'HardSoftDecimalScore', 'HardMediumSoftDecimalScore', 'BendableDecimalScore'] diff --git a/timefold-solver-python-core/src/main/python/score/_score_conversions.py b/timefold-solver-python-core/src/main/python/score/_score_conversions.py index 8be78820..841bcd49 100644 --- a/timefold-solver-python-core/src/main/python/score/_score_conversions.py +++ b/timefold-solver-python-core/src/main/python/score/_score_conversions.py @@ -20,3 +20,23 @@ def _convert_hard_medium_soft_score(jcls, score: HardMediumSoftScore): @JConversion('ai.timefold.solver.core.api.score.Score', exact=BendableScore) def _convert_bendable_score(jcls, score: BendableScore): return score._to_java_score() + + +@JConversion('ai.timefold.solver.core.api.score.Score', exact=SimpleDecimalScore) +def _convert_simple_decimal_score(jcls, score: SimpleDecimalScore): + return score._to_java_score() + + +@JConversion('ai.timefold.solver.core.api.score.Score', exact=HardSoftDecimalScore) +def _convert_hard_soft_decimal_score(jcls, score: HardSoftDecimalScore): + return score._to_java_score() + + +@JConversion('ai.timefold.solver.core.api.score.Score', exact=HardMediumSoftDecimalScore) +def _convert_hard_medium_soft_decimal_score(jcls, score: HardMediumSoftDecimalScore): + return score._to_java_score() + + +@JConversion('ai.timefold.solver.core.api.score.Score', exact=BendableDecimalScore) +def _convert_bendable_decimal_score(jcls, score: BendableDecimalScore): + return score._to_java_score() diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMappingTest.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMappingTest.java new file mode 100644 index 00000000..cf5f0656 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/BendableDecimalScorePythonJavaTypeMappingTest.java @@ -0,0 +1,95 @@ +package ai.timefold.solver.python.score; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class BendableDecimalScorePythonJavaTypeMappingTest { + BendableDecimalScorePythonJavaTypeMapping typeMapping; + + @BeforeEach + void setUp() throws NoSuchFieldException, ClassNotFoundException, NoSuchMethodException { + this.typeMapping = new BendableDecimalScorePythonJavaTypeMapping(PythonBendableDecimalScore.TYPE); + } + + @Test + void getPythonType() { + assertThat(typeMapping.getPythonType()).isEqualTo(PythonBendableDecimalScore.TYPE); + } + + @Test + void getJavaType() { + assertThat(typeMapping.getJavaType()).isEqualTo(BendableBigDecimalScore.class); + } + + @Test + void toPythonObject() { + var initializedScore = BendableBigDecimalScore.of( + new BigDecimal[] { BigDecimal.valueOf(10), BigDecimal.valueOf(20), BigDecimal.valueOf(30) }, + new BigDecimal[] { BigDecimal.valueOf(4), BigDecimal.valueOf(5) }); + + var initializedPythonScore = (PythonBendableDecimalScore) typeMapping.toPythonObject(initializedScore); + + assertThat(initializedPythonScore.init_score).isEqualTo(PythonInteger.ZERO); + + assertThat(initializedPythonScore.hard_scores.size()).isEqualTo(3); + assertThat(initializedPythonScore.hard_scores.get(0)).isEqualTo(PythonDecimal.valueOf("10")); + assertThat(initializedPythonScore.hard_scores.get(1)).isEqualTo(PythonDecimal.valueOf("20")); + assertThat(initializedPythonScore.hard_scores.get(2)).isEqualTo(PythonDecimal.valueOf("30")); + + assertThat(initializedPythonScore.soft_scores.size()).isEqualTo(2); + assertThat(initializedPythonScore.soft_scores.get(0)).isEqualTo(PythonDecimal.valueOf("4")); + assertThat(initializedPythonScore.soft_scores.get(1)).isEqualTo(PythonDecimal.valueOf("5")); + + var uninitializedScore = BendableBigDecimalScore.ofUninitialized(-300, + new BigDecimal[] { BigDecimal.valueOf(10), BigDecimal.valueOf(20), BigDecimal.valueOf(30) }, + new BigDecimal[] { BigDecimal.valueOf(4), BigDecimal.valueOf(5) }); + var uninitializedPythonScore = (PythonBendableDecimalScore) typeMapping.toPythonObject(uninitializedScore); + + assertThat(uninitializedPythonScore.init_score).isEqualTo(PythonInteger.valueOf(-300)); + + assertThat(uninitializedPythonScore.hard_scores.size()).isEqualTo(3); + assertThat(uninitializedPythonScore.hard_scores.get(0)).isEqualTo(PythonDecimal.valueOf("10")); + assertThat(uninitializedPythonScore.hard_scores.get(1)).isEqualTo(PythonDecimal.valueOf("20")); + assertThat(uninitializedPythonScore.hard_scores.get(2)).isEqualTo(PythonDecimal.valueOf("30")); + + assertThat(uninitializedPythonScore.soft_scores.size()).isEqualTo(2); + assertThat(uninitializedPythonScore.soft_scores.get(0)).isEqualTo(PythonDecimal.valueOf("4")); + assertThat(uninitializedPythonScore.soft_scores.get(1)).isEqualTo(PythonDecimal.valueOf("5")); + } + + @Test + void toJavaObject() { + var initializedScore = PythonBendableDecimalScore.of(new int[] { 10, 20, 30 }, new int[] { 4, 5 }); + + var initializedJavaScore = typeMapping.toJavaObject(initializedScore); + + assertThat(initializedJavaScore.initScore()).isEqualTo(0); + assertThat(initializedJavaScore.hardScores()).containsExactly( + BigDecimal.valueOf(10), + BigDecimal.valueOf(20), + BigDecimal.valueOf(30)); + assertThat(initializedJavaScore.softScores()).containsExactly( + BigDecimal.valueOf(4), + BigDecimal.valueOf(5)); + + var uninitializedScore = PythonBendableDecimalScore.ofUninitialized(-300, new int[] { 10, 20, 30 }, new int[] { 4, 5 }); + var uninitializedJavaScore = typeMapping.toJavaObject(uninitializedScore); + + assertThat(uninitializedJavaScore.initScore()).isEqualTo(-300); + assertThat(uninitializedJavaScore.hardScores()).containsExactly( + BigDecimal.valueOf(10), + BigDecimal.valueOf(20), + BigDecimal.valueOf(30)); + assertThat(uninitializedJavaScore.softScores()).containsExactly( + BigDecimal.valueOf(4), + BigDecimal.valueOf(5)); + } +} \ No newline at end of file diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMappingTest.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMappingTest.java new file mode 100644 index 00000000..a84fe3a7 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardMediumSoftDecimalScorePythonJavaTypeMappingTest.java @@ -0,0 +1,75 @@ +package ai.timefold.solver.python.score; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.hardmediumsoftbigdecimal.HardMediumSoftBigDecimalScore; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class HardMediumSoftDecimalScorePythonJavaTypeMappingTest { + HardMediumSoftDecimalScorePythonJavaTypeMapping typeMapping; + + @BeforeEach + void setUp() throws NoSuchFieldException, ClassNotFoundException, NoSuchMethodException { + this.typeMapping = new HardMediumSoftDecimalScorePythonJavaTypeMapping(PythonHardMediumSoftDecimalScore.TYPE); + } + + @Test + void getPythonType() { + assertThat(typeMapping.getPythonType()).isEqualTo(PythonHardMediumSoftDecimalScore.TYPE); + } + + @Test + void getJavaType() { + assertThat(typeMapping.getJavaType()).isEqualTo(HardMediumSoftBigDecimalScore.class); + } + + @Test + void toPythonObject() { + var initializedScore = HardMediumSoftBigDecimalScore.of(BigDecimal.valueOf(300), + BigDecimal.valueOf(20), + BigDecimal.valueOf(1)); + + var initializedPythonScore = (PythonHardMediumSoftDecimalScore) typeMapping.toPythonObject(initializedScore); + + assertThat(initializedPythonScore.init_score).isEqualTo(PythonInteger.ZERO); + assertThat(initializedPythonScore.hard_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(300))); + assertThat(initializedPythonScore.medium_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(20))); + assertThat(initializedPythonScore.soft_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(1))); + + var uninitializedScore = HardMediumSoftBigDecimalScore.ofUninitialized(-4000, BigDecimal.valueOf(300), + BigDecimal.valueOf(20), + BigDecimal.valueOf(1)); + var uninitializedPythonScore = (PythonHardMediumSoftDecimalScore) typeMapping.toPythonObject(uninitializedScore); + + assertThat(uninitializedPythonScore.init_score).isEqualTo(PythonInteger.valueOf(-4000)); + assertThat(initializedPythonScore.hard_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(300))); + assertThat(initializedPythonScore.medium_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(20))); + assertThat(initializedPythonScore.soft_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(1))); + } + + @Test + void toJavaObject() { + var initializedScore = PythonHardMediumSoftDecimalScore.of(300, 20, 1); + + var initializedJavaScore = typeMapping.toJavaObject(initializedScore); + + assertThat(initializedJavaScore.initScore()).isEqualTo(0); + assertThat(initializedJavaScore.hardScore()).isEqualTo(BigDecimal.valueOf(300)); + assertThat(initializedJavaScore.mediumScore()).isEqualTo(BigDecimal.valueOf(20)); + assertThat(initializedJavaScore.softScore()).isEqualTo(BigDecimal.valueOf(1)); + + var uninitializedScore = PythonHardMediumSoftDecimalScore.ofUninitialized(-4000, 300, 20, 1); + var uninitializedJavaScore = typeMapping.toJavaObject(uninitializedScore); + + assertThat(uninitializedJavaScore.initScore()).isEqualTo(-4000); + assertThat(initializedJavaScore.hardScore()).isEqualTo(BigDecimal.valueOf(300)); + assertThat(initializedJavaScore.mediumScore()).isEqualTo(BigDecimal.valueOf(20)); + assertThat(initializedJavaScore.softScore()).isEqualTo(BigDecimal.valueOf(1)); + } +} \ No newline at end of file diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMappingTest.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMappingTest.java new file mode 100644 index 00000000..e2e934e2 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/HardSoftDecimalScorePythonJavaTypeMappingTest.java @@ -0,0 +1,67 @@ +package ai.timefold.solver.python.score; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.hardsoftbigdecimal.HardSoftBigDecimalScore; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class HardSoftDecimalScorePythonJavaTypeMappingTest { + HardSoftDecimalScorePythonJavaTypeMapping typeMapping; + + @BeforeEach + void setUp() throws NoSuchFieldException, ClassNotFoundException, NoSuchMethodException { + this.typeMapping = new HardSoftDecimalScorePythonJavaTypeMapping(PythonHardSoftDecimalScore.TYPE); + } + + @Test + void getPythonType() { + assertThat(typeMapping.getPythonType()).isEqualTo(PythonHardSoftDecimalScore.TYPE); + } + + @Test + void getJavaType() { + assertThat(typeMapping.getJavaType()).isEqualTo(HardSoftBigDecimalScore.class); + } + + @Test + void toPythonObject() { + var initializedScore = HardSoftBigDecimalScore.of(BigDecimal.valueOf(10), BigDecimal.valueOf(2)); + + var initializedPythonScore = (PythonHardSoftDecimalScore) typeMapping.toPythonObject(initializedScore); + + assertThat(initializedPythonScore.init_score).isEqualTo(PythonInteger.ZERO); + assertThat(initializedPythonScore.hard_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(10))); + assertThat(initializedPythonScore.soft_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(2))); + + var uninitializedScore = HardSoftBigDecimalScore.ofUninitialized(-300, BigDecimal.valueOf(20), BigDecimal.valueOf(1)); + var uninitializedPythonScore = (PythonHardSoftDecimalScore) typeMapping.toPythonObject(uninitializedScore); + + assertThat(uninitializedPythonScore.init_score).isEqualTo(PythonInteger.valueOf(-300)); + assertThat(uninitializedPythonScore.hard_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(20))); + assertThat(uninitializedPythonScore.soft_score).isEqualTo(new PythonDecimal(BigDecimal.valueOf(1))); + } + + @Test + void toJavaObject() { + var initializedScore = PythonHardSoftDecimalScore.of(10, 2); + + var initializedJavaScore = typeMapping.toJavaObject(initializedScore); + + assertThat(initializedJavaScore.initScore()).isEqualTo(0); + assertThat(initializedJavaScore.hardScore()).isEqualTo(BigDecimal.valueOf(10)); + assertThat(initializedJavaScore.softScore()).isEqualTo(BigDecimal.valueOf(2)); + + var uninitializedScore = PythonHardSoftDecimalScore.ofUninitialized(-300, 20, 1); + var uninitializedJavaScore = typeMapping.toJavaObject(uninitializedScore); + + assertThat(uninitializedJavaScore.initScore()).isEqualTo(-300); + assertThat(uninitializedJavaScore.hardScore()).isEqualTo(BigDecimal.valueOf(20)); + assertThat(uninitializedJavaScore.softScore()).isEqualTo(BigDecimal.valueOf(1)); + } +} \ No newline at end of file diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonBendableDecimalScore.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonBendableDecimalScore.java new file mode 100644 index 00000000..797de616 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonBendableDecimalScore.java @@ -0,0 +1,46 @@ +package ai.timefold.solver.python.score; + +import java.math.BigDecimal; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; + +public class PythonBendableDecimalScore extends AbstractPythonLikeObject { + public static final PythonLikeType TYPE = new PythonLikeType("BendableDecimalScore", PythonBendableDecimalScore.class); + public PythonInteger init_score; + public PythonLikeTuple hard_scores; + public PythonLikeTuple soft_scores; + + public PythonBendableDecimalScore() { + super(TYPE); + } + + public static PythonBendableDecimalScore of(int[] hardScores, int[] softScores) { + var out = new PythonBendableDecimalScore(); + out.init_score = PythonInteger.ZERO; + out.hard_scores = IntStream.of(hardScores) + .mapToObj(i -> new PythonDecimal(BigDecimal.valueOf(i))) + .collect(Collectors.toCollection(PythonLikeTuple::new)); + out.soft_scores = IntStream.of(softScores) + .mapToObj(i -> new PythonDecimal(BigDecimal.valueOf(i))) + .collect(Collectors.toCollection(PythonLikeTuple::new)); + return out; + } + + public static PythonBendableDecimalScore ofUninitialized(int initScore, int[] hardScores, int[] softScores) { + var out = new PythonBendableDecimalScore(); + out.init_score = PythonInteger.valueOf(initScore); + out.hard_scores = IntStream.of(hardScores) + .mapToObj(i -> new PythonDecimal(BigDecimal.valueOf(i))) + .collect(Collectors.toCollection(PythonLikeTuple::new)); + out.soft_scores = IntStream.of(softScores) + .mapToObj(i -> new PythonDecimal(BigDecimal.valueOf(i))) + .collect(Collectors.toCollection(PythonLikeTuple::new)); + return out; + } +} diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardMediumSoftDecimalScore.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardMediumSoftDecimalScore.java new file mode 100644 index 00000000..940d31e8 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardMediumSoftDecimalScore.java @@ -0,0 +1,40 @@ +package ai.timefold.solver.python.score; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; + +public class PythonHardMediumSoftDecimalScore extends AbstractPythonLikeObject { + public static final PythonLikeType TYPE = + new PythonLikeType("HardMediumSoftDecimalScore", PythonHardMediumSoftDecimalScore.class); + public PythonInteger init_score; + public PythonDecimal hard_score; + public PythonDecimal medium_score; + public PythonDecimal soft_score; + + public PythonHardMediumSoftDecimalScore() { + super(TYPE); + } + + public static PythonHardMediumSoftDecimalScore of(int hardScore, int mediumScore, int softScore) { + var out = new PythonHardMediumSoftDecimalScore(); + out.init_score = PythonInteger.ZERO; + out.hard_score = new PythonDecimal(BigDecimal.valueOf(hardScore)); + out.medium_score = new PythonDecimal(BigDecimal.valueOf(mediumScore)); + out.soft_score = new PythonDecimal(BigDecimal.valueOf(softScore)); + return out; + } + + public static PythonHardMediumSoftDecimalScore ofUninitialized(int initScore, int hardScore, int mediumScore, + int softScore) { + var out = new PythonHardMediumSoftDecimalScore(); + out.init_score = PythonInteger.valueOf(initScore); + out.hard_score = new PythonDecimal(BigDecimal.valueOf(hardScore)); + out.medium_score = new PythonDecimal(BigDecimal.valueOf(mediumScore)); + out.soft_score = new PythonDecimal(BigDecimal.valueOf(softScore)); + return out; + } +} diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardSoftDecimalScore.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardSoftDecimalScore.java new file mode 100644 index 00000000..a1595a63 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonHardSoftDecimalScore.java @@ -0,0 +1,35 @@ +package ai.timefold.solver.python.score; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; + +public class PythonHardSoftDecimalScore extends AbstractPythonLikeObject { + public static final PythonLikeType TYPE = new PythonLikeType("HardSoftDecimalScore", PythonHardSoftDecimalScore.class); + public PythonInteger init_score; + public PythonDecimal hard_score; + public PythonDecimal soft_score; + + public PythonHardSoftDecimalScore() { + super(TYPE); + } + + public static PythonHardSoftDecimalScore of(int hardScore, int softScore) { + var out = new PythonHardSoftDecimalScore(); + out.init_score = PythonInteger.ZERO; + out.hard_score = new PythonDecimal(BigDecimal.valueOf(hardScore)); + out.soft_score = new PythonDecimal(BigDecimal.valueOf(softScore)); + return out; + } + + public static PythonHardSoftDecimalScore ofUninitialized(int initScore, int hardScore, int softScore) { + var out = new PythonHardSoftDecimalScore(); + out.init_score = PythonInteger.valueOf(initScore); + out.hard_score = new PythonDecimal(BigDecimal.valueOf(hardScore)); + out.soft_score = new PythonDecimal(BigDecimal.valueOf(softScore)); + return out; + } +} diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonSimpleDecimalScore.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonSimpleDecimalScore.java new file mode 100644 index 00000000..4568a5e1 --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/PythonSimpleDecimalScore.java @@ -0,0 +1,32 @@ +package ai.timefold.solver.python.score; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject; +import ai.timefold.jpyinterpreter.types.PythonLikeType; +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; + +public class PythonSimpleDecimalScore extends AbstractPythonLikeObject { + public static final PythonLikeType TYPE = new PythonLikeType("SimpleDecimalScore", PythonSimpleDecimalScore.class); + public PythonInteger init_score; + public PythonDecimal score; + + public PythonSimpleDecimalScore() { + super(TYPE); + } + + public static PythonSimpleDecimalScore of(int score) { + var out = new PythonSimpleDecimalScore(); + out.init_score = PythonInteger.ZERO; + out.score = new PythonDecimal(BigDecimal.valueOf(score)); + return out; + } + + public static PythonSimpleDecimalScore ofUninitialized(int initScore, int score) { + var out = new PythonSimpleDecimalScore(); + out.init_score = PythonInteger.valueOf(initScore); + out.score = new PythonDecimal(BigDecimal.valueOf(score)); + return out; + } +} diff --git a/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMappingTest.java b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMappingTest.java new file mode 100644 index 00000000..5dc2574c --- /dev/null +++ b/timefold-solver-python-core/src/test/java/ai/timefold/solver/python/score/SimpleDecimalScorePythonJavaTypeMappingTest.java @@ -0,0 +1,63 @@ +package ai.timefold.solver.python.score; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.math.BigDecimal; + +import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal; +import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; +import ai.timefold.solver.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class SimpleDecimalScorePythonJavaTypeMappingTest { + SimpleDecimalScorePythonJavaTypeMapping typeMapping; + + @BeforeEach + void setUp() throws NoSuchFieldException, ClassNotFoundException, NoSuchMethodException { + this.typeMapping = new SimpleDecimalScorePythonJavaTypeMapping(PythonSimpleDecimalScore.TYPE); + } + + @Test + void getPythonType() { + assertThat(typeMapping.getPythonType()).isEqualTo(PythonSimpleDecimalScore.TYPE); + } + + @Test + void getJavaType() { + assertThat(typeMapping.getJavaType()).isEqualTo(SimpleBigDecimalScore.class); + } + + @Test + void toPythonObject() { + var initializedScore = SimpleBigDecimalScore.of(BigDecimal.valueOf(10)); + + var initializedPythonScore = (PythonSimpleDecimalScore) typeMapping.toPythonObject(initializedScore); + + assertThat(initializedPythonScore.init_score).isEqualTo(PythonInteger.ZERO); + assertThat(initializedPythonScore.score).isEqualTo(PythonDecimal.valueOf("10")); + + var uninitializedScore = SimpleBigDecimalScore.ofUninitialized(-5, BigDecimal.valueOf(20)); + var uninitializedPythonScore = (PythonSimpleDecimalScore) typeMapping.toPythonObject(uninitializedScore); + + assertThat(uninitializedPythonScore.init_score).isEqualTo(PythonInteger.valueOf(-5)); + assertThat(uninitializedPythonScore.score).isEqualTo(PythonDecimal.valueOf("20")); + } + + @Test + void toJavaObject() { + var initializedScore = PythonSimpleDecimalScore.of(10); + + var initializedJavaScore = typeMapping.toJavaObject(initializedScore); + + assertThat(initializedJavaScore.initScore()).isEqualTo(0); + assertThat(initializedJavaScore.score()).isEqualTo(BigDecimal.valueOf(10)); + + var uninitializedScore = PythonSimpleDecimalScore.ofUninitialized(-5, 20); + var uninitializedJavaScore = typeMapping.toJavaObject(uninitializedScore); + + assertThat(uninitializedJavaScore.initScore()).isEqualTo(-5); + assertThat(uninitializedJavaScore.score()).isEqualTo(BigDecimal.valueOf(20)); + } +} \ No newline at end of file