Skip to content

Commit

Permalink
load_gguf -> gguf_load in doc and test (tinygrad#7199)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 21, 2024
1 parent f93bd9e commit f37e6b4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions test/unit/test_disk_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,10 @@ def setUp(self) -> None:
self.ctx = ctypes.cast(ggml.ggml_init(params), ctypes.POINTER(ctypes.c_void_p))
def tearDown(self) -> None: ggml.ggml_free(self.ctx)

def test_load_tinyllama_q8_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
def test_load_tinyllama_q4_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
def test_load_gpt2_q4_1(self): self._test_load_gguf("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
def test_load_sample_q6_k(self): self._test_load_gguf("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")
def test_load_tinyllama_q8_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
def test_load_gpt2_q4_1(self): self._test_gguf_load("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
def test_load_sample_q6_k(self): self._test_gguf_load("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")

def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
Expand All @@ -514,7 +514,8 @@ def _test_dequantization(self, ttype: int):
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, ttype).reshape(n_el)

np.testing.assert_equal(dq_tensor.numpy(), np.frombuffer(c_dq_data, dtype=np.float32))
def _test_load_gguf(self, url: str):

def _test_gguf_load(self, url: str):
fp = fetch(url)
model_size = os.stat(fp).st_size
gguf_tensor = Tensor.empty(model_size, dtype=dtypes.uint8, device=f"disk:{fp}").to(Device.DEFAULT)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
```python
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
kv_data, state_dict = load_gguf(gguf_tensor)
kv_data, state_dict = gguf_load(gguf_tensor)
```
"""
if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
Expand Down

0 comments on commit f37e6b4

Please sign in to comment.