diff --git a/src/thread/thread.py b/src/thread/thread.py index 060057d..9297236 100644 --- a/src/thread/thread.py +++ b/src/thread/thread.py @@ -11,6 +11,7 @@ class ParallelProcessing: ... import sys import time +import ctypes import signal import threading from functools import wraps @@ -32,7 +33,7 @@ class ParallelProcessing: ... HookFunction, ) from typing_extensions import Generic, ParamSpec -from typing import List, Callable, Optional, Union, Mapping, Sequence, Tuple, Generator +from typing import List, Optional, Union, Mapping, Sequence, Tuple, Generator Threads: set['Thread'] = set() @@ -56,7 +57,6 @@ class Thread(threading.Thread, Generic[_Target_P, _Target_T]): # threading.Thread stuff _initialized: bool - _run: Callable def __init__( self, @@ -116,23 +116,29 @@ def _wrap_target( def wrapper( *args: _Target_P.args, **kwargs: _Target_P.kwargs ) -> Union[_Target_T, None]: - self.status = 'Running' + try: + self.status = 'Running' - global Threads - Threads.add(self) + global Threads + Threads.add(self) - try: - self._returned_value = target(*args, **kwargs) - except Exception as e: - if not any(isinstance(e, ignore) for ignore in self.ignore_errors): - self.status = 'Errored' - self.errors.append(e) - return + try: + self._returned_value = target(*args, **kwargs) + except Exception as e: + if not any(isinstance(e, ignore) for ignore in self.ignore_errors): + self.status = 'Errored' + self.errors.append(e) + return + + self.status = 'Invoking hooks' + self._invoke_hooks() + Threads.remove(self) + self.status = 'Completed' - self.status = 'Invoking hooks' - self._invoke_hooks() - Threads.remove(self) - self.status = 'Completed' + except SystemExit: + self.status = 'Killed' + print('KILLED ident: %s' % self.ident) + return return wrapper @@ -157,27 +163,6 @@ def _handle_exceptions(self) -> None: for e in self.errors: raise e - def global_trace(self, frame, event: str, arg) -> Optional[Callable]: - if event == 'call': - return self.local_trace - - def local_trace(self, frame, event: str, arg): - if self.status == 'Kill Scheduled' and event == 'line': - print('KILLED ident: %s' % self.ident) - self.status = 'Killed' - raise SystemExit() - return self.local_trace - - def _run_with_trace(self) -> None: - """This will replace `threading.Thread`'s `run()` method""" - if not self._run: - raise exceptions.ThreadNotInitializedError( - 'Running `_run_with_trace` may cause unintended behaviour, run `start` instead' - ) - - sys.settrace(self.global_trace) - self._run() - @property def result(self) -> _Target_T: """ @@ -274,10 +259,11 @@ def kill(self, yielding: bool = False, timeout: float = 5) -> bool: Returns ------- - :returns bool: False if the it exceeded the timeout + :returns bool: False if the it exceeded the timeout without being killed Raises ------ + ValueError: If the thread ident does not exist ThreadNotInitializedError: If the thread is not initialized ThreadNotRunningError: If the thread is not running """ @@ -285,6 +271,21 @@ def kill(self, yielding: bool = False, timeout: float = 5) -> bool: raise exceptions.ThreadNotRunningError() self.status = 'Kill Scheduled' + + res: int = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(self.ident), ctypes.py_object(SystemExit) + ) + + if res == 0: + raise ValueError('Thread IDENT does not exist') + elif res > 1: + # Unexpected behaviour, something seriously went wrong + # https://docs.python.org/3/c-api/init.html#c.PyThreadState_SetAsyncExc + ctypes.pythonapi.PyThreadState_SetAsyncExc(self.ident, None) + raise SystemError( + f'Killing thread with ident [{self.ident}] failed!\nPyThreadState_SetAsyncExc returned: {res}' + ) + if not yielding: return True @@ -308,8 +309,6 @@ def start(self) -> None: if self.is_alive(): raise exceptions.ThreadStillRunningError() - self._run = self.run - self.run = self._run_with_trace super().start()