Skip to content

Commit

Permalink
fix: Add PlanningPinToIndex to Python (#1191)
Browse files Browse the repository at this point in the history
JPyInterpreter

- Allow annotations to specify a type to override the generated getter
return type. When a type is specified, it goes through the generic
PythonLikeObject to Java Object convertor.

Python

- Add PlanningPinToIndex
- Make PlanningPin use the new JPyInterpreter mechanism to remove code
that converts a PlanningPin to Pinning filter

Build

- Fix Python 3.12 environment name in `tox.ini`
  • Loading branch information
Christopher-Chianelli authored Nov 11, 2024
1 parent 19e2a79 commit 0750c9d
Show file tree
Hide file tree
Showing 15 changed files with 348 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
* Example: Assuming a list of values {@code [A, B, C]}:
*
* <ul>
* <li>0 or null allows the entire list to be modified.</li>
* <li>0 allows the entire list to be modified.</li>
* <li>1 pins {@code A}; rest of the list may be modified or added to.</li>
* <li>2 pins {@code A, B}; rest of the list may be modified or added to.</li>
* <li>3 pins {@code A, B, C}; the list can only be added to.</li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ private void createEffectivePlanningPinIndexReader() {
case 0 -> effectivePlanningPinToIndexReader = null;
case 1 -> {
var memberAccessor = planningPinIndexMemberAccessorList.get(0);
effectivePlanningPinToIndexReader = (solution, entity) -> (int) memberAccessor.executeGetter(entity);
effectivePlanningPinToIndexReader =
(solution, entity) -> (int) memberAccessor.executeGetter(entity);
}
default -> throw new IllegalStateException(
"The entityClass (%s) has (%d) @%s-annotated members (%s), where it should only have one."
Expand Down
6 changes: 6 additions & 0 deletions python/jpyinterpreter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
<scope>runtime</scope>
</dependency>

<!-- standardized nullness annotations -->
<dependency>
<groupId>org.jspecify</groupId>
<artifactId>jspecify</artifactId>
</dependency>

<!-- Testing -->
<dependency>
<groupId>org.junit.jupiter</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import java.util.List;
import java.util.Map;

import org.jspecify.annotations.NonNull;
import org.jspecify.annotations.Nullable;
import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;

public record AnnotationMetadata(Class<? extends Annotation> annotationType, Map<String, Object> annotationValueMap) {
public record AnnotationMetadata(@NonNull Class<? extends Annotation> annotationType,
@NonNull Map<String, Object> annotationValueMap,
@Nullable Class<?> fieldTypeOverride) {
public void addAnnotationTo(ClassVisitor classVisitor) {
visitAnnotation(classVisitor.visitAnnotation(Type.getDescriptor(annotationType), true));
}
Expand All @@ -30,19 +34,22 @@ public void addAnnotationTo(MethodVisitor methodVisitor) {
public static List<AnnotationMetadata> getAnnotationListWithoutRepeatable(List<AnnotationMetadata> metadata) {
List<AnnotationMetadata> out = new ArrayList<>();
Map<Class<? extends Annotation>, List<AnnotationMetadata>> repeatableAnnotationMap = new LinkedHashMap<>();
Map<Class<? extends Annotation>, Class<?>> fieldTypeOverrideMap = new LinkedHashMap<>();
for (AnnotationMetadata annotation : metadata) {
Repeatable repeatable = annotation.annotationType().getAnnotation(Repeatable.class);
if (repeatable == null) {
out.add(annotation);
continue;
}
var annotationContainer = repeatable.value();
fieldTypeOverrideMap.put(annotationContainer, annotation.fieldTypeOverride());
repeatableAnnotationMap.computeIfAbsent(annotationContainer,
ignored -> new ArrayList<>()).add(annotation);
}
for (var entry : repeatableAnnotationMap.entrySet()) {
out.add(new AnnotationMetadata(entry.getKey(),
Map.of("value", entry.getValue().toArray(AnnotationMetadata[]::new))));
Map.of("value", entry.getValue().toArray(AnnotationMetadata[]::new)),
fieldTypeOverrideMap.get(entry.getKey())));
}
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ private static void createMethodDelegate(ClassWriter classWriter,
interfaceMethodVisitor.visitInsn(Opcodes.RETURN);
} else {
if (returnType.isPrimitive()) {
DelegatingInterfaceImplementor.loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor);
JavaPythonTypeConversionImplementor.loadTypeClass(returnType, interfaceMethodVisitor);
} else {
interfaceMethodVisitor.visitLdcInsn(Type.getType(returnType));
}
Expand All @@ -279,7 +279,7 @@ private static void createMethodDelegate(ClassWriter classWriter,
PythonLikeObject.class)),
false);
if (returnType.isPrimitive()) {
DelegatingInterfaceImplementor.unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
JavaPythonTypeConversionImplementor.unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
interfaceMethodVisitor.visitInsn(Type.getType(returnType).getOpcode(Opcodes.IRETURN));
} else {
interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,11 @@ private static void createJavaGetterSetter(ClassWriter classWriter,
private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo,
MatchedMapping matchedMapping, String attributeName,
Type attributeType, Type getterType, String signature, TypeHint typeHint) {
var typeOverride = typeHint.getOverrideTypeDescriptor();
var isTypeOverridden = typeOverride != null;
if (isTypeOverridden) {
getterType = Type.getType(typeOverride);
}
var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
if (signature != null && Objects.equals(attributeType, getterType)) {
signature = "()" + signature;
Expand All @@ -858,6 +863,9 @@ private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo
}

getterVisitor.visitCode();
if (isTypeOverridden && !Objects.equals(attributeType, getterType)) {
JavaPythonTypeConversionImplementor.loadTypeClass(getterType, getterVisitor);
}
getterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
getterVisitor.visitFieldInsn(Opcodes.GETFIELD, preparedClassInfo.classInternalName,
attributeName, attributeType.getDescriptor());
Expand Down Expand Up @@ -890,9 +898,22 @@ private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo
true);
getterVisitor.visitLabel(skipMapping);
}
getterVisitor.visitTypeInsn(Opcodes.CHECKCAST, getterType.getInternalName());
if (isTypeOverridden) {
getterVisitor.visitMethodInsn(Opcodes.INVOKESTATIC,
Type.getInternalName(JavaPythonTypeConversionImplementor.class),
"convertPythonObjectToJavaType",
Type.getMethodDescriptor(Type.getType(Object.class),
Type.getType(Class.class),
Type.getType(PythonLikeObject.class)),
false);
}
if (getterType.getSort() == Type.OBJECT) {
getterVisitor.visitTypeInsn(Opcodes.CHECKCAST, getterType.getInternalName());
} else {
JavaPythonTypeConversionImplementor.unboxBoxedPrimitiveType(getterType, getterVisitor);
}
}
getterVisitor.visitInsn(Opcodes.ARETURN);
getterVisitor.visitInsn(getterType.getOpcode(Opcodes.IRETURN));
getterVisitor.visitMaxs(maxStack, 0);
getterVisitor.visitEnd();
}
Expand All @@ -901,6 +922,11 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
MatchedMapping matchedMapping, String attributeName,
Type attributeType, Type setterType, String signature, TypeHint typeHint) {
var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
var typeOverride = typeHint.getOverrideTypeDescriptor();
var isTypeOverridden = typeOverride != null;
if (isTypeOverridden) {
setterType = Type.getType(typeOverride);
}
if (signature != null && Objects.equals(attributeType, setterType)) {
signature = "(" + signature + ")V";
}
Expand All @@ -910,7 +936,10 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
var maxStack = 2;
setterVisitor.visitCode();
setterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
setterVisitor.visitVarInsn(Opcodes.ALOAD, 1);
setterVisitor.visitVarInsn(setterType.getOpcode(Opcodes.ILOAD), 1);
if (setterType.getSort() != Type.OBJECT) {
JavaPythonTypeConversionImplementor.boxPrimitiveType(setterType, setterVisitor);
}
if (typeHint.type().isInstance(PythonNone.INSTANCE)) {
maxStack = 4;
// We want to replace null with None
Expand Down Expand Up @@ -941,6 +970,14 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
true);
setterVisitor.visitLabel(skipMapping);
}
if (isTypeOverridden) {
setterVisitor.visitMethodInsn(Opcodes.INVOKESTATIC,
Type.getInternalName(JavaPythonTypeConversionImplementor.class),
"wrapJavaObject",
Type.getMethodDescriptor(Type.getType(PythonLikeObject.class),
Type.getType(Object.class)),
false);
}
setterVisitor.visitTypeInsn(Opcodes.CHECKCAST, attributeType.getInternalName());
}
setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import ai.timefold.jpyinterpreter.types.PythonLikeType;

