Skip to content

Commit

Permalink
Merge pull request #1168 from beehive-lab/florin/fix-tornado-memory-l…
Browse files Browse the repository at this point in the history
…eaks

Fix tornado memory leaks
  • Loading branch information
jjfumero authored Aug 3, 2024
2 parents f1e670d + 4802ed3 commit 22b4fed
Show file tree
Hide file tree
Showing 27 changed files with 285 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public interface TornadoDeviceContext {

boolean wasReset();

void reset(long executionPlanId);

void setResetToFalse();

boolean isPlatformFPGA();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package uk.ac.manchester.tornado.drivers.opencl;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
Expand All @@ -46,6 +47,19 @@ public OCLCommandQueue get(OCLTargetDevice device, OCLContext context) {
return deviceCommandMap.get(device).get(Thread.currentThread().threadId(), device, context);
}

public void cleanup(OCLTargetDevice device) {
if (deviceCommandMap.containsKey(device)) {
deviceCommandMap.get(device).cleanup(Thread.currentThread().threadId());
}
if (deviceCommandMap.get(device).size() == 0) {
deviceCommandMap.remove(device);
}
}

public int size() {
return deviceCommandMap.size();
}

private static class ThreadCommandQueueTable {
private final Map<Long, OCLCommandQueue> commandQueueMap;

Expand All @@ -68,5 +82,16 @@ public OCLCommandQueue get(long threadId, OCLTargetDevice device, OCLContext con
}
return commandQueueMap.get(threadId);
}

public void cleanup(long threadId) {
if (commandQueueMap.containsKey(threadId)) {
OCLCommandQueue queue = commandQueueMap.remove(threadId);
queue.cleanup();
}
}

public int size() {
return commandQueueMap.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ public class OCLContext implements OCLContextInterface {
private final List<OCLTargetDevice> devices;
private final List<OCLDeviceContext> deviceContexts;

private final List<OCLProgram> programs;
private final OCLPlatform platform;

private final TornadoLogger logger;
Expand All @@ -53,7 +52,6 @@ public OCLContext(OCLPlatform platform, long contextPointer, List<OCLTargetDevic
this.contextID = contextPointer;
this.devices = devices;
this.deviceContexts = new ArrayList<>(devices.size());
this.programs = new ArrayList<>();
this.logger = new TornadoLogger(this.getClass());
}

Expand Down Expand Up @@ -135,7 +133,6 @@ public OCLProgram createProgramWithSource(byte[] source, long[] lengths, OCLDevi

try {
program = new OCLProgram(clCreateProgramWithSource(contextID, source, lengths), deviceContext);
programs.add(program);
} catch (OCLException e) {
logger.error(e.getMessage());
}
Expand All @@ -151,7 +148,6 @@ public OCLProgram createProgramWithIL(byte[] spirvBinary, long[] lengths, OCLDev
throw new TornadoNoOpenCLPlatformException("OpenCL version <= 2.1. clCreateProgramWithIL is not supported");
}
program = new OCLProgram(programID, deviceContext);
programs.add(program);
} catch (OCLException e) {
throw new TornadoRuntimeException(e);
}
Expand Down Expand Up @@ -180,18 +176,12 @@ public void cleanup() {
}

try {
long t0 = System.nanoTime();
for (OCLProgram program : programs) {
program.cleanup();
}
long t1 = System.nanoTime();
clReleaseContext(contextID);
long t2 = System.nanoTime();

if (TornadoOptions.FULL_DEBUG) {
System.out.printf("cleanup: %-10s..........%.9f s%n", "programs", (t1 - t0) * 1e-9);
System.out.printf("cleanup: %-10s..........%.9f s%n", "context", (t2 - t1) * 1e-9);
System.out.printf("cleanup: %-10s..........%.9f s%n", "total", (t2 - t0) * 1e-9);
}
} catch (OCLException e) {
logger.error(e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,16 @@ public int enqueueMarker(long executionPlanId, int[] events) {
public void reset(long executionPlanId) {
OCLEventPool eventPool = getOCLEventPool(executionPlanId);
eventPool.reset();
oclEventPool.remove(executionPlanId);
OCLCommandQueueTable table = commandQueueTable.get(executionPlanId);
if (table != null) {
table.cleanup(device);
if (table.size() == 0) {
commandQueueTable.remove(executionPlanId);
}
executionIDs.remove(executionPlanId);
}
getMemoryManager().releaseKernelStackFrame(executionPlanId);
codeCache.reset();
wasReset = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public List<OCLEvent> getEvents() {
return result;
}

protected void reset() {
public void reset() {
for (int index = 0; index < events.length; index++) {
if (events[index] > 0) {
internalEvent.setEventId(index, events[index]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ public OCLKernel clCreateKernel(String entryPoint) {
OCLKernel kernel;
try {
kernel = new OCLKernel(clCreateKernel(programPointer, entryPoint), deviceContext);
kernels.add(kernel);
} catch (OCLException e) {
throw new TornadoBailoutRuntimeException(e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public OCLInstalledCode(final String entryPoint, final byte[] code, final OCLDev
@Override
public void invalidate() {
if (valid) {
kernel.cleanup();
program.cleanup();
valid = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ public class OCLKernelStackFrame extends OCLByteBuffer implements KernelStackFra

private final ArrayList<CallArgument> callArguments;

private boolean isValid;

OCLKernelStackFrame(long bufferId, int numArgs, OCLDeviceContext device) {
super(device, bufferId, 0, RESERVED_SLOTS << 3);
this.callArguments = new ArrayList<>(numArgs);
buffer.clear();
this.isValid = true;
}

@Override
Expand All @@ -53,6 +56,17 @@ public void reset() {
callArguments.clear();
}

@Override
public boolean isValid() {
return isValid;
}

@Override
public void invalidate() {
isValid = false;
deviceContext.getPlatformContext().releaseBuffer(toBuffer());
}

@Override
public List<CallArgument> getCallArguments() {
return callArguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ public OCLKernelStackFrame createKernelStackFrame(long executionPlanId, final in
return oclKernelStackFrame.get(executionPlanId);
}

public void releaseKernelStackFrame(long executionPlanId) {
OCLKernelStackFrame stackFrame = oclKernelStackFrame.remove(executionPlanId);
if (stackFrame != null) {
stackFrame.invalidate();
}
}

public XPUBuffer createAtomicsBuffer(final int[] array) {
return new AtomicsBuffer(array, deviceContext);
}
Expand Down
Loading

0 comments on commit 22b4fed

Please sign in to comment.