Skip to content

Commit

Permalink
hip mutex signal (tinygrad#3234)
Browse files Browse the repository at this point in the history
* hip mutex

* hip mutex 2

* sync
  • Loading branch information
geohot authored Jan 24, 2024
1 parent 47f9887 commit ed8a327
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 47 deletions.
4 changes: 3 additions & 1 deletion tinygrad/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:int, var_vals: Optiona
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
host: bool = False
signal: bool = False

class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None):
Expand Down Expand Up @@ -156,7 +158,7 @@ def free_cache(self):
for opaque in opaques: self._free(opaque)
opaques.clear()
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
if getenv("LRU", 1): self.cache[(size, options)].append(opaque)
if getenv("LRU", 1) and (options is None or not options.signal): self.cache[(size, options)].append(opaque)
else: self._free(opaque)

class _MallocAllocator(LRUAllocator):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __exit__(self, *exc):
self.et = time.perf_counter_ns() - self.st
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))

def _format_fcn(fcn): return f"{fcn[0]}:{fcn[2]}" if fcn[2] != "<genexpr>" else f"{fcn[0]}:{fcn[1]}"
def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
class Profiling(contextlib.ContextDecorator):
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def is_unrealized_contiguous_const(self): return self.base == self and not self.
def schedule(self, seen=None): return create_schedule([self], seen)

def _copy(self, device:str) -> LazyBuffer:
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=self, enable_cache=True)
sync_size = 1 if self.device.startswith("HIP") else 0
sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True)
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=sync, enable_cache=True)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)

Expand Down
28 changes: 5 additions & 23 deletions tinygrad/realize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
from tinygrad.shape.symbolic import Variable
Expand All @@ -21,26 +21,6 @@ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=Fals
et = cpu_time_execution(self.device.synchronize, enable=wait or DEBUG >= 1)
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)

class SyncEvent(JITRunner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, Device[lb.device], lb.device
assert hasattr(self.device, "event_create")
setattr(self.lb, "event", self.device.event_create())
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
assert hasattr(self.device, "event_record")
self.device.event_record(self.lb.event)
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)

class WaitEvent(JITRunner):
def __init__(self, device, lb_sync):
self.lb_sync, self.device, self.dname = lb_sync, Device[device], device
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
assert hasattr(self.device, "event_wait")
self.device.event_wait(self.lb_sync.event)
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, device=self.dname)

def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op in {LoadOps.COPY, LoadOps.WAIT}, \
f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
Expand All @@ -52,8 +32,9 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
# TODO: this doesn't have to be only HIP, check if it has the event functions
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
from tinygrad.runtime.ops_hip import SyncEvent, WaitEvent
if si.ast.op is LoadOps.SYNC: return SyncEvent(si.out)
if si.ast.op is LoadOps.WAIT: return WaitEvent(si.out.device, si.inputs[0])
if si.ast.op is LoadOps.WAIT: return WaitEvent(si.out.device)
else:
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
if si.ast.op is LoadOps.WAIT: return None
Expand All @@ -78,8 +59,9 @@ def run_schedule(schedule:List[ScheduleItem]):

# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if si.out.size > 0:
options = BufferOptions(host=True, signal=True) if si.ast.op is LoadOps.SYNC else None
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None, options=options)
del si.out.srcs

# run the function (put it in JIT)
Expand Down
9 changes: 5 additions & 4 deletions tinygrad/runtime/ops_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ def __init__(self, device:CLDevice):
def _alloc(self, size:int) -> cl.cl_mem:
return checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, ctypes.byref(status := ctypes.c_int32())), status)
def _alloc_with_options(self, size:int, options:BufferOptions) -> cl.cl_mem:
assert options.image is not None
return checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, ctypes.byref(status := ctypes.c_int32())), status)
if options.image is not None:
return checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, ctypes.byref(status := ctypes.c_int32())), status)
else: return self._alloc(size)
def _free(self, buf:cl.cl_mem): check(cl.clReleaseMemObject(buf))
def copyin(self, dest:cl.cl_mem, src:memoryview):
check(cl.clEnqueueWriteBuffer(self.device.queue, dest, False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
Expand Down
47 changes: 30 additions & 17 deletions tinygrad/runtime/ops_hip.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
import ctypes, functools, subprocess, io
from typing import Tuple, TypeVar, List, Any
from typing import Tuple, TypeVar, List, Any, cast
import gpuctypes.hip as hip
from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
from tinygrad.helpers import from_mv, round_up, to_mv
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, BufferOptions
from tinygrad.helpers import from_mv, round_up, to_mv, colored
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, BufferOptions, JITRunner, Device, Buffer, update_stats
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.codegen.kernel import LinearizerOptions

Expand Down Expand Up @@ -55,15 +55,18 @@ def _alloc(self, size:int):
check(hip.hipSetDevice(self.device.device))
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
def _alloc_with_options(self, size:int, options:BufferOptions):
assert options.uncached
check(hip.hipSetDevice(self.device.device))
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
if options.uncached:
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
elif options.host:
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 2 if options.signal else 0)))
else:
raise Exception("no options")
def _free(self, opaque:T): check(hip.hipFree(opaque))
def _hostalloc(self, size:int): return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 0)))
def copy_from_fd(self, dest, fd, offset, size):
check(hip.hipSetDevice(self.device.device))
if not hasattr(self, 'hb'):
self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)]
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
self.hb_events = [None, None]
self.hb_polarity = 0
fo = io.FileIO(fd, "a+b", closefd=False)
Expand All @@ -86,7 +89,7 @@ def copy_from_fd(self, dest, fd, offset, size):
minor_offset = 0 # only on the first
def copyin(self, dest:T, src: memoryview):
check(hip.hipSetDevice(self.device.device))
host_mem = self._hostalloc(len(src))
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
self.device.pending_copyin.append(host_mem)
ctypes.memmove(host_mem, from_mv(src), len(src))
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
Expand Down Expand Up @@ -114,12 +117,22 @@ def synchronize(self):
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
self.track_cross_buffer.clear()
self.pending_copyin.clear()
def event_create(self):
check(hip.hipSetDevice(self.device))
return init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x))))
def event_record(self, evt):
check(hip.hipSetDevice(self.device))
check(hip.hipEventRecord(evt, None))
def event_wait(self, evt):
check(hip.hipSetDevice(self.device))
check(hip.hipStreamWaitEvent(None, evt, 0))

class SyncEvent(JITRunner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
check(hip.hipSetDevice(self.device.device))
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)

class WaitEvent(JITRunner):
def __init__(self, device):
self.device, self.dname = cast(HIPDevice, Device[device]), device
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
check(hip.hipSetDevice(self.device.device))
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, device=self.dname)

0 comments on commit ed8a327

Please sign in to comment.