import org.jspecify.annotations.Nullable;
import org.objectweb.asm.Type;

public record TypeHint(PythonLikeType type, List<AnnotationMetadata> annotationList, TypeHint[] genericArgs,
PythonLikeType javaGetterType) {
public TypeHint {
Expand All @@ -22,6 +25,20 @@ public TypeHint(PythonLikeType type, List<AnnotationMetadata> annotationList, Py
this(type, annotationList, null, javaGetterType);
}

@Nullable
public String getOverrideTypeDescriptor() {
Class<?> override = null;
for (var annotation : annotationList) {
var newOverride = annotation.fieldTypeOverride();
if (override != null && !override.equals(newOverride)) {
throw new IllegalArgumentException(
"Multiple override specified that do not match in annotations (" + annotationList + ").");
}
override = newOverride;
}
return (override != null) ? Type.getDescriptor(override) : null;
}

public TypeHint addAnnotations(List<AnnotationMetadata> addedAnnotations) {
List<AnnotationMetadata> combinedAnnotations = new ArrayList<>(annotationList.size() + addedAnnotations.size());
combinedAnnotations.addAll(annotationList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil
interfaceMethodVisitor.visitInsn(Opcodes.RETURN);
} else {
if (returnType.isPrimitive()) {
loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor);
JavaPythonTypeConversionImplementor.loadTypeClass(returnType, interfaceMethodVisitor);
} else {
interfaceMethodVisitor.visitLdcInsn(Type.getType(returnType));
}
Expand All @@ -131,7 +131,7 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil
PythonLikeObject.class)),
false);
if (returnType.isPrimitive()) {
unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
JavaPythonTypeConversionImplementor.unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
interfaceMethodVisitor.visitInsn(Type.getType(returnType).getOpcode(Opcodes.IRETURN));
} else {
interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType));
Expand All @@ -156,7 +156,7 @@ public static void prepareParametersForMethodCallFromArgumentSpec(Method interfa
interfaceMethodVisitor.visitVarInsn(Type.getType(parameterType).getOpcode(Opcodes.ILOAD),
i + 1);
if (parameterType.isPrimitive()) {
convertPrimitiveToObjectType(parameterType, interfaceMethodVisitor);
JavaPythonTypeConversionImplementor.boxPrimitiveType(parameterType, interfaceMethodVisitor);
}
interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 1);
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC,
Expand Down Expand Up @@ -191,93 +191,4 @@ public static void prepareParametersForMethodCallFromArgumentSpec(Method interfa
interfaceMethodVisitor.visitInsn(Opcodes.POP);
}

