diff --git a/CHANGELOG.md b/CHANGELOG.md index 70b44628..6489696c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214). ### Fixed diff --git a/setup.py b/setup.py index dc1103df..c50ba5ed 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import contextlib import os import pathlib import platform @@ -5,22 +6,13 @@ import shutil import sys import sysconfig +from importlib.util import module_from_spec, spec_from_file_location -from setuptools import setup +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext -try: - from pybind11.setup_helpers import Pybind11Extension as Extension - from pybind11.setup_helpers import build_ext -except ImportError: - from setuptools import Extension - from setuptools.command.build_ext import build_ext - HERE = pathlib.Path(__file__).absolute().parent -VERSION_FILE = HERE / 'torchopt' / 'version.py' - -sys.path.insert(0, str(VERSION_FILE.parent)) -import version # noqa class CMakeExtension(Extension): @@ -47,7 +39,6 @@ def build_extension(self, ext): build_temp.mkdir(parents=True, exist_ok=True) config = 'Debug' if self.debug else 'Release' - cmake_args = [ f'-DCMAKE_BUILD_TYPE={config}', f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', @@ -83,13 +74,53 @@ def build_extension(self, ext): build_args.extend(['--target', ext.target, '--']) + cwd = os.getcwd() try: os.chdir(build_temp) self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: self.spawn([cmake, '--build', '.', *build_args]) finally: - os.chdir(HERE) + os.chdir(cwd) + + +@contextlib.contextmanager +def vcs_version(name, path): + path = pathlib.Path(path).absolute() + assert path.is_file() + module_spec = spec_from_file_location(name=name, location=path) + assert module_spec is not None + assert module_spec.loader is not None + module = sys.modules.get(name) + if module is None: + module = module_from_spec(module_spec) + sys.modules[name] = module + module_spec.loader.exec_module(module) + + if module.__release__: + yield module + return + + content = None + try: + try: + content = path.read_text(encoding='utf-8') + path.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f'__version__ = {module.__version__!r}', + string=content, + ), + encoding='utf-8', + ) + except OSError: + content = None + + yield module + finally: + if content is not None: + with path.open(mode='wt', encoding='utf-8', newline='') as file: + file.write(content) CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' @@ -112,29 +143,9 @@ def build_extension(self, ext): ext_kwargs.clear() -VERSION_CONTENT = None - -try: - if not version.__release__: - try: - VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') - VERSION_FILE.write_text( - data=re.sub( - r"""__version__\s*=\s*('[^']+'|"[^"]+")""", - f'__version__ = {version.__version__!r}', - string=VERSION_CONTENT, - ), - encoding='utf-8', - ) - except OSError: - VERSION_CONTENT = None - +with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version: setup( name='torchopt', version=version.__version__, **ext_kwargs, ) -finally: - if VERSION_CONTENT is not None: - with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file: - file.write(VERSION_CONTENT)