Skip to content

Commit

Permalink
Merge pull request beehive-lab#513 from babylonml/printf-dbl-input
Browse files Browse the repository at this point in the history
Issue beehive-lab#506 has been fixed.
  • Loading branch information
jjfumero authored Jul 30, 2024
2 parents c64dca1 + b0683c7 commit f1e670d
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 287 deletions.
8 changes: 8 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,14 @@
<artifactId>collections</artifactId>
<version>${graalvm.version}</version>
</dependency>

<dependency>
<groupId>org.graalvm.compiler</groupId>
<artifactId>compiler</artifactId>
<version>23.1.0</version>
<scope>provided</scope>
</dependency>

</dependencies>
<modules>
<module>tornado-runtime</module>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import org.graalvm.compiler.core.common.memory.MemoryOrderMode;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.lir.Variable;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.PiNode;
Expand Down Expand Up @@ -79,10 +81,14 @@
import uk.ac.manchester.tornado.api.exceptions.Debug;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLUnaryIntrinsic;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt.AssignStmt;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLUnary;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.AtomicAddNodeTemplate;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.DecAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GetAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GlobalThreadIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.IncAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalArrayNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLBarrierNode;
Expand All @@ -91,7 +97,6 @@
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntBinaryIntrinsicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntUnaryIntrinsicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.PrintfNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.TPrintfNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.TornadoAtomicIntegerNode;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;

Expand Down Expand Up @@ -309,75 +314,6 @@ private static void registerKernelContextPlugins(InvocationPlugins plugins) {
localArraysPlugins(r);
}

private static boolean printfHandler(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode... args) {
int idCount = 0;
int index = 0;
for (; index < 3; index++) {
ValueNode arg = args[index];
if (arg instanceof ConstantNode && arg.getStackKind().isObject()) {
break;
}
idCount++;
}

NewArrayNode newArrayNode = (NewArrayNode) args[index + 1];
ConstantNode lengthNode = (ConstantNode) newArrayNode.dimension(0);
int length = ((JavaConstant) lengthNode.getValue()).asInt();

ValueNode[] actualArgs = new ValueNode[4 + length];
if (idCount >= 0) {
System.arraycopy(args, 0, actualArgs, 0, idCount);
}

for (int i = idCount; i < 3; i++) {
actualArgs[i] = ConstantNode.forInt(0);
}

actualArgs[3] = args[index];

int argIndex = 0;
for (Node n : newArrayNode.usages()) {
if (n instanceof StoreIndexedNode) {
StoreIndexedNode storeNode = (StoreIndexedNode) n;
ValueNode value = storeNode.value();
if (value instanceof BoxNode) {
BoxNode box = (BoxNode) value;
value = box.getValue();
GraphUtil.unlinkFixedNode(box);
box.safeDelete();
}
actualArgs[4 + argIndex] = value;
argIndex++;
}

}

TPrintfNode printfNode = new TPrintfNode(actualArgs);

b.add(b.append(printfNode));
while (newArrayNode.hasUsages()) {
Node n = newArrayNode.usages().first();
// need to remove all nodes from the graph that operate on
// the new array,
// however, we cannot remove all inputs as they may be used
// by the currently
// unbuilt part of the graph. We also need to ensure that we
// do not leave any
// gaps inbetween fixed nodes
if (n instanceof FixedWithNextNode) {
GraphUtil.unlinkFixedNode((FixedWithNextNode) n);
}
n.clearInputs();
n.safeDelete();
}

GraphUtil.unlinkFixedNode(newArrayNode);
newArrayNode.clearInputs();
newArrayNode.safeDelete();

return true;
}

public static Class getValueLayoutClass(Class k) {
if (k == int.class) {
return ValueLayout.OfInt.class;
Expand Down Expand Up @@ -430,39 +366,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}

private static void registerTornadoVMIntrinsicsPlugins(InvocationPlugins plugins) {
final InvocationPlugin tprintfPlugin2 = new InvocationPlugin("tprintf", String.class, Object[].class) {
@Override
public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode... args) {
return printfHandler(b, targetMethod, receiver, args);
}
};

final InvocationPlugin tprintfPlugin3 = new InvocationPlugin("tprintf", int.class, String.class, Object[].class) {
@Override
public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode... args) {
return printfHandler(b, targetMethod, receiver, args);
}
};

