Skip to content

Commit

Permalink
Draft of fix of passing primitives in lambdas.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii0lomakin committed Aug 14, 2024
1 parent 22b4fed commit 9b36b2a
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ public static CallingConvention getCallingConvention(CodeCacheProvider codeCache
} else {
argTypes = new JavaType[sigCount];
}
var declaringClass = method.getDeclaringClass();
for (int i = 0; i < sigCount; i++) {
argTypes[argIndex++] = sig.getParameterType(i, null);
argTypes[argIndex++] = sig.getParameterType(i, declaringClass);
}

final Local[] locals = method.getLocalVariableTable().getLocalsAt(0);
Expand All @@ -70,8 +71,9 @@ private static CallingConvention getCallingConvention(Type type, JavaType return
inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(JavaKind.Short)), variableIndex);
continue;
}
inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(argTypes[i].getJavaKind())), variableIndex);

var javaKind = convertJavaKind(argTypes[i]);
inputParameters[i] = new Variable(LIRKind.value(target.arch.getPlatformKind(javaKind)), variableIndex);
}

JavaKind returnKind = returnType == null ? JavaKind.Void : returnType.getJavaKind();
Expand All @@ -87,4 +89,24 @@ public static boolean isHalfFloat(JavaType type) {
return type.toJavaName().equals(HalfFloat.class.getName());
}

