Skip to content

Commit

Permalink
Add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Nov 11, 2024
1 parent b2bd66a commit dc5fe3e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/kvikio/kvikio/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,39 @@ def compat_mode() -> bool:
return kvikio._lib.defaults.compat_mode()


def compat_mode_reset(compat_mode: bool | str) -> None:
def compat_mode_reset(compatmode: bool | str) -> None:
"""Reset the compatibility mode.
Use this function to enable/disable compatibility mode explicitly.
Parameters
----------
compat_mode : bool or str
compatmode : bool or str
bool: Set to True to enable and False to disable compatibility mode
str: Set to "ON" to enable and "OFF" to disable compatibility mode, or "AUTO"
to let KvikIO determine (try "OFF", and if failed, fall back to "ON")
"""
if isinstance(compat_mode, bool):
kvikio._lib.defaults.compat_mode_reset_bool(compat_mode)
if isinstance(compatmode, bool):
kvikio._lib.defaults.compat_mode_reset_bool(compatmode)
else:
kvikio._lib.defaults.compat_mode_reset_str(compat_mode)
kvikio._lib.defaults.compat_mode_reset_str(compatmode)


@contextlib.contextmanager
def set_compat_mode(enable: bool):
def set_compat_mode(compatmode: bool | str):
"""Context for resetting the compatibility mode.
Parameters
----------
enable : bool
Set to True to enable and False to disable compatibility mode
compatmode : bool or str
bool: Set to True to enable and False to disable compatibility mode
str: Set to "ON" to enable and "OFF" to disable compatibility mode, or "AUTO"
to let KvikIO determine (try "OFF", and if failed, fall back to "ON")
"""
num_threads_reset(get_num_threads()) # Sync all running threads
old_value = compat_mode()
try:
compat_mode_reset(enable)
compat_mode_reset(compatmode)
yield
finally:
compat_mode_reset(old_value)
Expand Down
8 changes: 8 additions & 0 deletions python/kvikio/tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def test_compat_mode():
assert not kvikio.defaults.compat_mode()


def test_compat_mode_extra():
inputs = ["", "invalidOption"]
for input in inputs:
with pytest.raises(ValueError):
with kvikio.defaults.set_compat_mode(input):
pass


def test_num_threads():
"""Test changing `num_threads`"""

Expand Down

0 comments on commit dc5fe3e

Please sign in to comment.