diff --git a/README.md b/README.md index d436904c..a21f573c 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,21 @@ Most methods and properties from `pathlib.Path` are supported except for the one | `key` | ❌ | ✅ | ❌ | | `md5` | ✅ | ❌ | ❌ | + +## Writing to cloud files + +**Warning:** You can't call `open(CloudPath("s3://path"), "w")` and have write to the cloud file (reading works fine with the built-in open). Instead of using the Python built-in open, you must use `CloudPath("s3://path").open("w")`. For more iformation, see [#128](https://github.com/drivendataorg/cloudpathlib/issues/128) and [#140](https://github.com/drivendataorg/cloudpathlib/pull/140). + +We try to detect this scenario and raise a `BuiltInOpenWriteError` exception for you. There is a slight performance hit for this check, and if _you are sure_ that either (1) you are not writing to cloud files or (2) you are writing, but you are using the `CloudPath.open` method every time, you can skip this check by setting the environment variable `CLOUDPATHLIB_CHECK_UNSAFE_OPEN=False`. + +If you are passing the `CloudPath` into another library and you see `BuiltInOpenWriteError`, try opening and passing the buffer into that function instead: + +```python +with CloudPath("s3://bucket/path_to_write.txt").open("w") as fp: + function_that_writes(fp) +``` + + ---- Icon made by srip from www.flaticon.com. diff --git a/cloudpathlib/client.py b/cloudpathlib/client.py index 544e34ee..226d2550 100644 --- a/cloudpathlib/client.py +++ b/cloudpathlib/client.py @@ -37,7 +37,7 @@ def __init__(self, local_cache_dir: Optional[Union[str, os.PathLike]] = None): def __del__(self) -> None: # make sure temp is cleaned up if we created it - if self._cache_tmp_dir is not None: + if hasattr(self, "_cache_tmp_dir") and self._cache_tmp_dir is not None: self._cache_tmp_dir.cleanup() @classmethod diff --git a/cloudpathlib/cloudpath.py b/cloudpathlib/cloudpath.py index 29559d1a..df421f8a 100644 --- a/cloudpathlib/cloudpath.py +++ b/cloudpathlib/cloudpath.py @@ -1,14 +1,19 @@ import abc +import ast from collections import defaultdict import collections.abc import fnmatch +import inspect import os from pathlib import Path, PosixPath, PurePosixPath, WindowsPath +import sys +from textwrap import dedent from typing import Any, IO, Iterable, Optional, TYPE_CHECKING, Union from urllib.parse import urlparse from warnings import warn from .exceptions import ( + BuiltInOpenWriteError, ClientMismatchError, CloudPathFileExistsError, CloudPathIsADirectoryError, @@ -27,6 +32,11 @@ if TYPE_CHECKING: from .client import Client +CHECK_UNSAFE_OPEN = str(os.getenv("CLOUDPATHLIB_CHECK_UNSAFE_OPEN", "True").lower()) not in { + "false", + "0", +} + class CloudImplementation: def __init__(self): @@ -170,7 +180,7 @@ def __init__(self, cloud_path: Union[str, "CloudPath"], client: Optional["Client def __del__(self): # make sure that file handle to local path is closed - if self._handle is not None: + if hasattr(self, "_handle") and self._handle is not None: self._handle.close() @property @@ -205,6 +215,49 @@ def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and str(self) == str(other) def __fspath__(self): + # make sure that we're not getting called by the builtin open + # in a write mode, since we won't actually write to the cloud in + # that scenario + if CHECK_UNSAFE_OPEN: + frame = inspect.currentframe().f_back + + # line number of the call for this frame + lineno = frame.f_lineno + + # get source lines and start of the entire function + lines, start_lineno = inspect.getsourcelines(frame) + + # in some contexts like jupyter, start_lineno is 0, but should be 1-indexed + if start_lineno == 0: + start_lineno = 1 + + all_lines = "".join(lines) + + if "open" in all_lines: + # walk from this call until we find the line + # that actually has "open" call on it + # only needed on Python <= 3.7 + if (sys.version_info.major, sys.version_info.minor) <= (3, 7): + while "open" not in lines[lineno - start_lineno]: + lineno -= 1 + + # 1-indexed line within this scope + line_to_check = (lineno - start_lineno) + 1 + + # Walk the AST of the previous frame source and see if we + # ended up here from a call to the builtin open with and a writeable mode + if any( + _is_open_call_write_with_var(n, line_to_check) + for n in ast.walk(ast.parse(dedent(all_lines))) + ): + raise BuiltInOpenWriteError( + "Cannot use built-in open function with a CloudPath in a writeable mode. " + "Changes would not be uploaded to the cloud; instead, " + "please use the .open() method instead. " + "NOTE: If you are sure and want to skip this check with " + "set the env var CLOUDPATHLIB_CHECK_UNSAFE_OPEN=False" + ) + if self.is_file(): self._refresh_cache(force_overwrite_from_cloud=False) return str(self._local) @@ -749,3 +802,38 @@ def _resolve(path: PurePosixPath) -> str: newpath = newpath + sep + name return newpath or sep + + +WRITE_MODES = {"r+", "w", "w+", "a", "a+", "rb+", "wb", "wb+", "ab", "ab+"} + + +# This function is used to check if our `__fspath__` implementation has been +# called in a writeable mode from the built-in open function. +def _is_open_call_write_with_var(ast_node, lineno): + """For a given AST node, check that the node is a `Call`, and that the + call is to a function with the name `open` at line number `lineno`, + and that the last argument or the `mode` kwarg is one of the writeable modes. + """ + if not isinstance(ast_node, ast.Call): + return False + if not hasattr(ast_node, "func"): + return False + if not hasattr(ast_node.func, "id"): + return False + if ast_node.func.id != "open": + return False + + # there may be an invalid open call in the scope, + # but it is not on the line for our current stack, + # so we skip it for now since it will get parsed later + if ast_node.func.lineno != lineno: + return False + + # get the mode as second arg or kwarg where arg==mode + mode = ( + ast_node.args[1] + if len(ast_node.args) >= 2 + else [kwarg for kwarg in ast_node.keywords if kwarg.arg == "mode"][0].value + ) + + return mode.s.lower() in WRITE_MODES diff --git a/cloudpathlib/exceptions.py b/cloudpathlib/exceptions.py index b27b8aab..aa1fd9cd 100644 --- a/cloudpathlib/exceptions.py +++ b/cloudpathlib/exceptions.py @@ -12,6 +12,10 @@ class AnyPathTypeError(CloudPathException, TypeError): pass +class BuiltInOpenWriteError(CloudPathException): + pass + + class ClientMismatchError(CloudPathException, ValueError): pass diff --git a/tests/test_cloudpath_file_io.py b/tests/test_cloudpath_file_io.py index eeb7781c..03cdf6aa 100644 --- a/tests/test_cloudpath_file_io.py +++ b/tests/test_cloudpath_file_io.py @@ -5,7 +5,11 @@ import pytest -from cloudpathlib.exceptions import CloudPathIsADirectoryError, DirectoryNotEmptyError +from cloudpathlib.exceptions import ( + BuiltInOpenWriteError, + CloudPathIsADirectoryError, + DirectoryNotEmptyError, +) def test_file_discovery(rig): @@ -109,7 +113,70 @@ def test_fspath(rig): assert os.fspath(p) == p.fspath -def test_os_open(rig): +def test_os_open_read(rig): p = rig.create_cloud_path("dir_0/file0_0.txt") with open(p, "r") as f: assert f.readable() + + +# entire function is passed as source, so check separately +# that all of the built in open write modes fail +def test_os_open_write1(rig): + p = rig.create_cloud_path("dir_0/file0_0.txt") + + with open(p, "r") as f: + assert f.readable() + + with pytest.raises(BuiltInOpenWriteError): + with open(p, "w") as f: + assert f.writable() + + with pytest.raises(BuiltInOpenWriteError): + with open(p, "W") as f: + assert f.writable() + + with pytest.raises(BuiltInOpenWriteError): + with open(p, "wb") as f: + assert f.writable() + + +def test_os_open_write2(rig): + p = rig.create_cloud_path("dir_0/file0_0.txt") + + with pytest.raises(BuiltInOpenWriteError): + with open(p, "a") as f: + assert f.writable() + + with pytest.raises(BuiltInOpenWriteError): + with open(rig.create_cloud_path("dir_0/file0_0.txt"), "r+") as f: + assert f.readable() + + +def test_os_open_write3(rig): + p = rig.create_cloud_path("dir_0/file0_0.txt") + # first call should not raise even though there is an unsafe open in same scope + with open( + p, + "r", + ) as f: + assert f.readable() + + with pytest.raises(BuiltInOpenWriteError): + with open( + p, + "w", + ) as f: + assert f.writable() + + +def test_os_open_write4(monkeypatch, rig): + p = rig.create_cloud_path("dir_0/file0_0.txt") + + monkeypatch.setattr("cloudpathlib.cloudpath.CHECK_UNSAFE_OPEN", False) + + # unsafe write check is skipped + with open( + p, + "w+", + ) as f: + assert f.readable()