diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c30cfe..08862b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,7 +64,7 @@ jobs: run: python -m pip install .[test] - name: Test package run: | - python -m pytest -vv tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|timeout|expired|connection|socket" + python -m pytest -vv tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|FileNotFoundError|timeout|expired|connection|socket" - name: Run fsspec-xrootd tests from uproot latest release run: | @@ -75,7 +75,7 @@ jobs: python -m pip install ./uproot[test] # Install xrootd-fsspec again because it may have been overwritten by uproot python -m pip install .[test] - python -m pytest -vv -k "xrootd" uproot/tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|timeout|expired|connection|socket" + python -m pytest -vv -k "xrootd" uproot/tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|FileNotFoundError|timeout|expired|connection|socket" dist: name: Distribution build diff --git a/src/fsspec_xrootd/xrootd.py b/src/fsspec_xrootd/xrootd.py index 8d36793..4dea736 100644 --- a/src/fsspec_xrootd/xrootd.py +++ b/src/fsspec_xrootd/xrootd.py @@ -259,7 +259,7 @@ async def _rmdir(self, path: str) -> None: rmdir = sync_wrapper(_rmdir) - async def _rm_file(self, path: str) -> None: + async def _rm_file(self, path: str, **kwargs: Any) -> None: status, n = await _async_wrap(self._myclient.rm, path, self.timeout) if not status.ok: raise OSError(f"File not removed properly: {status.message}") @@ -391,7 +391,7 @@ async def _cat_file(self, path: str, start: int, end: int, **kwargs: Any) -> Any try: status, _n = await _async_wrap( _myFile.open, - self.protocol + "://" + self.storage_options["hostid"] + "/" + path, + self.unstrip_protocol(path), OpenFlags.READ, self.timeout, ) @@ -412,6 +412,45 @@ async def _cat_file(self, path: str, start: int, end: int, **kwargs: Any) -> Any self.timeout, ) + async def _get_file( + self, rpath: str, lpath: str, chunk_size: int = 262_144, **kwargs: Any + ) -> None: + # Open the remote file for reading + remote_file = client.File() + + try: + status, _n = await _async_wrap( + remote_file.open, + self.unstrip_protocol(rpath), + OpenFlags.READ, + self.timeout, + ) + if not status.ok: + raise OSError(f"Remote file failed to open: {status.message}") + + with open(lpath, "wb") as local_file: + start: int = 0 + while True: + # Read a chunk of content from the remote file + status, chunk = await _async_wrap( + remote_file.read, start, chunk_size, self.timeout + ) + start += chunk_size + + if not status.ok: + raise OSError(f"Remote file failed to read: {status.message}") + + # Break if there is no more content + if not chunk: + break + + # Write the chunk to the local file + local_file.write(chunk) + + finally: + # Close the remote file + await _async_wrap(remote_file.close, self.timeout) + async def _get_max_chunk_info(self, file: Any) -> tuple[int, int]: """Queries the XRootD server for info required for pyxrootd vector_read() function. Queries for maximum number of chunks and the maximum chunk size allowed by the server. diff --git a/tests/test_basicio.py b/tests/test_basicio.py index 771804d..32f5dc6 100644 --- a/tests/test_basicio.py +++ b/tests/test_basicio.py @@ -412,3 +412,35 @@ def test_glob_full_names(localserver, clear_server): for name in full_names: with fsspec.open(name) as f: assert f.read() in [bytes(data, "utf-8") for data in [TESTDATA1, TESTDATA2]] + + +@pytest.mark.parametrize("protocol_prefix", ["", "simplecache::"]) +def test_cache(localserver, clear_server, protocol_prefix): + data = TESTDATA1 * int(1e7 / len(TESTDATA1)) # bigger than the chunk size + remoteurl, localpath = localserver + with open(localpath + "/testfile.txt", "w") as fout: + fout.write(data) + + with fsspec.open(protocol_prefix + remoteurl + "/testfile.txt", "rb") as f: + contents = f.read() + assert contents == data.encode("utf-8") + + +def test_cache_directory(localserver, clear_server, tmp_path): + remoteurl, localpath = localserver + with open(localpath + "/testfile.txt", "w") as fout: + fout.write(TESTDATA1) + + cache_directory = tmp_path / "cache" + with fsspec.open( + "simplecache::" + remoteurl + "/testfile.txt", + "rb", + simplecache={"cache_storage": str(cache_directory)}, + ) as f: + contents = f.read() + assert contents == TESTDATA1.encode("utf-8") + + assert len(os.listdir(cache_directory)) == 1 + with open(cache_directory / os.listdir(cache_directory)[0], "rb") as f: + contents = f.read() + assert contents == TESTDATA1.encode("utf-8")