final InvocationPlugin tprintfPlugin4 = new InvocationPlugin("tprintf", int.class, int.class, String.class, Object[].class) {
@Override
public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode... args) {
return printfHandler(b, targetMethod, receiver, args);
}
};

final InvocationPlugin tprintfPlugin5 = new InvocationPlugin("tprintf", int.class, int.class, int.class, String.class, Object[].class) {
@Override
public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode... args) {
return printfHandler(b, targetMethod, receiver, args);
}
};

plugins.register(Debug.class, tprintfPlugin2);
plugins.register(Debug.class, tprintfPlugin3);
plugins.register(Debug.class, tprintfPlugin4);
plugins.register(Debug.class, tprintfPlugin5);

final InvocationPlugin printfPlugin = new InvocationPlugin("printf", String.class, Object[].class) {

@Override
Expand All @@ -472,9 +375,13 @@ public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMe
ConstantNode lengthNode = (ConstantNode) newArrayNode.dimension(0);
int length = ((JavaConstant) lengthNode.getValue()).asInt();

ValueNode[] actualArgs = new ValueNode[length + 1];
ValueNode[] actualArgs = new ValueNode[length + 4];
actualArgs[0] = args[0];

actualArgs[1] = b.append(new GlobalThreadIdNode(ConstantNode.forInt(0)));
actualArgs[2] = b.append(new GlobalThreadIdNode(ConstantNode.forInt(1)));
actualArgs[3] = b.append(new GlobalThreadIdNode(ConstantNode.forInt(2)));

int argIndex = 0;
for (Node n : newArrayNode.usages()) {
if (n instanceof StoreIndexedNode) {
Expand All @@ -486,14 +393,15 @@ public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMe
GraphUtil.unlinkFixedNode(box);
box.safeDelete();
}
actualArgs[argIndex + 1] = value;
actualArgs[argIndex + 4] = value;
argIndex++;
}

}


PrintfNode printfNode = new PrintfNode(actualArgs);
b.add(b.append(printfNode));
b.append(printfNode);

while (newArrayNode.hasUsages()) {
Node n = newArrayNode.usages().first();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,43 @@ public OCLPrintf(Value[] inputs) {

@Override
public void emit(OCLCompilationResultBuilder crb, OCLAssembler asm) {
asm.emit("printf( \"tornado[%%3d,%%3d,%%3d]> %s", asm.formatConstant((ConstantValue) inputs[0]));
asm.emit("printf( \"tornado[%%3d,%%3d,%%3d]> %s\"",
asm.formatConstant((ConstantValue) inputs[0]));

asm.emit("\", ");
for (int i = 0; i < 2; i++) {
asm.emit("get_global_id(%d), ", i);
for (int i = 1; i < 4; i++) {
asm.emit(", ");
asm.emitValue(crb, inputs[i]);
}
asm.emit("get_global_id(%d) ", 2);
if (inputs.length > 1) {

if (inputs.length > 4) {
asm.emit(", ");
}
for (int i = 1; i < inputs.length - 1; i++) {

for (int i = 4; i < inputs.length - 1; i++) {
asm.emitValue(crb, inputs[i]);
asm.emit(", ");
}

if (inputs.length > 1) {
if (inputs.length > 4) {
asm.emitValue(crb, inputs[inputs.length - 1]);
}

asm.emit(")");
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("printf( %s", inputs[0]));
if (inputs.length > 1) {
sb.append(String.format("printf( \"%s\"", inputs[0]));

if (inputs.length > 4) {
sb.append(", ");
}
for (int i = 1; i < inputs.length - 1; i++) {
for (int i = 4; i < inputs.length - 1; i++) {
sb.append(inputs[i]);
sb.append(", ");
}
if (inputs.length > 1) {
if (inputs.length > 4) {
sb.append(inputs[inputs.length - 1]);
}
sb.append(" )");
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ public class PrintfNode extends FixedWithNextNode implements LIRLowerable {

public PrintfNode(ValueNode... values) {
super(TYPE, StampFactory.forVoid());
this.inputs = new NodeInputList<>(this, values.length);
for (int i = 0; i < values.length; i++) {
inputs.set(i, values[i]);
}
this.inputs = new NodeInputList<>(this, values);
}

@Override
Expand Down
Loading

0 comments on commit f1e670d

Please sign in to comment.