From b06011a19acae8f6fd6dad82eeeb84d3a2a321c6 Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Sun, 30 Jun 2024 22:36:59 +0300 Subject: [PATCH 1/3] Fix memory leaks on the OpenCL backend --- .../tornado/api/TornadoDeviceContext.java | 2 ++ .../drivers/opencl/OCLCommandQueueTable.java | 25 +++++++++++++++++++ .../tornado/drivers/opencl/OCLContext.java | 10 -------- .../drivers/opencl/OCLDeviceContext.java | 11 ++++++++ .../tornado/drivers/opencl/OCLProgram.java | 1 + .../opencl/graal/OCLInstalledCode.java | 2 +- .../opencl/mm/OCLKernelStackFrame.java | 14 +++++++++++ .../drivers/opencl/mm/OCLMemoryManager.java | 7 ++++++ .../tornado/drivers/ptx/PTXDeviceContext.java | 1 + .../drivers/spirv/SPIRVDeviceContext.java | 1 + .../runtime/common/KernelStackFrame.java | 4 +++ .../interpreter/TornadoVMInterpreter.java | 2 +- .../runtime/tasks/TornadoTaskGraph.java | 2 ++ 13 files changed, 70 insertions(+), 12 deletions(-) diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java index 9d040411df..ba756444f4 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java @@ -30,6 +30,8 @@ public interface TornadoDeviceContext { boolean wasReset(); + void reset(long executionPlanId); + void setResetToFalse(); boolean isPlatformFPGA(); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCommandQueueTable.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCommandQueueTable.java index 38c3402104..e62a399188 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCommandQueueTable.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCommandQueueTable.java @@ -24,6 +24,7 @@ package uk.ac.manchester.tornado.drivers.opencl; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; @@ -46,6 +47,19 @@ public OCLCommandQueue get(OCLTargetDevice device, OCLContext context) { return deviceCommandMap.get(device).get(Thread.currentThread().threadId(), device, context); } + public void cleanup(OCLTargetDevice device) { + if (deviceCommandMap.containsKey(device)) { + deviceCommandMap.get(device).cleanup(Thread.currentThread().threadId()); + } + if (deviceCommandMap.get(device).size() == 0) { + deviceCommandMap.remove(device); + } + } + + public int size() { + return deviceCommandMap.size(); + } + private static class ThreadCommandQueueTable { private final Map commandQueueMap; @@ -68,5 +82,16 @@ public OCLCommandQueue get(long threadId, OCLTargetDevice device, OCLContext con } return commandQueueMap.get(threadId); } + + public void cleanup(long threadId) { + if (commandQueueMap.containsKey(threadId)) { + OCLCommandQueue queue = commandQueueMap.remove(threadId); + queue.cleanup(); + } + } + + public int size() { + return commandQueueMap.size(); + } } } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLContext.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLContext.java index b8c26a5d6d..d88bd73972 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLContext.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLContext.java @@ -43,7 +43,6 @@ public class OCLContext implements OCLContextInterface { private final List devices; private final List deviceContexts; - private final List programs; private final OCLPlatform platform; private final TornadoLogger logger; @@ -53,7 +52,6 @@ public OCLContext(OCLPlatform platform, long contextPointer, List(devices.size()); - this.programs = new ArrayList<>(); this.logger = new TornadoLogger(this.getClass()); } @@ -135,7 +133,6 @@ public OCLProgram createProgramWithSource(byte[] source, long[] lengths, OCLDevi try { program = new OCLProgram(clCreateProgramWithSource(contextID, source, lengths), deviceContext); - programs.add(program); } catch (OCLException e) { logger.error(e.getMessage()); } @@ -151,7 +148,6 @@ public OCLProgram createProgramWithIL(byte[] spirvBinary, long[] lengths, OCLDev throw new TornadoNoOpenCLPlatformException("OpenCL version <= 2.1. clCreateProgramWithIL is not supported"); } program = new OCLProgram(programID, deviceContext); - programs.add(program); } catch (OCLException e) { throw new TornadoRuntimeException(e); } @@ -180,18 +176,12 @@ public void cleanup() { } try { - long t0 = System.nanoTime(); - for (OCLProgram program : programs) { - program.cleanup(); - } long t1 = System.nanoTime(); clReleaseContext(contextID); long t2 = System.nanoTime(); if (TornadoOptions.FULL_DEBUG) { - System.out.printf("cleanup: %-10s..........%.9f s%n", "programs", (t1 - t0) * 1e-9); System.out.printf("cleanup: %-10s..........%.9f s%n", "context", (t2 - t1) * 1e-9); - System.out.printf("cleanup: %-10s..........%.9f s%n", "total", (t2 - t0) * 1e-9); } } catch (OCLException e) { logger.error(e.getMessage()); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java index 4d773e634e..ac94bc1a72 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java @@ -513,6 +513,17 @@ public int enqueueMarker(long executionPlanId, int[] events) { public void reset(long executionPlanId) { OCLEventPool eventPool = getOCLEventPool(executionPlanId); eventPool.reset(); + oclEventPool.remove(executionPlanId); + OCLCommandQueueTable table = commandQueueTable.get(executionPlanId); + if (table != null) { + OCLTargetDevice device = context.devices().get(getDeviceIndex()); + table.cleanup(device); + if (table.size() == 0) { + commandQueueTable.remove(executionPlanId); + } + executionIDs.remove(executionPlanId); + } + getMemoryManager().releaseKernelStackFrame(executionPlanId); codeCache.reset(); wasReset = true; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLProgram.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLProgram.java index bc30a98366..fd622d3c74 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLProgram.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLProgram.java @@ -222,6 +222,7 @@ public OCLKernel clCreateKernel(String entryPoint) { OCLKernel kernel; try { kernel = new OCLKernel(clCreateKernel(programPointer, entryPoint), deviceContext); + kernels.add(kernel); } catch (OCLException e) { throw new TornadoBailoutRuntimeException(e.getMessage()); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java index b7723a69f5..1caef66b6e 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java @@ -85,7 +85,7 @@ public OCLInstalledCode(final String entryPoint, final byte[] code, final OCLDev @Override public void invalidate() { if (valid) { - kernel.cleanup(); + program.cleanup(); valid = false; } } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLKernelStackFrame.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLKernelStackFrame.java index c537c254bf..6151f5fa24 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLKernelStackFrame.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLKernelStackFrame.java @@ -37,10 +37,13 @@ public class OCLKernelStackFrame extends OCLByteBuffer implements KernelStackFra private final ArrayList callArguments; + private boolean isValid; + OCLKernelStackFrame(long bufferId, int numArgs, OCLDeviceContext device) { super(device, bufferId, 0, RESERVED_SLOTS << 3); this.callArguments = new ArrayList<>(numArgs); buffer.clear(); + this.isValid = true; } @Override @@ -53,6 +56,17 @@ public void reset() { callArguments.clear(); } + @Override + public boolean isValid() { + return isValid; + } + + @Override + public void invalidate() { + isValid = false; + deviceContext.getPlatformContext().releaseBuffer(toBuffer()); + } + @Override public List getCallArguments() { return callArguments; diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemoryManager.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemoryManager.java index 02a80707eb..b2f3892f11 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemoryManager.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemoryManager.java @@ -62,6 +62,13 @@ public OCLKernelStackFrame createKernelStackFrame(long executionPlanId, final in return oclKernelStackFrame.get(executionPlanId); } + public void releaseKernelStackFrame(long executionPlanId) { + OCLKernelStackFrame stackFrame = oclKernelStackFrame.remove(executionPlanId); + if (stackFrame != null) { + stackFrame.invalidate(); + } + } + public XPUBuffer createAtomicsBuffer(final int[] array) { return new AtomicsBuffer(array, deviceContext); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java index 389afd05e6..648a9f8798 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java @@ -243,6 +243,7 @@ public void flush(long executionPlanId) { sync(executionPlanId); } + @Override public void reset(long executionPlanId) { PTXStream stream = getStream(executionPlanId); stream.reset(); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java index 68a37ead26..7cba8bfe09 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java @@ -153,6 +153,7 @@ public SPIRVTornadoDevice asMapping() { return tornadoDevice; } + @Override public void reset(long executionPlanId) { spirvEventPool.put(executionPlanId, new SPIRVEventPool(TornadoOptions.EVENT_WINDOW)); codeCache.reset(); diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/KernelStackFrame.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/KernelStackFrame.java index 030271b436..fbcbf596f1 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/KernelStackFrame.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/KernelStackFrame.java @@ -57,4 +57,8 @@ public boolean isReferenceType() { void addCallArgument(Object value, boolean isReferenceType); void setKernelContext(HashMap map); + + boolean isValid(); + + void invalidate(); } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java index aad972a46b..940a8700ac 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java @@ -855,7 +855,7 @@ private KernelStackFrame resolveCallWrapper(int index, int numArgs, KernelStackF if (executionContext.meta().isDebug() && redeployOnDevice) { logger.debug("Recompiling task on device " + device); } - if (kernelStackFrame[index] == null || redeployOnDevice) { + if (kernelStackFrame[index] == null || !kernelStackFrame[index].isValid() || redeployOnDevice) { kernelStackFrame[index] = device.createKernelStackFrame(executionContext.getExecutionPlanId(), numArgs); } return kernelStackFrame[index]; diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java index 3199428518..740807806d 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java @@ -62,6 +62,7 @@ import uk.ac.manchester.tornado.api.Policy; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoBackend; +import uk.ac.manchester.tornado.api.TornadoDeviceContext; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.TornadoTaskGraphInterface; import uk.ac.manchester.tornado.api.common.Event; @@ -1072,6 +1073,7 @@ private void free() { } inputModesObjects.forEach(inputStreamObject -> freeDeviceMemoryObject(inputStreamObject.getObject())); outputModeObjects.forEach(outputStreamObject -> freeDeviceMemoryObject(outputStreamObject.getObject())); + meta().getLogicDevice().getDeviceContext().reset(executionPlanId); } private void freeDeviceMemoryObject(Object object) { From 24d3635becd8a1d64be5576d82cc301ae3a1a5f1 Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Fri, 19 Jul 2024 22:26:27 +0300 Subject: [PATCH 2/3] Fix memory leaks on the SPIRV backend --- .../tornado/drivers/opencl/OCLEventPool.java | 2 +- .../tornado/drivers/spirv/SPIRVContext.java | 2 ++ .../drivers/spirv/SPIRVDeviceContext.java | 4 ++- .../SPIRVLevelZeroCommandQueueTable.java | 27 +++++++++++++++++++ .../drivers/spirv/SPIRVLevelZeroContext.java | 11 ++++++++ .../spirv/SPIRVOCLCommandQueueTable.java | 25 +++++++++++++++++ .../drivers/spirv/SPIRVOCLContext.java | 18 +++++++++++++ .../drivers/spirv/mm/SPIRVByteBuffer.java | 2 +- .../spirv/mm/SPIRVKernelStackFrame.java | 14 ++++++++++ .../drivers/spirv/mm/SPIRVMemoryManager.java | 8 ++++++ 10 files changed, 110 insertions(+), 3 deletions(-) diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLEventPool.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLEventPool.java index f2d1f6085c..879804dfd6 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLEventPool.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLEventPool.java @@ -148,7 +148,7 @@ public List getEvents() { return result; } - protected void reset() { + public void reset() { for (int index = 0; index < events.length; index++) { if (events[index] > 0) { internalEvent.setEventId(index, events[index]); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVContext.java index 6a1d027256..f04cebec61 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVContext.java @@ -95,4 +95,6 @@ public abstract int enqueueWriteBuffer(long executionPlanId, int deviceIndex, lo public abstract void readBuffer(long executionPlanId, int deviceIndex, long bufferId, long offset, long bytes, long offHeapSegmentAddress, long hostOffset, int[] waitEvents, ProfilerTransfer profilerTransfer); + + public abstract void reset(long executionPlanId, int deviceIndex); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java index 7cba8bfe09..a75a990eaa 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java @@ -155,7 +155,9 @@ public SPIRVTornadoDevice asMapping() { @Override public void reset(long executionPlanId) { - spirvEventPool.put(executionPlanId, new SPIRVEventPool(TornadoOptions.EVENT_WINDOW)); + spirvContext.reset(executionPlanId, getDeviceIndex()); + spirvEventPool.remove(executionPlanId); + getMemoryManager().releaseKernelStackFrame(executionPlanId); codeCache.reset(); wasReset = true; } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCommandQueueTable.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCommandQueueTable.java index 08afd61ef4..607ae2aec0 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCommandQueueTable.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCommandQueueTable.java @@ -26,6 +26,8 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import uk.ac.manchester.tornado.drivers.opencl.OCLCommandQueue; +import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice; import uk.ac.manchester.tornado.drivers.spirv.levelzero.LevelZeroCommandList; import uk.ac.manchester.tornado.drivers.spirv.levelzero.LevelZeroCommandQueue; import uk.ac.manchester.tornado.drivers.spirv.levelzero.LevelZeroContext; @@ -58,6 +60,19 @@ public SPIRVLevelZeroCommandQueue get(SPIRVDevice device, LevelZeroContext level return deviceCommandMap.get(device).get(Thread.currentThread().threadId(), device, levelZeroContext); } + public void cleanup(SPIRVDevice device, LevelZeroContext levelZeroContext) { + if (deviceCommandMap.containsKey(device)) { + deviceCommandMap.get(device).cleanup(Thread.currentThread().threadId(), levelZeroContext); + } + if (deviceCommandMap.get(device).size() == 0) { + deviceCommandMap.remove(device); + } + } + + public int size() { + return deviceCommandMap.size(); + } + private static class ThreadCommandQueueTable { private final Map commandQueueMap; @@ -129,6 +144,18 @@ private int getCommandQueueOrdinal(LevelZeroDevice device) { } return ordinal; } + + public void cleanup(long threadId, LevelZeroContext levelZeroContext) { + if (commandQueueMap.containsKey(threadId)) { + SPIRVLevelZeroCommandQueue queue = commandQueueMap.remove(threadId); + levelZeroContext.zeCommandQueueDestroy(queue.getCommandQueue().getCommandQueueHandle()); + levelZeroContext.zeCommandListDestroy(queue.getCommandList().getCommandListHandler()); + } + } + + public int size() { + return commandQueueMap.size(); + } } } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroContext.java index 112a29dea8..1070fcbda0 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroContext.java @@ -579,4 +579,15 @@ public void readBuffer(long executionPlanId, int deviceIndex, long bufferId, lon } } + @Override + public void reset(long executionPlanId, int deviceIndex) { + SPIRVLevelZeroCommandQueueTable table = commmandQueueTable.get(executionPlanId); + if (table != null) { + SPIRVDevice device = devices.get(deviceIndex); + table.cleanup(device, levelZeroContext); + if (table.size() == 0) { + commmandQueueTable.remove(executionPlanId); + } + } + } } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCommandQueueTable.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCommandQueueTable.java index 93c6f82183..523b88af77 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCommandQueueTable.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCommandQueueTable.java @@ -29,6 +29,7 @@ import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.opencl.OCLCommandQueue; import uk.ac.manchester.tornado.drivers.opencl.OCLContext; +import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice; import uk.ac.manchester.tornado.drivers.opencl.exceptions.OCLException; public class SPIRVOCLCommandQueueTable { @@ -48,6 +49,19 @@ public OCLCommandQueue get(SPIRVOCLDevice device, OCLContext context) { return deviceCommandMap.get(device).get(Thread.currentThread().threadId(), device, context); } + public void cleanup(SPIRVOCLDevice device) { + if (deviceCommandMap.containsKey(device)) { + deviceCommandMap.get(device).cleanup(Thread.currentThread().threadId()); + } + if (deviceCommandMap.get(device).size() == 0) { + deviceCommandMap.remove(device); + } + } + + public int size() { + return deviceCommandMap.size(); + } + private static class ThreadCommandQueueTable { private final Map commandQueueMap; @@ -70,5 +84,16 @@ public OCLCommandQueue get(long threadId, SPIRVOCLDevice device, OCLContext cont } return commandQueueMap.get(threadId); } + + public void cleanup(long threadId) { + if (commandQueueMap.containsKey(threadId)) { + OCLCommandQueue queue = commandQueueMap.remove(threadId); + queue.cleanup(); + } + } + + public int size() { + return commandQueueMap.size(); + } } } \ No newline at end of file diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLContext.java index b6e1c66675..1e48b27156 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLContext.java @@ -36,9 +36,11 @@ import uk.ac.manchester.tornado.drivers.common.CommandQueue; import uk.ac.manchester.tornado.drivers.common.utils.EventDescriptor; import uk.ac.manchester.tornado.drivers.opencl.OCLCommandQueue; +import uk.ac.manchester.tornado.drivers.opencl.OCLCommandQueueTable; import uk.ac.manchester.tornado.drivers.opencl.OCLContext; import uk.ac.manchester.tornado.drivers.opencl.OCLContextInterface; import uk.ac.manchester.tornado.drivers.opencl.OCLEventPool; +import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice; import uk.ac.manchester.tornado.drivers.opencl.OpenCLBlocking; import uk.ac.manchester.tornado.drivers.opencl.enums.OCLMemFlags; @@ -281,4 +283,20 @@ public void readBuffer(long executionPlanId, int deviceIndex, long bufferId, lon ? eventPool.waitEventsBuffer : null), EventDescriptor.DESC_READ_BYTE, commandQueue); } + + @Override + public void reset(long executionPlanId, int deviceIndex) { + OCLEventPool eventPool = getOCLEventPool(executionPlanId); + eventPool.reset(); + oclEventPool.remove(executionPlanId); + SPIRVOCLCommandQueueTable table = commmandQueueTable.get(executionPlanId); + if (table != null) { + SPIRVDevice device = devices.get(deviceIndex); + table.cleanup((SPIRVOCLDevice) device); + if (table.size() == 0) { + commmandQueueTable.remove(executionPlanId); + } + executionIDs.remove(executionPlanId); + } + } } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVByteBuffer.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVByteBuffer.java index 494db134e6..e27d9be76c 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVByteBuffer.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVByteBuffer.java @@ -36,7 +36,7 @@ public class SPIRVByteBuffer { private final long bufferId; protected final long bytes; private final long offset; - private SPIRVDeviceContext deviceContext; + protected SPIRVDeviceContext deviceContext; public SPIRVByteBuffer(final SPIRVDeviceContext deviceContext, final long bufferId, final long offset, final long numBytes) { this.deviceContext = deviceContext; diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVKernelStackFrame.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVKernelStackFrame.java index 7be8915994..2af3d2cd56 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVKernelStackFrame.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVKernelStackFrame.java @@ -36,10 +36,13 @@ public class SPIRVKernelStackFrame extends SPIRVByteBuffer implements KernelStac private final ArrayList callArguments; + private boolean isValid; + public SPIRVKernelStackFrame(long bufferId, int numArgs, SPIRVDeviceContext device) { super(device, bufferId, 0, RESERVED_SLOTS << 3); this.callArguments = new ArrayList<>(numArgs); buffer.clear(); + this.isValid = true; } @Override @@ -52,6 +55,17 @@ public void reset() { callArguments.clear(); } + @Override + public boolean isValid() { + return isValid; + } + + @Override + public void invalidate() { + isValid = false; + deviceContext.getSpirvContext().freeMemory(toBuffer(), deviceContext.getDeviceIndex()); + } + @Override public List getCallArguments() { return callArguments; diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemoryManager.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemoryManager.java index 8930f3f0ed..2fdabf20f8 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemoryManager.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemoryManager.java @@ -30,6 +30,7 @@ import java.util.concurrent.ConcurrentHashMap; import uk.ac.manchester.tornado.api.memory.TornadoMemoryProvider; +import uk.ac.manchester.tornado.drivers.opencl.mm.OCLKernelStackFrame; import uk.ac.manchester.tornado.drivers.spirv.SPIRVDeviceContext; public class SPIRVMemoryManager implements TornadoMemoryProvider { @@ -55,4 +56,11 @@ public SPIRVKernelStackFrame createKernelStackFrame(long threadId, final int max return spirvKernelStackFrame.get(threadId); } + public void releaseKernelStackFrame(long executionPlanId) { + SPIRVKernelStackFrame stackFrame = spirvKernelStackFrame.remove(executionPlanId); + if (stackFrame != null) { + stackFrame.invalidate(); + } + } + } From 4802ed3e3c54c5786dec26e9bc8fa1aceeffdabd Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Sat, 20 Jul 2024 08:18:02 +0300 Subject: [PATCH 3/3] Fix memory leaks on the PTX backend --- .../drivers/opencl/OCLDeviceContext.java | 1 - .../ptx-jni/src/main/cpp/source/PTXStream.cpp | 108 ++++++++---------- .../tornado/drivers/ptx/PTXDeviceContext.java | 11 +- .../tornado/drivers/ptx/PTXStreamTable.java | 25 ++++ .../tornado/drivers/ptx/mm/PTXByteBuffer.java | 4 +- .../drivers/ptx/mm/PTXKernelStackFrame.java | 14 +++ .../drivers/ptx/mm/PTXMemoryManager.java | 7 ++ .../drivers/spirv/SPIRVDeviceContext.java | 1 + 8 files changed, 106 insertions(+), 65 deletions(-) diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java index ac94bc1a72..5fdaaa7022 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java @@ -516,7 +516,6 @@ public void reset(long executionPlanId) { oclEventPool.remove(executionPlanId); OCLCommandQueueTable table = commandQueueTable.get(executionPlanId); if (table != null) { - OCLTargetDevice device = context.devices().get(getDeviceIndex()); table.cleanup(device); if (table.size() == 0) { commandQueueTable.remove(executionPlanId); diff --git a/tornado-drivers/ptx-jni/src/main/cpp/source/PTXStream.cpp b/tornado-drivers/ptx-jni/src/main/cpp/source/PTXStream.cpp index 160e71068f..cb570bd456 100644 --- a/tornado-drivers/ptx-jni/src/main/cpp/source/PTXStream.cpp +++ b/tornado-drivers/ptx-jni/src/main/cpp/source/PTXStream.cpp @@ -34,35 +34,27 @@ #include "ptx_log.h" /* - A singly linked list (with elements of type StagingAreaList) is used to keep all the allocated pinned memory through cuMemAllocHost. A queue (with elements of type QueueNode) is used to hold all the free (no longer used) pinned memory regions. - On a new read/write we call get_first_free_staging_area which will try to dequeue a pinned memory region to use it. + On a new read/write we call get_first_free_staging_block which will try to dequeue a pinned memory region to use it. */ /* - Linked list which holds information regarding the pinned memory allocated. + Holds information regarding the pinned memory allocated. - next -- next element of the list staging_area -- pointer to the pinned memory region length -- length in bytes of the memory region referenced by staging_area */ typedef struct area_list { - struct area_list *next; void *staging_area; size_t length; -} StagingAreaList; - -/* - Head of the allocated pinned memory list - */ -static StagingAreaList *head = NULL; +} StagingBlock; /* Linked list used to implement a queue which holds the free (no longer used) pinned memory regions. */ typedef struct queue_list { - StagingAreaList* element; + StagingBlock* element; struct queue_list *next; } QueueNode; @@ -75,7 +67,7 @@ static QueueNode *rear = NULL; /* Adds a free pinned memory region to the queue. */ -static void enqueue(StagingAreaList *region) { +static void enqueue(StagingBlock *region) { if (front == NULL) { front = static_cast(malloc(sizeof(QueueNode))); front->next = NULL; @@ -95,11 +87,11 @@ static void enqueue(StagingAreaList *region) { /* Returns the first element (free pinned memory region) of the queue. */ -static StagingAreaList* dequeue() { +static StagingBlock* dequeue() { if (front == NULL) { return NULL; } - StagingAreaList* region = front->element; + StagingBlock* region = front->element; QueueNode *oldFront = front; front = front->next; free(oldFront); @@ -107,6 +99,8 @@ static StagingAreaList* dequeue() { return region; } +static CUresult free_staging_block(StagingBlock *block); + /* Free the queue. */ @@ -116,6 +110,7 @@ static void free_queue() { QueueNode *node; while(front != NULL) { node = front; + free_staging_block(node->element); front = front->next; free(node); } @@ -124,75 +119,69 @@ static void free_queue() { /* Checks if the given staging region can fit into the required size. If not, it allocates the required pinned memory. */ -static StagingAreaList *check_or_init_staging_area(size_t size, StagingAreaList *list) { +static StagingBlock *check_or_init_staging_block(size_t size, StagingBlock *block) { // Create - if (list == NULL) { - list = static_cast(malloc(sizeof(StagingAreaList))); - CUresult result = cuMemAllocHost(&(list->staging_area), size); + if (block == NULL) { + block = static_cast(malloc(sizeof(StagingBlock))); + CUresult result = cuMemAllocHost(&(block->staging_area), size); if (result != CUDA_SUCCESS) { std::cout << "\t[JNI] " << __FILE__ << ":" << __LINE__ << " in function: " << __FUNCTION__ << " result = " << result << std::endl; std::flush(std::cout); return NULL; } - list->length = size; - list->next = NULL; + block->length = size; } // Update - else if (list->length < size) { - CUresult result = cuMemFreeHost(list->staging_area); + else if (block->length < size) { + CUresult result = cuMemFreeHost(block->staging_area); if (result != CUDA_SUCCESS) { std::cout << "\t[JNI] " << __FILE__ << ":" << __LINE__ << " in function: " << __FUNCTION__ << " result = " << result << std::endl; std::flush(std::cout); return NULL; } - result = cuMemAllocHost(&(list->staging_area), size); + result = cuMemAllocHost(&(block->staging_area), size); if (result != CUDA_SUCCESS) { std::cout << "\t[JNI] " << __FILE__ << ":" << __LINE__ << " in function: " << __FUNCTION__ << " result = " << result << std::endl; std::flush(std::cout); return NULL; } - list->length = size; + block->length = size; } - return list; + return block; } /* - Returns a StagingAreaList with pinned memory of given size. + Returns a StagingBlock with pinned memory of given size. */ -static StagingAreaList *get_first_free_staging_area(size_t size) { +static StagingBlock *get_first_free_staging_block(size_t size) { // Dequeue the first free staging area - StagingAreaList *list = dequeue(); + StagingBlock *block = dequeue(); - list = check_or_init_staging_area(size, list); - if (head == NULL) head = list; + block = check_or_init_staging_block(size, block); - return list; + return block; } /* - Called by cuStreamAddCallback, enqueues a StagingAreaList to the free queue for memory reuse. + Called by cuStreamAddCallback, enqueues a StagingBlock to the free queue for memory reuse. */ -static void set_to_unused(CUstream hStream, CUresult status, void *list) { - StagingAreaList *stagingList = (StagingAreaList *) list; - enqueue(stagingList); +static void set_to_unused(CUstream hStream, CUresult status, void *block) { + StagingBlock *stagingBlock = (StagingBlock *) block; + enqueue(stagingBlock); } /* Free all the allocated pinned memory. */ -static CUresult free_staging_area_list() { +static CUresult free_staging_block(StagingBlock *block) { CUresult result; - while (head != NULL) { - result = cuMemFreeHost(head->staging_area); - if (result != CUDA_SUCCESS) { - std::cout << "\t[JNI] " << __FILE__ << ":" << __LINE__ << " in function: " << __FUNCTION__ << " result = " << result << std::endl; - std::flush(std::cout); - } - StagingAreaList *list = head; - head = head->next; - free(list); + result = cuMemFreeHost(block->staging_area); + if (result != CUDA_SUCCESS) { + std::cout << "\t[JNI] " << __FILE__ << ":" << __LINE__ << " in function: " << __FUNCTION__ << " result = " << result << std::endl; + std::flush(std::cout); } + free(block); return result; } @@ -210,18 +199,18 @@ static jbyteArray array_from_stream(JNIEnv *env, CUstream *stream) { CUevent beforeEvent, afterEvent; \ CUstream stream; \ stream_from_array(env, &stream, stream_wrapper); \ - StagingAreaList *staging_list = get_first_free_staging_area(length);\ + StagingBlock *staging_block = get_first_free_staging_block(length);\ record_events_create(&beforeEvent, &afterEvent); \ record_event(&beforeEvent, &stream); \ - CUresult result = cuMemcpyDtoHAsync(staging_list->staging_area, device_ptr, (size_t) length, stream); \ + CUresult result = cuMemcpyDtoHAsync(staging_block->staging_area, device_ptr, (size_t) length, stream); \ LOG_PTX_AND_VALIDATE("cuMemcpyDtoHAsync", result); \ record_event(&afterEvent, &stream); \ if (cuEventQuery(afterEvent) != CUDA_SUCCESS) { \ cuEventSynchronize(afterEvent); \ } \ env->Set ## TYPE ## ArrayRegion(array, host_offset / sizeof(JAVATYPE), \ - length / sizeof(JAVATYPE), static_cast(staging_list->staging_area)); \ - set_to_unused(stream, result, staging_list); \ + length / sizeof(JAVATYPE), static_cast(staging_block->staging_area)); \ + set_to_unused(stream, result, staging_block); \ return wrapper_from_events(env, &beforeEvent, &afterEvent); @@ -384,30 +373,30 @@ JNIEXPORT jobjectArray JNICALL Java_uk_ac_manchester_tornado_drivers_ptx_PTXStre CUevent beforeEvent, afterEvent; \ CUstream stream; \ stream_from_array(env, &stream, stream_wrapper); \ - StagingAreaList *staging_list = get_first_free_staging_area(length);\ - env->Get## TYPE ##ArrayRegion(array, host_offset / sizeof(JAVATYPE), length / sizeof(JAVATYPE), static_cast(staging_list->staging_area)); \ + StagingBlock *staging_block = get_first_free_staging_block(length);\ + env->Get## TYPE ##ArrayRegion(array, host_offset / sizeof(JAVATYPE), length / sizeof(JAVATYPE), static_cast(staging_block->staging_area)); \ record_events_create(&beforeEvent, &afterEvent); \ record_event(&beforeEvent, &stream); \ - CUresult result = cuMemcpyHtoDAsync(device_ptr, staging_list->staging_area, (size_t) length, stream);\ + CUresult result = cuMemcpyHtoDAsync(device_ptr, staging_block->staging_area, (size_t) length, stream);\ LOG_PTX_AND_VALIDATE("cuMemcpyHtoDAsync", result); \ record_event(&afterEvent, &stream); \ - result = cuStreamAddCallback(stream, set_to_unused, staging_list, 0);\ + result = cuStreamAddCallback(stream, set_to_unused, staging_block, 0);\ LOG_PTX_AND_VALIDATE("cuStreamAddCallback", result); \ return wrapper_from_events(env, &beforeEvent, &afterEvent); #define TRANSFER_FROM_HOST_TO_DEVICE_ASYNC(TYPE, JAVATYPE) \ CUevent beforeEvent, afterEvent; \ - StagingAreaList *staging_list = get_first_free_staging_area(length);\ - env->Get## TYPE ##ArrayRegion(array, host_offset / sizeof(JAVATYPE), length / sizeof(JAVATYPE), static_cast(staging_list->staging_area));\ + StagingBlock *staging_block = get_first_free_staging_block(length);\ + env->Get## TYPE ##ArrayRegion(array, host_offset / sizeof(JAVATYPE), length / sizeof(JAVATYPE), static_cast(staging_block->staging_area));\ CUstream stream; \ stream_from_array(env, &stream, stream_wrapper); \ record_events_create(&beforeEvent, &afterEvent); \ record_event(&beforeEvent, &stream); \ - CUresult result = cuMemcpyHtoDAsync(device_ptr, staging_list->staging_area, (size_t) length, stream);\ + CUresult result = cuMemcpyHtoDAsync(device_ptr, staging_block->staging_area, (size_t) length, stream);\ LOG_PTX_AND_VALIDATE("cuMemcpyHtoDAsync", result); \ record_event(&afterEvent, &stream); \ - result = cuStreamAddCallback(stream, set_to_unused, staging_list, 0);\ + result = cuStreamAddCallback(stream, set_to_unused, staging_block, 0);\ LOG_PTX_AND_VALIDATE("cuStreamAddCallback", result); \ return wrapper_from_events(env, &beforeEvent, &afterEvent); @@ -644,8 +633,7 @@ JNIEXPORT jlong JNICALL Java_uk_ac_manchester_tornado_drivers_ptx_PTXStream_cuDe LOG_PTX_AND_VALIDATE("cuStreamDestroy", result); free_queue(); - CUresult stagingAreaResult = free_staging_area_list(); - return (jlong) result & stagingAreaResult; + return (jlong) result; } /* diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java index 648a9f8798..71af2eee9d 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java @@ -245,8 +245,15 @@ public void flush(long executionPlanId) { @Override public void reset(long executionPlanId) { - PTXStream stream = getStream(executionPlanId); - stream.reset(); + PTXStreamTable table = streamTable.get(executionPlanId); + if (table != null) { + table.cleanup(device); + if (table.size() == 0) { + streamTable.remove(executionPlanId); + } + executionIDs.remove(executionPlanId); + } + getMemoryManager().releaseKernelStackFrame(executionPlanId); codeCache.reset(); wasReset = true; } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXStreamTable.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXStreamTable.java index 1d3d5267d9..62e6e7287c 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXStreamTable.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXStreamTable.java @@ -46,6 +46,19 @@ public PTXStream get(PTXDevice device) { return deviceStream.get(device).get(Thread.currentThread().threadId()); } + public void cleanup(PTXDevice device) { + if (deviceStream.containsKey(device)) { + deviceStream.get(device).cleanup(Thread.currentThread().threadId()); + } + if (deviceStream.get(device).size() == 0) { + deviceStream.remove(device); + } + } + + public int size() { + return deviceStream.size(); + } + private static class ThreadStreamTable { private final Map streamTable; @@ -62,5 +75,17 @@ public PTXStream get(long threadId) { return streamTable.get(threadId); } + public void cleanup(long threadId) { + if (streamTable.containsKey(threadId)) { + PTXStream queue = streamTable.remove(threadId); + queue.reset(); + queue.cuDestroyStream(); + } + } + + public int size() { + return streamTable.size(); + } + } } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXByteBuffer.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXByteBuffer.java index c34def1aa5..0615ae3a27 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXByteBuffer.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXByteBuffer.java @@ -33,7 +33,7 @@ public class PTXByteBuffer { private final long address; private final long bytes; private final long offset; - private final PTXDeviceContext deviceContext; + protected final PTXDeviceContext deviceContext; public PTXByteBuffer(long address, long bytes, long offset, PTXDeviceContext deviceContext) { this.address = address; @@ -93,7 +93,7 @@ public void write(long executionPlanId, int[] events) { deviceContext.writeBuffer(executionPlanId, getAddress() + offset, bytes, buffer.array(), 0, events); } - private long getAddress() { + protected long getAddress() { return address; } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXKernelStackFrame.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXKernelStackFrame.java index be432060af..f29aa7647a 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXKernelStackFrame.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXKernelStackFrame.java @@ -35,11 +35,14 @@ public class PTXKernelStackFrame extends PTXByteBuffer implements KernelStackFra public static final int RESERVED_SLOTS = 3; private final ArrayList callArguments; + private boolean isValid; + public PTXKernelStackFrame(long address, int numArgs, PTXDeviceContext deviceContext) { super(address, RESERVED_SLOTS << 3, 0, deviceContext); this.callArguments = new ArrayList<>(numArgs); buffer.clear(); + this.isValid = true; } @Override @@ -83,4 +86,15 @@ public void setKernelContext(HashMap map) { } } } + + @Override + public boolean isValid() { + return isValid; + } + + @Override + public void invalidate() { + isValid = false; + deviceContext.getDevice().getPTXContext().freeMemory(getAddress()); + } } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemoryManager.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemoryManager.java index 2cc8085ffa..9c40ede577 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemoryManager.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemoryManager.java @@ -53,4 +53,11 @@ public PTXKernelStackFrame createCallWrapper(final long threadId, final int maxA } return ptxKernelStackFrame.get(threadId); } + + public void releaseKernelStackFrame(long executionPlanId) { + PTXKernelStackFrame stackFrame = ptxKernelStackFrame.remove(executionPlanId); + if (stackFrame != null) { + stackFrame.invalidate(); + } + } } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java index a75a990eaa..4dd3068cbe 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java @@ -155,6 +155,7 @@ public SPIRVTornadoDevice asMapping() { @Override public void reset(long executionPlanId) { + executionIds.remove(executionPlanId); spirvContext.reset(executionPlanId, getDeviceIndex()); spirvEventPool.remove(executionPlanId); getMemoryManager().releaseKernelStackFrame(executionPlanId);