Skip to content

Commit

Permalink
Fix the pytest error for async io
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Nov 20, 2024
1 parent 19e44c0 commit c661e9d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
8 changes: 8 additions & 0 deletions python/kvikio/kvikio/_lib/file_handle.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ cdef extern from "<kvikio/file_handle.hpp>" namespace "kvikio" nogil:
size_t devPtr_offset,
CUstream stream
) except +
bool is_compat_mode_preferred() except +
bool is_compat_mode_preferred_for_async() except +


cdef class CuFile:
Expand Down Expand Up @@ -175,3 +177,9 @@ cdef class CuFile:
dev_offset,
stream,
))

def is_compat_mode_preferred(self) -> bool:
return self._handle.is_compat_mode_preferred()

def is_compat_mode_preferred_for_async(self) -> bool:
return self._handle.is_compat_mode_preferred_for_async()
6 changes: 6 additions & 0 deletions python/kvikio/kvikio/cufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,9 @@ def raw_write(
to be a multiple of 4096 bytes. When GDS isn't used, this is less critical.
"""
return self._handle.write(buf, size, file_offset, dev_offset)

def is_compat_mode_preferred(self) -> bool:
return self._handle.is_compat_mode_preferred()

def is_compat_mode_preferred_for_async(self) -> bool:
return self._handle.is_compat_mode_preferred_for_async()
7 changes: 6 additions & 1 deletion python/kvikio/tests/test_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ def test_read_write(tmp_path, size):
assert check_bit_flags(f.open_flags(), os.O_WRONLY)
assert f.raw_write_async(a, stream.ptr).check_bytes_done() == a.nbytes

if f.is_compat_mode_preferred_for_async():
expected_except_msg = "Operation not permitted"
else:
expected_except_msg = "unsupported file open flags"

# Try to read file opened in write-only mode
with pytest.raises(RuntimeError, match="Operation not permitted"):
with pytest.raises(RuntimeError, match=expected_except_msg):
# The exception is raised when we call the raw_read_async API.
future_stream = f.raw_read_async(a, stream.ptr)
future_stream.check_bytes_done()
Expand Down

0 comments on commit c661e9d

Please sign in to comment.