diff --git a/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/code/CodeUtil.java b/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/code/CodeUtil.java index 86d2cc3a50..121b3bc87b 100644 --- a/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/code/CodeUtil.java +++ b/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/code/CodeUtil.java @@ -52,8 +52,9 @@ public static CallingConvention getCallingConvention(CodeCacheProvider codeCache } else { argTypes = new JavaType[sigCount]; } + var declaringClass = method.getDeclaringClass(); for (int i = 0; i < sigCount; i++) { - argTypes[argIndex++] = sig.getParameterType(i, null); + argTypes[argIndex++] = sig.getParameterType(i, declaringClass); } final Local[] locals = method.getLocalVariableTable().getLocalsAt(0); @@ -70,8 +71,9 @@ private static CallingConvention getCallingConvention(Type type, JavaType return inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(JavaKind.Short)), variableIndex); continue; } - inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(argTypes[i].getJavaKind())), variableIndex); + var javaKind = convertJavaKind(argTypes[i]); + inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(javaKind)), variableIndex); } JavaKind returnKind = returnType == null ? JavaKind.Void : returnType.getJavaKind(); @@ -87,4 +89,24 @@ public static boolean isHalfFloat(JavaType type) { return type.toJavaName().equals(HalfFloat.class.getName()); } + /** + * Convert a {@link JavaType} to a {@link JavaKind}, all wrappers for primitive types are converted to the corresponding {@link JavaKind} of the primitive type. + * + * @param type + * @return + */ + public static JavaKind convertJavaKind(JavaType type) { + return switch (type.getName()) { + case "Ljava/lang/Boolean;" -> JavaKind.Boolean; + case "Ljava/lang/Byte;" -> JavaKind.Byte; + case "Ljava/lang/Short;" -> JavaKind.Short; + case "Ljava/lang/Character;" -> JavaKind.Char; + case "Ljava/lang/Integer;" -> JavaKind.Int; + case "Ljava/lang/Long;" -> JavaKind.Long; + case "Ljava/lang/Float;" -> JavaKind.Float; + case "Ljava/lang/Double;" -> JavaKind.Double; + default -> type.getJavaKind(); + }; + } + } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/backend/OCLBackend.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/backend/OCLBackend.java index 2db7b8f976..8099c48111 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/backend/OCLBackend.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/backend/OCLBackend.java @@ -144,8 +144,7 @@ public ReferenceMapBuilder newReferenceMapBuilder(int totalFrameSize) { } /** - * It allocated the extra internal buffers that are used by this backend - * (constant and atomic). + * It allocated the extra internal buffers that are used by this backend (constant and atomic). */ @Override public void allocateTornadoVMBuffersOnDevice() { @@ -363,19 +362,21 @@ private void emitMethodParameters(OCLAssembler asm, ResolvedJavaMethod method, C final Local[] locals = method.getLocalVariableTable().getLocalsAt(0); for (int i = 0; i < incomingArguments.getArgumentCount(); i++) { + var javaType = locals[i].getType(); + var javaKind = CodeUtil.convertJavaKind(javaType); if (isKernel) { - if (locals[i].getType().getJavaKind().isPrimitive() || isHalfFloat(locals[i].getType())) { + if (javaKind.isPrimitive() || isHalfFloat(javaType)) { final AllocatableValue param = incomingArguments.getArgument(i); OCLKind kind = (OCLKind) param.getPlatformKind(); asm.emit(", "); asm.emit("__private %s %s", kind.toString(), locals[i].getName()); } else { // Skip the kernel context object - if (locals[i].getType().toJavaName().equals(KernelContext.class.getName())) { + if (javaType.toJavaName().equals(KernelContext.class.getName())) { continue; } // Skip atomic integers - if (locals[i].getType().toJavaName().equals(AtomicInteger.class.getName())) { + if (javaType.toJavaName().equals(AtomicInteger.class.getName())) { continue; } asm.emit(", "); @@ -385,8 +386,8 @@ private void emitMethodParameters(OCLAssembler asm, ResolvedJavaMethod method, C } else { final AllocatableValue param = incomingArguments.getArgument(i); OCLKind oclKind = (OCLKind) param.getPlatformKind(); - if (locals[i].getType().getJavaKind().isObject()) { - OCLKind tmpKind = OCLKind.resolveToVectorKind(locals[i].getType().resolve(method.getDeclaringClass())); + if (javaKind.isObject()) { + OCLKind tmpKind = OCLKind.resolveToVectorKind(javaType.resolve(method.getDeclaringClass())); if (tmpKind != OCLKind.ILLEGAL) { oclKind = tmpKind; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoTaskSpecialisation.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoTaskSpecialisation.java index 9576490087..14323602a9 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoTaskSpecialisation.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoTaskSpecialisation.java @@ -33,11 +33,14 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.graalvm.compiler.core.common.type.ObjectStamp; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.core.common.type.StampPair; import org.graalvm.compiler.debug.DebugContext; import org.graalvm.compiler.graph.Graph.Mark; import org.graalvm.compiler.graph.Node; import org.graalvm.compiler.graph.iterators.NodeIterable; import org.graalvm.compiler.nodes.ConstantNode; +import org.graalvm.compiler.nodes.FixedGuardNode; import org.graalvm.compiler.nodes.GraphState; import org.graalvm.compiler.nodes.LogicConstantNode; import org.graalvm.compiler.nodes.NodeView; @@ -47,6 +50,7 @@ import org.graalvm.compiler.nodes.StructuredGraph; import org.graalvm.compiler.nodes.calc.IntegerLessThanNode; import org.graalvm.compiler.nodes.calc.IsNullNode; +import org.graalvm.compiler.nodes.extended.UnboxNode; import org.graalvm.compiler.nodes.java.ArrayLengthNode; import org.graalvm.compiler.nodes.java.LoadFieldNode; import org.graalvm.compiler.nodes.util.GraphUtil; @@ -91,8 +95,9 @@ private static boolean hasPanamaArraySizeNode(StructuredGraph graph) { for (LoadFieldNode loadField : graph.getNodes().filter(LoadFieldNode.class)) { final ResolvedJavaField field = loadField.field(); if (field.getType().getJavaKind().isPrimitive()) { - if (loadField.toString().contains("numberOfElements")) + if (loadField.toString().contains("numberOfElements")) { return true; + } } } return false; @@ -254,18 +259,18 @@ private void evaluate(final StructuredGraph graph, final Node node, final Object } } - private ConstantNode createConstantFromObject(Object obj, StructuredGraph graph) { - ConstantNode result = null; + private ParameterNode createPrimitiveParameterFromObjectParameter(Object obj, int index, StructuredGraph graph) { + ParameterNode result = null; switch (obj) { - case Byte objByte -> result = ConstantNode.forByte(objByte, graph); - case Character objChar -> result = ConstantNode.forChar(objChar, graph); - case Short objShort -> result = ConstantNode.forShort(objShort, graph); - case HalfFloat objHalfFloat -> result = ConstantNode.forFloat(objHalfFloat.getFloat32(), graph); - case Integer objInteger -> result = ConstantNode.forInt(objInteger, graph); - case Float objFloat -> result = ConstantNode.forFloat(objFloat, graph); - case Double objDouble -> result = ConstantNode.forDouble(objDouble, graph); - case Long objLong -> result = ConstantNode.forLong(objLong, graph); - case null, default -> unimplemented("createConstantFromObject: %s", obj); + case Byte objByte -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Byte))); + case Character objChar -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Char))); + case Short objShort -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Short))); + case HalfFloat objHalfFloat -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Float))); + case Integer objInteger -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Int))); + case Float objFloat -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Float))); + case Double objDouble -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Double))); + case Long objLong -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Long))); + case null, default -> unimplemented("createPrimitiveParameterFromObjectParameter: %s", obj); } return result; } @@ -297,8 +302,35 @@ private void propagateParameters(StructuredGraph graph, ParameterNode parameterN parameterNode.replaceAtUsages(kernelContextAccessNode); index++; } else { - ConstantNode constant = createConstantFromObject(args[parameterNode.index()], graph); - parameterNode.replaceAtUsages(constant); + ParameterNode primitiveParameter = createPrimitiveParameterFromObjectParameter(args[parameterNode.index()], parameterNode.index(), graph); + + parameterNode.replaceAtAllUsages(primitiveParameter, true); + parameterNode.safeDelete(); + + //remove Unbox nodes, they are not needed for constant values and cause compilation errors + graph.getNodes().filter(n -> n instanceof PiNode piNode && piNode.object() == primitiveParameter).snapshot().forEach(node -> { + var usagesSnapshot = node.usages().snapshot(); + node.replaceAtAllUsages(primitiveParameter, true); + node.safeDelete(); + + usagesSnapshot.forEach(n -> { + if (n instanceof UnboxNode unboxNode) { + var prev = n.predecessor(); + + unboxNode.replaceAtAllUsages(primitiveParameter, true); + graph.removeFixed(unboxNode); + + if (prev instanceof FixedGuardNode fixedGuardNode) { + if (fixedGuardNode.condition() instanceof IsNullNode isNullNode && isNullNode.getValue() == primitiveParameter) { + fixedGuardNode.clearInputs(); + graph.removeFixed(fixedGuardNode); + } + } + } + }); + }); + + graph.addOrUnique(primitiveParameter); } } else { parameterNode.usages().snapshot().forEach(n -> { @@ -319,7 +351,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) { final Mark mark = graph.getMark(); if (context.hasArgs()) { getDebugContext().dump(DebugContext.INFO_LEVEL, graph, "Before Phase Propagate Parameters"); - for (final ParameterNode param : graph.getNodes(ParameterNode.TYPE)) { + for (final ParameterNode param : graph.getNodes(ParameterNode.TYPE).snapshot()) { propagateParameters(graph, param, context.getArgs()); } getDebugContext().dump(DebugContext.INFO_LEVEL, graph, "After Phase Propagate Parameters"); @@ -395,6 +427,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) { @FunctionalInterface private interface FunctionThatThrows { + R apply(T t) throws IllegalArgumentException, IllegalAccessException; } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/functional/TestLambdas.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/functional/TestLambdas.java index 6f80c2ca23..e3ce7e319f 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/functional/TestLambdas.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/functional/TestLambdas.java @@ -18,6 +18,7 @@ package uk.ac.manchester.tornado.unittests.functional; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import java.util.Random; @@ -32,6 +33,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; @@ -145,4 +147,28 @@ public void testVectorFunctionLambda03() throws TornadoExecutionPlanException { } } + @Test + public void testParameterUnboxing() throws Exception { + var arrayToCopy = new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }; + + var taskGraph = new TaskGraph("s0"); + taskGraph.transferToDevice(DataTransferMode.EVERY_EXECUTION, arrayToCopy); + + var resultArray = new FloatArray(arrayToCopy.length); + taskGraph.transferToHost(DataTransferMode.EVERY_EXECUTION, resultArray); + + taskGraph.task("t0", (source, sourceOffset, destination, destinationOffset, length) -> { + for (@Parallel int i = 0; i < length; i++) { + destination.set(destinationOffset + i, source[sourceOffset + i]); + } + }, arrayToCopy, 0, resultArray, 0, arrayToCopy.length); + + var snapshot = taskGraph.snapshot(); + try (var executionPlan = new TornadoExecutionPlan(snapshot)) { + executionPlan.execute(); + } + + assertArrayEquals(arrayToCopy, resultArray.toHeapArray(), 0.001f); + } + }