Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
feat: add support for Decimal and Decimal score types (#110)
Browse files Browse the repository at this point in the history
- Decimal maps mostly to BigDecimal, although its floating
  point concepts are ignored (Python does not have an infinite
  precision MathContext, so it acts more like a dynamic range
  floating point with an adjustable precision.
  The precision used is shared in
  a thread local object that can be changed using
  decimal.setcontext.

- Added `str` constructors to `float` and `int`

- Added sanity tests for all variants of penalize/reward/impact
  and score types
  • Loading branch information
Christopher-Chianelli authored Jul 12, 2024
1 parent e1dac95 commit 1c7d902
Show file tree
Hide file tree
Showing 36 changed files with 4,198 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public class PythonClassTranslator {
// $ is illegal in variables/methods in Python
public static final String TYPE_FIELD_NAME = "$TYPE";
public static final String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE";
private static final String JAVA_METHOD_PREFIX = "$method$";
private static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping";
public static final String JAVA_METHOD_PREFIX = "$method$";
public static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping";

public record PreparedClassInfo(PythonLikeType type, String className, String classInternalName) {
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.timefold.jpyinterpreter.implementors;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.IdentityHashMap;
import java.util.Iterator;
Expand Down Expand Up @@ -31,6 +32,7 @@
import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple;
import ai.timefold.jpyinterpreter.types.errors.TypeError;
import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean;
import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal;
import ai.timefold.jpyinterpreter.types.numeric.PythonFloat;
import ai.timefold.jpyinterpreter.types.numeric.PythonInteger;
import ai.timefold.jpyinterpreter.types.numeric.PythonNumber;
Expand Down Expand Up @@ -65,76 +67,78 @@ public static PythonLikeObject wrapJavaObject(Object object, Map<Object, PythonL
return existingObject;
}

if (object instanceof OpaqueJavaReference) {
return ((OpaqueJavaReference) object).proxy();
if (object instanceof OpaqueJavaReference opaqueJavaReference) {
return opaqueJavaReference.proxy();
}

if (object instanceof PythonLikeObject) {
if (object instanceof PythonLikeObject instance) {
// Object already a PythonLikeObject; need to do nothing
return (PythonLikeObject) object;
return instance;
}

if (object instanceof Byte || object instanceof Short || object instanceof Integer || object instanceof Long) {
return PythonInteger.valueOf(((Number) object).longValue());
}

if (object instanceof BigInteger) {
return PythonInteger.valueOf((BigInteger) object);
if (object instanceof BigInteger integer) {
return PythonInteger.valueOf(integer);
}

if (object instanceof BigDecimal decimal) {
return new PythonDecimal(decimal);
}

if (object instanceof Float || object instanceof Double) {
return PythonFloat.valueOf(((Number) object).doubleValue());
}

if (object instanceof Boolean) {
return PythonBoolean.valueOf((Boolean) object);
if (object instanceof Boolean booleanValue) {
return PythonBoolean.valueOf(booleanValue);
}

if (object instanceof String) {
return PythonString.valueOf((String) object);
if (object instanceof String string) {
return PythonString.valueOf(string);
}

if (object instanceof Iterator) {
return new DelegatePythonIterator<>((Iterator) object);
if (object instanceof Iterator<?> iterator) {
return new DelegatePythonIterator<>(iterator);
}

if (object instanceof List) {
PythonLikeList out = new PythonLikeList();
if (object instanceof List<?> list) {
PythonLikeList<?> out = new PythonLikeList<>();
createdObjectMap.put(object, out);
for (Object item : (List) object) {
for (Object item : list) {
out.add(wrapJavaObject(item));
}
return out;
}

if (object instanceof Set) {
PythonLikeSet out = new PythonLikeSet();
if (object instanceof Set<?> set) {
PythonLikeSet<?> out = new PythonLikeSet<>();
createdObjectMap.put(object, out);
for (Object item : (Set) object) {
for (Object item : set) {
out.add(wrapJavaObject(item));
}
return out;
}

if (object instanceof Map) {
PythonLikeDict out = new PythonLikeDict();
if (object instanceof Map<?, ?> map) {
PythonLikeDict<?, ?> out = new PythonLikeDict<>();
createdObjectMap.put(object, out);
Set<Map.Entry<?, ?>> entrySet = ((Map) object).entrySet();
for (Map.Entry<?, ?> entry : entrySet) {
var entrySet = map.entrySet();
for (var entry : entrySet) {
out.put(wrapJavaObject(entry.getKey()), wrapJavaObject(entry.getValue()));
}
return out;
}

if (object instanceof Class) {
Class<?> maybeFunctionClass = (Class<?>) object;
if (Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) {
return new PythonCode((Class<? extends PythonLikeFunction>) maybeFunctionClass);
}
if (object instanceof Class<?> maybeFunctionClass &&
Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) {
return new PythonCode((Class<? extends PythonLikeFunction>) maybeFunctionClass);
}

if (object instanceof OpaquePythonReference) {
return new PythonObjectWrapper((OpaquePythonReference) object);
if (object instanceof OpaquePythonReference opaquePythonReference) {
return new PythonObjectWrapper(opaquePythonReference);
}

// Default: return a JavaObjectWrapper
Expand All @@ -161,6 +165,10 @@ public static PythonLikeType getPythonLikeType(Class<?> javaClass) {
return BuiltinTypes.INT_TYPE;
}

if (BigDecimal.class.equals(javaClass) || PythonDecimal.class.equals(javaClass)) {
return BuiltinTypes.DECIMAL_TYPE;
}

if (float.class.equals(javaClass) || double.class.equals(javaClass) ||
Float.class.equals(javaClass) || Double.class.equals(javaClass) ||
PythonFloat.class.equals(javaClass)) {
Expand Down Expand Up @@ -254,8 +262,7 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho
return null;
}

if (object instanceof JavaObjectWrapper) {
JavaObjectWrapper wrappedObject = (JavaObjectWrapper) object;
if (object instanceof JavaObjectWrapper wrappedObject) {
Object javaObject = wrappedObject.getWrappedObject();
if (!type.isAssignableFrom(javaObject.getClass())) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(javaObject.getClass()) + ") to ("
Expand All @@ -266,14 +273,13 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho

if (type.equals(byte.class) || type.equals(short.class) || type.equals(int.class) || type.equals(long.class) ||
type.equals(float.class) || type.equals(double.class) || Number.class.isAssignableFrom(type)) {
if (!(object instanceof PythonNumber)) {
if (!(object instanceof PythonNumber pythonNumber)) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to ("
+ getPythonLikeType(type) + ").");
}
PythonNumber pythonNumber = (PythonNumber) object;
Number value = pythonNumber.getValue();

if (type.equals(BigInteger.class)) {
if (type.equals(BigInteger.class) || type.equals(BigDecimal.class)) {
return (T) value;
}

Expand Down Expand Up @@ -303,11 +309,10 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho
}

if (type.equals(boolean.class) || type.equals(Boolean.class)) {
if (!(object instanceof PythonBoolean)) {
if (!(object instanceof PythonBoolean pythonBoolean)) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to ("
+ getPythonLikeType(type) + ").");
}
PythonBoolean pythonBoolean = (PythonBoolean) object;
return (T) (Boolean) pythonBoolean.getBooleanValue();
}

Expand Down Expand Up @@ -335,6 +340,53 @@ public static void loadName(MethodVisitor methodVisitor, String name) {
false);
}

private record ReturnValueOpDescriptor(
String wrapperClassName,
String methodName,
String methodDescriptor,
int opcode,
boolean noConversionNeeded) {
public static ReturnValueOpDescriptor noConversion() {
return new ReturnValueOpDescriptor("", "", "",
Opcodes.ARETURN, true);
}

public static ReturnValueOpDescriptor forNumeric(String methodName,
String methodDescriptor,
int opcode) {
return new ReturnValueOpDescriptor(Type.getInternalName(Number.class), methodName, methodDescriptor, opcode,
false);
}
}

private static final Map<Type, ReturnValueOpDescriptor> numericReturnValueOpDescriptorMap = Map.of(
Type.BYTE_TYPE, ReturnValueOpDescriptor.forNumeric(
"byteValue",
Type.getMethodDescriptor(Type.BYTE_TYPE),
Opcodes.IRETURN),
Type.SHORT_TYPE, ReturnValueOpDescriptor.forNumeric(
"shortValue",
Type.getMethodDescriptor(Type.SHORT_TYPE),
Opcodes.IRETURN),
Type.INT_TYPE, ReturnValueOpDescriptor.forNumeric(
"intValue",
Type.getMethodDescriptor(Type.INT_TYPE),
Opcodes.IRETURN),
Type.LONG_TYPE, ReturnValueOpDescriptor.forNumeric(
"longValue",
Type.getMethodDescriptor(Type.LONG_TYPE),
Opcodes.LRETURN),
Type.FLOAT_TYPE, ReturnValueOpDescriptor.forNumeric(
"floatValue",
Type.getMethodDescriptor(Type.FLOAT_TYPE),
Opcodes.FRETURN),
Type.DOUBLE_TYPE, ReturnValueOpDescriptor.forNumeric(
"doubleValue",
Type.getMethodDescriptor(Type.DOUBLE_TYPE),
Opcodes.DRETURN),
Type.getType(BigInteger.class), ReturnValueOpDescriptor.noConversion(),
Type.getType(BigDecimal.class), ReturnValueOpDescriptor.noConversion());

/**
* If {@code method} return type is not void, convert TOS into its Java equivalent and return it.
* If {@code method} return type is void, immediately return.
Expand All @@ -344,67 +396,36 @@ public static void loadName(MethodVisitor methodVisitor, String name) {
public static void returnValue(MethodVisitor methodVisitor, MethodDescriptor method, StackMetadata stackMetadata) {
Type returnAsmType = method.getReturnType();

if (Type.CHAR_TYPE.equals(returnAsmType)) {
throw new IllegalStateException("Unhandled case for primitive type (char).");
}

if (Type.VOID_TYPE.equals(returnAsmType)) {
methodVisitor.visitInsn(Opcodes.RETURN);
return;
}

if (Type.BYTE_TYPE.equals(returnAsmType) ||
Type.CHAR_TYPE.equals(returnAsmType) ||
Type.SHORT_TYPE.equals(returnAsmType) ||
Type.INT_TYPE.equals(returnAsmType) ||
Type.LONG_TYPE.equals(returnAsmType) ||
Type.FLOAT_TYPE.equals(returnAsmType) ||
Type.DOUBLE_TYPE.equals(returnAsmType)) {
if (numericReturnValueOpDescriptorMap.containsKey(returnAsmType)) {
var returnValueOpDescriptor = numericReturnValueOpDescriptorMap.get(returnAsmType);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(PythonNumber.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE,
Type.getInternalName(PythonNumber.class),
"getValue",
Type.getMethodDescriptor(Type.getType(Number.class)),
true);
String wrapperClassName = null;
String methodName = null;
String methodDescriptor = null;
int returnOpcode = 0;

if (Type.BYTE_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "byteValue";
methodDescriptor = Type.getMethodDescriptor(Type.BYTE_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.CHAR_TYPE.equals(returnAsmType)) {
throw new IllegalStateException("Unhandled case for primitive type (char).");
// returnOpcode = Opcodes.IRETURN;
} else if (Type.SHORT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "shortValue";
methodDescriptor = Type.getMethodDescriptor(Type.SHORT_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.INT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "intValue";
methodDescriptor = Type.getMethodDescriptor(Type.INT_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.FLOAT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "floatValue";
methodDescriptor = Type.getMethodDescriptor(Type.FLOAT_TYPE);
returnOpcode = Opcodes.FRETURN;
} else if (Type.LONG_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "longValue";
methodDescriptor = Type.getMethodDescriptor(Type.LONG_TYPE);
returnOpcode = Opcodes.LRETURN;
} else if (Type.DOUBLE_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "doubleValue";
methodDescriptor = Type.getMethodDescriptor(Type.DOUBLE_TYPE);
returnOpcode = Opcodes.DRETURN;

if (returnValueOpDescriptor.noConversionNeeded) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, returnAsmType.getInternalName());
methodVisitor.visitInsn(Opcodes.ARETURN);
return;
}

methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
wrapperClassName, methodName, methodDescriptor,
returnValueOpDescriptor.wrapperClassName,
returnValueOpDescriptor.methodName,
returnValueOpDescriptor.methodDescriptor,
false);
methodVisitor.visitInsn(returnOpcode);
methodVisitor.visitInsn(returnValueOpDescriptor.opcode);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.timefold.jpyinterpreter.types.collections.view.DictValueView;
import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean;
import ai.timefold.jpyinterpreter.types.numeric.PythonComplex;
import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal;
import ai.timefold.jpyinterpreter.types.numeric.PythonFloat;
import ai.timefold.jpyinterpreter.types.numeric.PythonInteger;
import ai.timefold.jpyinterpreter.types.numeric.PythonNumber;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class BuiltinTypes {
public static final PythonLikeType BOOLEAN_TYPE = new PythonLikeType("bool", PythonBoolean.class, List.of(INT_TYPE));
public static final PythonLikeType FLOAT_TYPE = new PythonLikeType("float", PythonFloat.class, List.of(NUMBER_TYPE));
public final static PythonLikeType COMPLEX_TYPE = new PythonLikeType("complex", PythonComplex.class, List.of(NUMBER_TYPE));
public final static PythonLikeType DECIMAL_TYPE = new PythonLikeType("Decimal", PythonDecimal.class, List.of(NUMBER_TYPE));

public static final PythonLikeType STRING_TYPE = new PythonLikeType("str", PythonString.class, List.of(BASE_TYPE));
public static final PythonLikeType BYTES_TYPE = new PythonLikeType("bytes", PythonBytes.class, List.of(BASE_TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ public PythonLikeTuple<T> createNewInstance() {
return new PythonLikeTuple<>();
}

public static PythonLikeTuple fromItems(PythonLikeObject... items) {
PythonLikeTuple result = new PythonLikeTuple();
public static <T extends PythonLikeObject> PythonLikeTuple<T> fromItems(T... items) {
PythonLikeTuple<T> result = new PythonLikeTuple<>();
Collections.addAll(result, items);
return result;
}

public static PythonLikeTuple fromList(List<PythonLikeObject> other) {
PythonLikeTuple result = new PythonLikeTuple();
public static <T extends PythonLikeObject> PythonLikeTuple<T> fromList(List<T> other) {
PythonLikeTuple<T> result = new PythonLikeTuple<>();
result.addAll(other);
return result;
}
Expand Down
Loading

0 comments on commit 1c7d902

Please sign in to comment.