Skip to content

Commit

Permalink
ggml_data_to_tensor touchups (tinygrad#7196)
Browse files Browse the repository at this point in the history
* ggml_data_to_tensor touchups

tiny reordering and variable name changes

* return type

* pylint
  • Loading branch information
chenyuxyz authored Oct 21, 2024
1 parent 815e1a3 commit f93bd9e
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,37 @@ def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
f.seek(rwd)
return TorchPickle(f).load()

def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int):
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q6_K (id: 14), Q8_0 (id: 8)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
"""
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356

bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None)
if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype)
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)

def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)

blk_info = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type, None) # map to (number of elements, number of bytes)
blocks = t if blk_info is None else t[:(n//blk_info[0])*blk_info[1]].reshape((-1, blk_info[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 3:
d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")

def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
Expand Down

0 comments on commit f93bd9e

Please sign in to comment.