public static void convertPrimitiveToObjectType(Class<?> primitiveType, MethodVisitor methodVisitor) {
if (primitiveType.equals(boolean.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Boolean.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Boolean.class), Type.BOOLEAN_TYPE), false);
} else if (primitiveType.equals(byte.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Byte.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Byte.class), Type.BYTE_TYPE), false);
} else if (primitiveType.equals(char.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Character.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Character.class), Type.CHAR_TYPE), false);
} else if (primitiveType.equals(short.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Short.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Short.class), Type.SHORT_TYPE), false);
} else if (primitiveType.equals(int.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Integer.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Integer.class), Type.INT_TYPE), false);
} else if (primitiveType.equals(long.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Long.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Long.class), Type.LONG_TYPE), false);
} else if (primitiveType.equals(float.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Float.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Float.class), Type.FLOAT_TYPE), false);
} else if (primitiveType.equals(double.class)) {
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Double.class),
"valueOf", Type.getMethodDescriptor(Type.getType(Double.class), Type.DOUBLE_TYPE), false);
} else {
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
}
}

public static void loadBoxedPrimitiveTypeClass(Class<?> primitiveType, MethodVisitor methodVisitor) {
if (primitiveType.equals(boolean.class)) {
methodVisitor.visitLdcInsn(Type.getType(Boolean.class));
} else if (primitiveType.equals(byte.class)) {
methodVisitor.visitLdcInsn(Type.getType(Byte.class));
} else if (primitiveType.equals(char.class)) {
methodVisitor.visitLdcInsn(Type.getType(Character.class));
} else if (primitiveType.equals(short.class)) {
methodVisitor.visitLdcInsn(Type.getType(Short.class));
} else if (primitiveType.equals(int.class)) {
methodVisitor.visitLdcInsn(Type.getType(Integer.class));
} else if (primitiveType.equals(long.class)) {
methodVisitor.visitLdcInsn(Type.getType(Long.class));
} else if (primitiveType.equals(float.class)) {
methodVisitor.visitLdcInsn(Type.getType(Float.class));
} else if (primitiveType.equals(double.class)) {
methodVisitor.visitLdcInsn(Type.getType(Double.class));
} else {
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
}
}

public static void unboxBoxedPrimitiveType(Class<?> primitiveType, MethodVisitor methodVisitor) {
if (primitiveType.equals(boolean.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class),
"booleanValue", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false);
} else if (primitiveType.equals(byte.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class),
"byteValue", Type.getMethodDescriptor(Type.BYTE_TYPE), false);
} else if (primitiveType.equals(char.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class),
"charValue", Type.getMethodDescriptor(Type.CHAR_TYPE), false);
} else if (primitiveType.equals(short.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class),
"shortValue", Type.getMethodDescriptor(Type.SHORT_TYPE), false);
} else if (primitiveType.equals(int.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class),
"intValue", Type.getMethodDescriptor(Type.INT_TYPE), false);
} else if (primitiveType.equals(long.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class),
"longValue", Type.getMethodDescriptor(Type.LONG_TYPE), false);
} else if (primitiveType.equals(float.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class),
"floatValue", Type.getMethodDescriptor(Type.FLOAT_TYPE), false);
} else if (primitiveType.equals(double.class)) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class),
"doubleValue", Type.getMethodDescriptor(Type.DOUBLE_TYPE), false);
} else {
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
}
}
}
Loading

0 comments on commit 0750c9d

Please sign in to comment.