From 0986f6b59086e2e0947906654c1642cf264b462e Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 21 Nov 2023 10:53:24 +0100 Subject: [PATCH] Typing: Add overload signatures for `open` Added for the `FolderData` and `NodeRepository` classes. The signature of the `SinglefileData` was actually incorrect as it defined: t.Iterator[t.BinaryIO | t.TextIO] as the return type, but which should really be: t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO] The former will cause `mypy` to raise an error. --- aiida/orm/nodes/data/folder.py | 12 +++++++++- aiida/orm/nodes/data/singlefile.py | 22 +++++++++++++------ aiida/orm/nodes/repository.py | 14 ++++++++++-- aiida/parsers/plugins/arithmetic/add.py | 2 +- .../parsers/plugins/diff_tutorial/parsers.py | 5 +++-- docs/source/nitpick-exceptions | 1 + 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/aiida/orm/nodes/data/folder.py b/aiida/orm/nodes/data/folder.py index 6972edd271..67b0fe7b41 100644 --- a/aiida/orm/nodes/data/folder.py +++ b/aiida/orm/nodes/data/folder.py @@ -71,8 +71,18 @@ def list_object_names(self, path: str | None = None) -> list[str]: """ return self.base.repository.list_object_names(path) + @t.overload + @contextlib.contextmanager + def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]: + ... + + @t.overload + @contextlib.contextmanager + def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]: + ... + @contextlib.contextmanager - def open(self, path: str, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]: + def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]: """Open a file handle to an object stored under the given key. .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index 1e4249f9f5..ed9d9079df 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -22,6 +22,8 @@ __all__ = ('SinglefileData',) +FilePath = t.Union[str, pathlib.PurePosixPath] + class SinglefileData(Data): """Data class that can be used to store a single file in its repository.""" @@ -37,7 +39,9 @@ def from_string(cls, content: str, filename: str | pathlib.Path | None = None, * """ return cls(io.StringIO(content), filename, **kwargs) - def __init__(self, file: str | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any) -> None: + def __init__( + self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any + ) -> None: """Construct a new instance and set the contents to that of the file. :param file: an absolute filepath or filelike object whose contents to copy. @@ -60,26 +64,30 @@ def filename(self) -> str: @t.overload @contextlib.contextmanager - def open(self, path: str, mode: t.Literal['r']) -> t.Iterator[t.TextIO]: + def open(self, path: FilePath, mode: t.Literal['r'] = ...) -> t.Iterator[t.TextIO]: ... @t.overload @contextlib.contextmanager - def open(self, path: None, mode: t.Literal['r']) -> t.Iterator[t.TextIO]: + def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]: ... @t.overload @contextlib.contextmanager - def open(self, path: str, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]: + def open( # type: ignore[overload-overlap] + self, path: None = None, mode: t.Literal['r'] = ... + ) -> t.Iterator[t.TextIO]: ... @t.overload @contextlib.contextmanager - def open(self, path: None, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]: + def open(self, path: None = None, mode: t.Literal['rb'] = ...) -> t.Iterator[t.BinaryIO]: ... @contextlib.contextmanager - def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO | t.TextIO]: + def open(self, + path: FilePath | None = None, + mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]: """Return an open file handle to the content of this data node. :param path: the relative path of the object within the repository. @@ -113,7 +121,7 @@ def get_content(self, mode: str = 'r') -> str | bytes: with self.open(mode=mode) as handle: # type: ignore[call-overload] return handle.read() - def set_file(self, file: str | t.IO, filename: str | pathlib.Path | None = None) -> None: + def set_file(self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None) -> None: """Store the content of the file in the node's repository, deleting any other existing objects. :param file: an absolute filepath or filelike object whose contents to copy diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py index daf97b6a86..e6d1b53c8e 100644 --- a/aiida/orm/nodes/repository.py +++ b/aiida/orm/nodes/repository.py @@ -164,8 +164,18 @@ def list_object_names(self, path: str | None = None) -> list[str]: """ return self._repository.list_object_names(path) + @t.overload + @contextlib.contextmanager + def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]: + ... + + @t.overload + @contextlib.contextmanager + def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]: + ... + @contextlib.contextmanager - def open(self, path: FilePath, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]: + def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]: """Open a file handle to an object stored under the given key. .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method @@ -210,7 +220,7 @@ def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]: assert path is not None with self.open(path, mode='rb') as source: with filepath.open('wb') as target: - shutil.copyfileobj(source, target) # type: ignore[misc] + shutil.copyfileobj(source, target) yield filepath def get_object(self, path: FilePath | None = None) -> File: diff --git a/aiida/parsers/plugins/arithmetic/add.py b/aiida/parsers/plugins/arithmetic/add.py index 8c2e32ee3d..7ba9d75ee5 100644 --- a/aiida/parsers/plugins/arithmetic/add.py +++ b/aiida/parsers/plugins/arithmetic/add.py @@ -9,7 +9,7 @@ ########################################################################### # Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part # of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency. -# mypy: disable_error_code=arg-type +# mypy: disable_error_code=call-overload """Parser for an `ArithmeticAddCalculation` job.""" from aiida.parsers.parser import Parser diff --git a/aiida/parsers/plugins/diff_tutorial/parsers.py b/aiida/parsers/plugins/diff_tutorial/parsers.py index cba01284d3..76272cd4b5 100644 --- a/aiida/parsers/plugins/diff_tutorial/parsers.py +++ b/aiida/parsers/plugins/diff_tutorial/parsers.py @@ -4,6 +4,7 @@ Register parsers via the "aiida.parsers" entry point in the pyproject.toml file. """ +# mypy: disable_error_code=call-overload # START PARSER HEAD from aiida.engine import ExitCode from aiida.orm import SinglefileData @@ -38,7 +39,7 @@ def parse(self, **kwargs): # add output file self.logger.info(f"Parsing '{output_filename}'") - with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type] + with self.retrieved.open(output_filename, 'rb') as handle: output_node = SinglefileData(file=handle) self.out('diff', output_node) @@ -59,7 +60,7 @@ def parse(self, **kwargs): # add output file self.logger.info(f"Parsing '{output_filename}'") - with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type] + with self.retrieved.open(output_filename, 'rb') as handle: output_node = SinglefileData(file=handle) self.out('diff', output_node) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 1c0be753a2..0f3aed4811 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -14,6 +14,7 @@ py:class BinaryIO py:class EntryPoint py:class EntryPoints py:class IO +py:class FilePath py:class Path py:class str | list[str] py:class str | Path