diff --git a/audeer/__init__.py b/audeer/__init__.py index f35709b..a1eee50 100644 --- a/audeer/__init__.py +++ b/audeer/__init__.py @@ -14,6 +14,7 @@ from audeer.core.io import move_file from audeer.core.io import replace_file_extension from audeer.core.io import rmdir +from audeer.core.io import script_dir from audeer.core.io import touch from audeer.core.path import path from audeer.core.path import safe_path diff --git a/audeer/core/io.py b/audeer/core/io.py index afdc4d2..04825d0 100644 --- a/audeer/core/io.py +++ b/audeer/core/io.py @@ -1,6 +1,7 @@ import errno import fnmatch import hashlib +import inspect import itertools import os import platform @@ -1007,6 +1008,36 @@ def rmdir( shutil.rmtree(path) +def script_dir() -> str: + r"""Folder in which caller of this function is located. + + When called from a file, + it returns the directory, + in which the file is stored. + When called in an interactive session, + it returns the current working directory + of the interactive session. + + Returns: + current directory of caller + + Examples: + >>> os.path.basename(script_dir()) # folder of docstring test + 'audeer_core_io_script_dir0' + + """ + # Returning the script dir is usually done with + # `os.path.dirname(os.path.realpath(__file__))`, + # see https://stackoverflow.com/a/5137509. + # We cannot use `__file__` here, + # as this would always point to this file (`io.py`). + # Instead we find the script + # of the caller of `audeer.script_dir()`, + # see https://stackoverflow.com/a/37792573 + caller = inspect.stack()[1].filename + return os.path.dirname(os.path.realpath(caller)) + + def touch( path: typing.Union[str, bytes], *paths: typing.Sequence[typing.Union[str, bytes]], diff --git a/docs/api-src/audeer.rst b/docs/api-src/audeer.rst index aa780a2..413bca8 100644 --- a/docs/api-src/audeer.rst +++ b/docs/api-src/audeer.rst @@ -39,6 +39,7 @@ audeer rmdir run_tasks safe_path + script_dir sort_versions StrictVersion to_list diff --git a/tests/test_io.py b/tests/test_io.py index 2275de4..49a08fd 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1543,6 +1543,23 @@ def test_rmdir(tmpdir): assert not os.path.exists(path) +def test_script_dir(tmpdir): + r"""Test estimation of current directory of caller. + + See https://stackoverflow.com/a/5137509. + + Args: + tmpdir: tmpdir fixture + + """ + expected_script_dir = os.path.dirname(os.path.realpath(__file__)) + assert audeer.script_dir() == expected_script_dir + current_dir = os.getcwd() + os.chdir(tmpdir) + assert audeer.script_dir() == expected_script_dir + os.chdir(current_dir) + + def test_touch(tmpdir): path = audeer.mkdir(tmpdir, "folder1") path = os.path.join(path, "file")