/**
* Convert a {@link JavaType} to a {@link JavaKind}, all wrappers for primitive types are converted to the corresponding {@link JavaKind} of the primitive type.
*
* @param type
* @return
*/
public static JavaKind convertJavaKind(JavaType type) {
return switch (type.getName()) {
case "Ljava/lang/Boolean;" -> JavaKind.Boolean;
case "Ljava/lang/Byte;" -> JavaKind.Byte;
case "Ljava/lang/Short;" -> JavaKind.Short;
case "Ljava/lang/Character;" -> JavaKind.Char;
case "Ljava/lang/Integer;" -> JavaKind.Int;
case "Ljava/lang/Long;" -> JavaKind.Long;
case "Ljava/lang/Float;" -> JavaKind.Float;
case "Ljava/lang/Double;" -> JavaKind.Double;
default -> type.getJavaKind();
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ public ReferenceMapBuilder newReferenceMapBuilder(int totalFrameSize) {
}

/**
* It allocated the extra internal buffers that are used by this backend
* (constant and atomic).
* It allocated the extra internal buffers that are used by this backend (constant and atomic).
*/
@Override
public void allocateTornadoVMBuffersOnDevice() {
Expand Down Expand Up @@ -363,19 +362,21 @@ private void emitMethodParameters(OCLAssembler asm, ResolvedJavaMethod method, C
final Local[] locals = method.getLocalVariableTable().getLocalsAt(0);

for (int i = 0; i < incomingArguments.getArgumentCount(); i++) {
var javaType = locals[i].getType();
var javaKind = CodeUtil.convertJavaKind(javaType);
if (isKernel) {
if (locals[i].getType().getJavaKind().isPrimitive() || isHalfFloat(locals[i].getType())) {
if (javaKind.isPrimitive() || isHalfFloat(javaType)) {
final AllocatableValue param = incomingArguments.getArgument(i);
OCLKind kind = (OCLKind) param.getPlatformKind();
asm.emit(", ");
asm.emit("__private %s %s", kind.toString(), locals[i].getName());
} else {
// Skip the kernel context object
if (locals[i].getType().toJavaName().equals(KernelContext.class.getName())) {
if (javaType.toJavaName().equals(KernelContext.class.getName())) {
continue;
}
// Skip atomic integers
if (locals[i].getType().toJavaName().equals(AtomicInteger.class.getName())) {
if (javaType.toJavaName().equals(AtomicInteger.class.getName())) {
continue;
}
asm.emit(", ");
Expand All @@ -385,8 +386,8 @@ private void emitMethodParameters(OCLAssembler asm, ResolvedJavaMethod method, C
} else {
final AllocatableValue param = incomingArguments.getArgument(i);
OCLKind oclKind = (OCLKind) param.getPlatformKind();
if (locals[i].getType().getJavaKind().isObject()) {
OCLKind tmpKind = OCLKind.resolveToVectorKind(locals[i].getType().resolve(method.getDeclaringClass()));
if (javaKind.isObject()) {
OCLKind tmpKind = OCLKind.resolveToVectorKind(javaType.resolve(method.getDeclaringClass()));
if (tmpKind != OCLKind.ILLEGAL) {
oclKind = tmpKind;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
import java.util.concurrent.atomic.AtomicBoolean;

import org.graalvm.compiler.core.common.type.ObjectStamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.common.type.StampPair;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.graph.Graph.Mark;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedGuardNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.LogicConstantNode;
import org.graalvm.compiler.nodes.NodeView;
Expand All @@ -47,6 +50,7 @@
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.calc.IsNullNode;
import org.graalvm.compiler.nodes.extended.UnboxNode;
import org.graalvm.compiler.nodes.java.ArrayLengthNode;
import org.graalvm.compiler.nodes.java.LoadFieldNode;
import org.graalvm.compiler.nodes.util.GraphUtil;
Expand Down Expand Up @@ -91,8 +95,9 @@ private static boolean hasPanamaArraySizeNode(StructuredGraph graph) {
for (LoadFieldNode loadField : graph.getNodes().filter(LoadFieldNode.class)) {
final ResolvedJavaField field = loadField.field();
if (field.getType().getJavaKind().isPrimitive()) {
if (loadField.toString().contains("numberOfElements"))
if (loadField.toString().contains("numberOfElements")) {
return true;
}
}
}
return false;
Expand Down Expand Up @@ -254,18 +259,18 @@ private void evaluate(final StructuredGraph graph, final Node node, final Object
}
}

private ConstantNode createConstantFromObject(Object obj, StructuredGraph graph) {
ConstantNode result = null;
private ParameterNode createPrimitiveParameterFromObjectParameter(Object obj, int index, StructuredGraph graph) {
ParameterNode result = null;
switch (obj) {
case Byte objByte -> result = ConstantNode.forByte(objByte, graph);
case Character objChar -> result = ConstantNode.forChar(objChar, graph);
case Short objShort -> result = ConstantNode.forShort(objShort, graph);
case HalfFloat objHalfFloat -> result = ConstantNode.forFloat(objHalfFloat.getFloat32(), graph);
case Integer objInteger -> result = ConstantNode.forInt(objInteger, graph);
case Float objFloat -> result = ConstantNode.forFloat(objFloat, graph);
case Double objDouble -> result = ConstantNode.forDouble(objDouble, graph);
case Long objLong -> result = ConstantNode.forLong(objLong, graph);
case null, default -> unimplemented("createConstantFromObject: %s", obj);
case Byte objByte -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Byte)));
case Character objChar -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Char)));
case Short objShort -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Short)));
case HalfFloat objHalfFloat -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Float)));
case Integer objInteger -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Int)));
case Float objFloat -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Float)));
case Double objDouble -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Double)));
case Long objLong -> result = new ParameterNode(index, StampPair.createSingle(StampFactory.forKind(JavaKind.Long)));
case null, default -> unimplemented("createPrimitiveParameterFromObjectParameter: %s", obj);
}
return result;
}
Expand Down Expand Up @@ -297,8 +302,35 @@ private void propagateParameters(StructuredGraph graph, ParameterNode parameterN
parameterNode.replaceAtUsages(kernelContextAccessNode);
index++;
} else {
ConstantNode constant = createConstantFromObject(args[parameterNode.index()], graph);
parameterNode.replaceAtUsages(constant);
ParameterNode primitiveParameter = createPrimitiveParameterFromObjectParameter(args[parameterNode.index()], parameterNode.index(), graph);

parameterNode.replaceAtAllUsages(primitiveParameter, true);
parameterNode.safeDelete();

//remove Unbox nodes, they are not needed for constant values and cause compilation errors
graph.getNodes().filter(n -> n instanceof PiNode piNode && piNode.object() == primitiveParameter).snapshot().forEach(node -> {
var usagesSnapshot = node.usages().snapshot();
node.replaceAtAllUsages(primitiveParameter, true);
node.safeDelete();

usagesSnapshot.forEach(n -> {
if (n instanceof UnboxNode unboxNode) {
var prev = n.predecessor();

unboxNode.replaceAtAllUsages(primitiveParameter, true);
graph.removeFixed(unboxNode);

if (prev instanceof FixedGuardNode fixedGuardNode) {
if (fixedGuardNode.condition() instanceof IsNullNode isNullNode && isNullNode.getValue() == primitiveParameter) {
fixedGuardNode.clearInputs();
graph.removeFixed(fixedGuardNode);
}
}
}
});
});

graph.addOrUnique(primitiveParameter);
}
} else {
parameterNode.usages().snapshot().forEach(n -> {
Expand All @@ -319,7 +351,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
final Mark mark = graph.getMark();
if (context.hasArgs()) {
getDebugContext().dump(DebugContext.INFO_LEVEL, graph, "Before Phase Propagate Parameters");
for (final ParameterNode param : graph.getNodes(ParameterNode.TYPE)) {
for (final ParameterNode param : graph.getNodes(ParameterNode.TYPE).snapshot()) {
propagateParameters(graph, param, context.getArgs());
}
getDebugContext().dump(DebugContext.INFO_LEVEL, graph, "After Phase Propagate Parameters");
Expand Down Expand Up @@ -395,6 +427,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {

@FunctionalInterface
private interface FunctionThatThrows<T, R> {

R apply(T t) throws IllegalArgumentException, IllegalAccessException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package uk.ac.manchester.tornado.unittests.functional;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.util.Random;
Expand All @@ -32,6 +33,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

Expand Down Expand Up @@ -145,4 +147,28 @@ public void testVectorFunctionLambda03() throws TornadoExecutionPlanException {
}
}

@Test
public void testParameterUnboxing() throws Exception {
var arrayToCopy = new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f };

var taskGraph = new TaskGraph("s0");
taskGraph.transferToDevice(DataTransferMode.EVERY_EXECUTION, arrayToCopy);

var resultArray = new FloatArray(arrayToCopy.length);
taskGraph.transferToHost(DataTransferMode.EVERY_EXECUTION, resultArray);

taskGraph.task("t0", (source, sourceOffset, destination, destinationOffset, length) -> {
for (@Parallel int i = 0; i < length; i++) {
destination.set(destinationOffset + i, source[sourceOffset + i]);
}
}, arrayToCopy, 0, resultArray, 0, arrayToCopy.length);

var snapshot = taskGraph.snapshot();
try (var executionPlan = new TornadoExecutionPlan(snapshot)) {
executionPlan.execute();
}

assertArrayEquals(arrayToCopy, resultArray.toHeapArray(), 0.001f);
}

}

0 comments on commit 9b36b2a

Please sign in to comment.