Skip to content

Commit

Permalink
fix device canonicalize for :0 in middle [pr] (tinygrad#8193)
Browse files Browse the repository at this point in the history
replace is wrong because it does not check if `:0` is at the end. use re.sub instead
  • Loading branch information
chenyuxyz authored Dec 12, 2024
1 parent 40a4c60 commit d47530c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
31 changes: 15 additions & 16 deletions test/unit/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@

class TestDevice(unittest.TestCase):
def test_canonicalize(self):
assert Device.canonicalize(None) == Device.DEFAULT
assert Device.canonicalize("CPU") == "CPU"
assert Device.canonicalize("cpu") == "CPU"
assert Device.canonicalize("GPU") == "GPU"
assert Device.canonicalize("GPU:0") == "GPU"
assert Device.canonicalize("gpu:0") == "GPU"
assert Device.canonicalize("GPU:1") == "GPU:1"
assert Device.canonicalize("gpu:1") == "GPU:1"
assert Device.canonicalize("GPU:2") == "GPU:2"
assert Device.canonicalize("disk:/dev/shm/test") == "DISK:/dev/shm/test"
# TODO: fix this
# assert Device.canonicalize("disk:000.txt") == "DISK:000.txt"
self.assertEqual(Device.canonicalize(None), Device.DEFAULT)
self.assertEqual(Device.canonicalize("CPU"), "CPU")
self.assertEqual(Device.canonicalize("cpu"), "CPU")
self.assertEqual(Device.canonicalize("GPU"), "GPU")
self.assertEqual(Device.canonicalize("GPU:0"), "GPU")
self.assertEqual(Device.canonicalize("gpu:0"), "GPU")
self.assertEqual(Device.canonicalize("GPU:1"), "GPU:1")
self.assertEqual(Device.canonicalize("gpu:1"), "GPU:1")
self.assertEqual(Device.canonicalize("GPU:2"), "GPU:2")
self.assertEqual(Device.canonicalize("disk:/dev/shm/test"), "DISK:/dev/shm/test")
self.assertEqual(Device.canonicalize("disk:000.txt"), "DISK:000.txt")

def test_getitem_not_exist(self):
with self.assertRaises(ModuleNotFoundError):
Expand All @@ -34,15 +33,15 @@ def test_compile_cached(self):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
assert MockCompiler("key").compile_cached("123") == str.encode("123")
assert diskcache_get("key", "123") == str.encode("123")
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))

def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
assert diskcache_get("disabled_key", "123") is None
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
self.assertIsNone(diskcache_get("disabled_key", "123"))

def test_device_compile(self):
getenv.cache_clear()
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
Expand All @@ -14,7 +14,7 @@ class _Device:
def __init__(self) -> None:
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
Expand Down

0 comments on commit d47530c

Please sign in to comment.