Skip to content

Commit

Permalink
GGUF Cleanup - raise if type is not supported (tinygrad#7194)
Browse files Browse the repository at this point in the history
* raise if ggml type is unsupported

* test raise
  • Loading branch information
leopf authored Oct 21, 2024
1 parent bc9eb32 commit 815e1a3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 4 additions & 0 deletions test/unit/test_disk_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)

def test_expected_failure_unknown_type(self):
with self.assertRaises(ValueError):
ggml_data_to_tensor(Tensor.empty(512, dtype=dtypes.uint8), 256, 1337)

def _test_dequantization(self, ttype: int):
type_traits = ggml.ggml_internal_get_type_traits(ttype)
n_el, n_bytes = ggml_test_block_count * type_traits.blck_size, ggml_test_block_count * type_traits.type_size
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def q_to_uint8(t: Tensor, b: int) -> Tensor:
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)

blk_nel, blk_nb = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }[ggml_type]
blocks = t[:(n//blk_nel)*blk_nb].reshape((-1, blk_nb))
blk_info = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type, None) # map to (number of elements, number of bytes)
blocks = t if blk_info is None else t[:(n//blk_info[0])*blk_info[1]].reshape((-1, blk_info[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 3:
Expand All @@ -254,6 +254,7 @@ def q_to_uint8(t: Tensor, b: int) -> Tensor:
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")

def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
"""
Expand Down

0 comments on commit 815e1a3

Please sign in to comment.