Skip to content

Commit

Permalink
Fix the race condition issue by using a threading.RLock.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Sep 1, 2024
1 parent 1bd1b9a commit 2a98962
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ extend-ignore = [
unfixable = [
"ERA", # Prevent unlikely erroneous deletion.
]
typing-modules = ["deferred._typing"]

[tool.ruff.lint.isort]
lines-after-imports = 2
Expand Down Expand Up @@ -184,7 +185,7 @@ keep-runtime-typing = true
]
"tests/stdlib_imports.py" = [
"F401", # Unused imports are fine; we're testing import success.
"ERA001", # Plenty of imports are commented out with explanation in comment.
"ERA001", # Plenty of imports are commented out with explanations next to them.
]
"benchmark/**/*.py" = [
"T201", # Printing is fine.
Expand Down
31 changes: 19 additions & 12 deletions src/deferred/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
# region -------- Compile-time hook


SourceDataType: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]"

BYTECODE_HEADER = f"deferred{__version__}".encode()
"""Custom header for deferred-instrumented bytecode files. Should be updated with every version release."""

SourceDataType: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]"


class DeferredInstrumenter(ast.NodeTransformer):
"""AST transformer that "instruments" imports within "with defer_imports_until_use: ..." blocks so that their
Expand Down Expand Up @@ -468,8 +468,10 @@ def __getattr__(self, name: str, /):

if name in self.defer_proxy_fromlist:
sub_proxy.defer_proxy_fromlist = (name,)
else:
elif name == self.defer_proxy_name.rpartition(".")[2]:
sub_proxy.defer_proxy_sub = name
else:
raise AttributeError(name)

return sub_proxy

Expand All @@ -480,7 +482,7 @@ class DeferredImportKey(str):
When referenced, the key will replace itself in the namespace with the resolved import or the right name from it.
"""

__slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_lock")
__slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_rlock")

def __new__(cls, key: str, proxy: DeferredImportProxy, /):
return super().__new__(cls, key)
Expand All @@ -490,6 +492,8 @@ def __init__(self, key: str, proxy: DeferredImportProxy, /):
self.defer_key_proxy = proxy
self.is_recursing = False

self._rlock = original_import.get()("threading").RLock()

def __repr__(self) -> str:
return f"<key for {self.defer_key_str!r} import>"

Expand All @@ -499,12 +503,18 @@ def __eq__(self, value: object, /) -> bool:
if self.defer_key_str != value:
return False

# The recursion guard allows self-referential imports within __init__.py files.
if not self.is_recursing:
self.is_recursing = True
# NOTE: This RLock prevents a scenario where the proxy begins resolution in one thread, but before it completes
# resolution, a context switch occurs and another thread that tries to resolve the same proxy just gets
# the proxy back before it has been resolved and replaced. This also partially happens because
# is_recursing is a guard that is only intended for one thread, but other threads will see it without the
# RLock.
with self._rlock:
# This recursion guard allows self-referential imports within __init__.py files.
if not self.is_recursing:
self.is_recursing = True

if not is_deferred.get():
self._resolve()
if not is_deferred.get():
self._resolve()

return True

Expand All @@ -518,8 +528,6 @@ def _resolve(self) -> None:

# Perform the original __import__ and pray.
module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args)
# FIXME: The below print is enough to trigger a race condition.
# print(f"{module=}, {proxy.defer_proxy_name=}, {proxy.defer_proxy_fromlist=}, {proxy.defer_proxy_sub=}") # noqa: ERA001

# Transfer nested proxies over to the resolved module.
module_vars = vars(module)
Expand All @@ -540,7 +548,6 @@ def _resolve(self) -> None:
# NOTE: This is necessary to prevent recursive resolution for proxies, since __eq__ will be triggered again.
_is_def_tok = is_deferred.set(True)
try:
# TODO: Figure out why this works, but del namespace[key] doesn't.
namespace[key] = namespace.pop(key)
finally:
is_deferred.reset(_is_def_tok)
Expand Down
8 changes: 0 additions & 8 deletions tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,14 +845,6 @@ def access_module_attr() -> object:
for thread in threads:
thread.join()
assert callable(thread.result) # pyright: ignore
# FIXME: Error on various versions; sometimes the accessed signature isn't resolved?
# Recreatable by putting "print(module)" after line 525:
# module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) # noqa: ERA001
#
# The error:
# AssertionError: assert False
# + where False = callable(<proxy for 'import inspect as ...'>)
# + where <proxy for 'import inspect as ...'> = <CapturingThread(Thread 4, stopped 22024)>.result
assert thread.exc is None


Expand Down

0 comments on commit 2a98962

Please sign in to comment.