Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml_data_to_tensor touchups #454

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,37 @@ def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
f.seek(rwd)
return TorchPickle(f).load()

def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int):
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.

Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q6_K (id: 14), Q8_0 (id: 8)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
"""
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356

bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None)
if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype)
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)

def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)

blk_info = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type, None) # map to (number of elements, number of bytes)
blocks = t if blk_info is None else t[:(n//blk_info[0])*blk_info[1]].reshape((-1, blk_info[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 3:
d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")

def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
Expand Down