From 2b4622b06c7f491a5b3a6c7c6ea0ad927fcc6c3a Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Tue, 16 Jul 2024 13:29:12 +0200 Subject: [PATCH 1/3] Add typehints, missing documentation and a small fix to stop the capturing if the underlying stream is closed. --- streamcapture/__init__.py | 298 ++++++++++++++++++++++++++------------ 1 file changed, 203 insertions(+), 95 deletions(-) diff --git a/streamcapture/__init__.py b/streamcapture/__init__.py index c61c622..e47b792 100644 --- a/streamcapture/__init__.py +++ b/streamcapture/__init__.py @@ -42,7 +42,7 @@ `monkeypatch` optional parameter to the constructor. When enabled, the workaround overwrites `stream.write(...)` by an implementation that sends everything to `os.write(self.fd,...)`. This workaround is enabled when `monkeypatch=True` and disabled when `monkeypatch=False`. -The default is `monkeypatch=None`, in which case monkeypatching is enabled only when +The default is `monkeypatch=None`, in which case monkeypatching is enabled only when `platform.system()=='Windows'`. When writing to multiple streams and file descriptors, sometimes the order in which the writes @@ -99,8 +99,8 @@ import sys, streamcapture writer = streamcapture.Writer(open('logfile.txt','wb'),2) with streamcapture.StreamCapture(sys.stdout,writer), streamcapture.StreamCapture(sys.stderr,writer): - print("This goes to stdout and is captured to logfile.txt") - print("This goes to stderr and is also captured to logfile.txt",file=sys.stderr) + print("This goes to stdout and is captured to logfile.txt") + print("This goes to stderr and is also captured to logfile.txt",file=sys.stderr) ``` In the above example, writer will be closed twice: once from the `StreamCapture(sys.stdout,...)` @@ -108,102 +108,210 @@ of the `streamcapture.Writer` was set to `2`, so that the underlying stream is only closed after 2 calls to `writer.close()`. """ +import io +import os, threading, platform +from types import TracebackType +from typing import Optional, Callable, Union, Type, TextIO -import os, sys, threading, platform, select class Writer: - def __init__(self,stream,count = None,lock_write = False): - """`Writer` constructor.""" - (self.stream,self.lock_write) = (stream,lock_write) - if count is None: - (self.count,self.increment) = (0,1) - else: - (self.count,self.increment) = (count,0) - self.lock = threading.Lock() - self._write = self.locked_write if lock_write else stream.write - def write_from(self,data,cap): - self._write(data) - def writer_open(self): - with self.lock: - self.count += self.increment - def close(self): - """When one is done using a `Writer`, one calls `Writer.close()`. This acquires `Writer.lock` so it is - thread-safe. Each time `Writer.close()` is called, `Writer.count` is decremented. When `Writer.count` - reaches `0`, `stream.close()` is called.""" - with self.lock: - self.count -= 1 - if self.count>0: - return - self.stream.close() - def locked_write(self,z): - with self.lock: - self.stream.write(z) + def __init__(self, stream: io.IOBase, count=None, lock_write=False): + """`Writer` constructor. + + Wrapper of a stream to which bytes may be written. Introduces an optional lock for which write which + may be enabled through `lock_write`. + + :param stream: The stream to wrap. + :param count: The starting number of users of this writer. + :param lock_write: Grab the lock before each write operation. + """ + (self.stream, self.lock_write) = (stream, lock_write) + if count is None: + (self.count, self.increment) = (0, 1) + else: + (self.count, self.increment) = (count, 0) + self.lock = threading.Lock() + self._write = self.locked_write if lock_write else stream.write + + def write_from(self, data: bytes, cap: int) -> None: + """Perform a write operation. + + :param data: The bytes to write. + :param cap: Unused. Remains for legacy purposes. + """ + self._write(data) + + def writer_open(self) -> None: + """Register that the writer is used.""" + with self.lock: + self.count += self.increment + + def close(self) -> None: + """Closes the writer and the underlying stream + + When one is done using a `Writer`, one calls `Writer.close()`. This acquires `Writer.lock` so it is + thread-safe. Each time `Writer.close()` is called, `Writer.count` is decremented. When `Writer.count` + reaches `0`, `stream.close()` is called. + """ + with self.lock: + self.count -= 1 + if self.count > 0: + return + self.stream.close() + + def locked_write(self, z: bytes) -> None: + """Perform the write operation in a thread-safe manner. + + :param z: Bytes to write. + """ + with self.lock: + self.stream.write(z) + class FDCapture: - def __init__(self,fd,writer,echo=True,magic=b'\x04\x81\x00\xff'): - """`FDCapture` constructor.""" - if(hasattr(writer,'writer_open')): - writer.writer_open() - (self.active, self.writer, self.fd, self.echo, self.magic) = (True,writer,fd,echo,magic) - self.write = (lambda data: self.writer.write_from(data,self)) if hasattr(writer,'write_from') else writer.write - (self.pipe_read_fd, self.pipe_write_fd) = os.pipe() - self.dup_fd = os.dup(fd) - os.dup2(self.pipe_write_fd,fd) - self.thread = threading.Thread(target=self.printer) - self.thread.start() - def printer(self): - """This is the thread that listens to the pipe output and passes it to the writer stream.""" - try: - looping = True - while looping: - data = os.read(self.pipe_read_fd,100000) - foo = data.split(self.magic) - - if len(foo)>=2: - looping = False - - for segment in foo: - if len(segment) == 0: - # Pipe is closed - looping = False - break - self.write(segment) - if self.echo: - os.write(self.dup_fd,segment) - finally: - os.close(self.pipe_read_fd) - def close(self): - """When you want to "uncapture" a stream, use this method.""" - if not self.active: - return - self.active = False - os.write(self.fd,self.magic) - self.thread.join() - os.dup2(self.dup_fd,self.fd) - os.close(self.pipe_write_fd) - os.close(self.dup_fd) - - def __enter__(self): - return self - def __exit__(self,a,b,c): - self.close() + """Redirect all output from a file descriptor and write it to `writer`.""" + + active: bool + writer: Union[io.IOBase, Writer] + fd: int + echo: bool + magic: bytes + write: Callable[[bytes], int] + + pipe_read_fd: int + pipe_write_fd: int + dup_fd: int + """Placeholder filedescriptor where the stream originally wrote to.""" + thread: threading.Thread + + def __init__( + self, + fd: int, + writer: Union[io.IOBase, Writer], + echo: bool, + magic: bytes = b"\x04\x81\x00\xff", + ): + """`FDCapture` constructor. + + :param fd: The filedescriptor to capture. + :param writer: Any bytes received from `fd` are written to this writer. + :param echo: Enable to also write bytes received to `fd` as well. + :param magic: The magic packet which denotes that the capturing process should stop. + """ + if hasattr(writer, "writer_open"): + writer.writer_open() + (self.active, self.writer, self.fd, self.echo, self.magic) = (True, writer, fd, echo, magic) + self.write = ( + (lambda data: self.writer.write_from(data, self)) + if hasattr(writer, "write_from") + else writer.write + ) + (self.pipe_read_fd, self.pipe_write_fd) = os.pipe() + self.dup_fd = os.dup(fd) + os.dup2(self.pipe_write_fd, fd) + self.thread = threading.Thread(target=self.printer) + self.thread.start() + + def printer(self): + """This is the thread that listens to the pipe output and passes it to the writer stream.""" + try: + looping = True + while looping: + data = os.read(self.pipe_read_fd, 100000) + foo = data.split(self.magic) + + # magic segment was found in data + if len(foo) >= 2: + looping = False + + for segment in foo: + # Pipe is closed + if len(segment) == 0: + break + self.write(segment) + if self.echo: + os.write(self.dup_fd, segment) + finally: + os.close(self.pipe_read_fd) + + def close(self): + """When you want to "uncapture" a stream, use this method.""" + if not self.active: + return + self.active = False + + os.write(self.fd, self.magic) + self.thread.join() + os.dup2(self.dup_fd, self.fd) + os.close(self.pipe_write_fd) + os.close(self.dup_fd) + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + class StreamCapture: - def __init__(self,stream,writer,echo=True,monkeypatch=None): - """The `StreamCapture` constructor.""" - self.fdcapture = FDCapture(stream.fileno(),writer,echo) - self.stream = stream - self.monkeypatch = platform.system()=='Windows' if monkeypatch is None else monkeypatch - if self.monkeypatch: - self.oldwrite = stream.write - stream.buffer.write = lambda z: os.write(stream.fileno(), z) - def close(self): - """When you want to "uncapture" a stream, use this method.""" - self.stream.flush() - self.fdcapture.close() - if self.monkeypatch: - self.stream.write = self.oldwrite - def __enter__(self): - return self - def __exit__(self,a,b,c): - self.close() + """Interface for users to redirect a stream to another `io.IOBase`""" + + fdcapture: FDCapture + stream: Union[io.IOBase, TextIO] + monkeypatch: bool + oldwrite: Optional[Callable[[Union[bytes, str]], None]] + + def __init__( + self, + stream_to_redirect: Union[io.IOBase, TextIO], + writer: io.IOBase, + echo: bool = True, + monkeypatch: Optional[bool] = None, + ) -> None: + """The `StreamCapture` constructor. + + :param stream_to_redirect: Stream which will be redirected. + :param writer: The stream will be redirected to this writer. It must derive from io.IOBase. + :param echo: If the redirected stream should also write any output to the original stream. + :param monkeypatch: If monkeypatching is necessary. Default is None which will perform + the monkeypatch in case this is run on Windows. Otherwise, the value of monkeypatch + is used. + """ + self.fdcapture = FDCapture(stream_to_redirect.fileno(), writer, echo) + self.stream = stream_to_redirect + self.monkeypatch = platform.system() == "Windows" if monkeypatch is None else monkeypatch + if self.monkeypatch: + self.oldwrite = stream_to_redirect.write + stream_to_redirect.write = lambda z: os.write( + stream_to_redirect.fileno(), z.encode() if hasattr(z, "encode") else z + ) + else: + self.oldwrite = None + + def close(self) -> None: + """When you want to "uncapture" a stream, use this method.""" + self.stream.flush() + self.fdcapture.close() + if self.monkeypatch: + self.stream.write = self.oldwrite + + def __enter__(self): + """Start the stream redirect as a contextmanager.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Stop the stream redirect as a contextmanager. + + Same as running StreamCapture.close() + """ + self.close() From 9d5941bc9d35c67b1162b2d9d0c9ca0c0466c7c8 Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Tue, 16 Jul 2024 13:45:14 +0200 Subject: [PATCH 2/3] Add missing typing for Writer, fix indent of Writer.writer_open, add missing looping = False statement in FDCapture.printer --- streamcapture/__init__.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/streamcapture/__init__.py b/streamcapture/__init__.py index e47b792..d8cf5b7 100644 --- a/streamcapture/__init__.py +++ b/streamcapture/__init__.py @@ -115,7 +115,13 @@ class Writer: - def __init__(self, stream: io.IOBase, count=None, lock_write=False): + stream: io.IOBase + count: int + increment: int + lock: threading.Lock + _write: Callable[[bytes], None] + + def __init__(self, stream: io.IOBase, count: Optional[int] = None, lock_write: bool = False): """`Writer` constructor. Wrapper of a stream to which bytes may be written. Introduces an optional lock for which write which @@ -143,8 +149,8 @@ def write_from(self, data: bytes, cap: int) -> None: def writer_open(self) -> None: """Register that the writer is used.""" - with self.lock: - self.count += self.increment + with self.lock: + self.count += self.increment def close(self) -> None: """Closes the writer and the underlying stream @@ -227,6 +233,7 @@ def printer(self): for segment in foo: # Pipe is closed if len(segment) == 0: + looping = False break self.write(segment) if self.echo: From efaa6cb88ef7ff49f34679dc70a409eaf99f72ef Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Tue, 16 Jul 2024 14:01:51 +0200 Subject: [PATCH 3/3] Fix or supress typing issues. --- streamcapture/__init__.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/streamcapture/__init__.py b/streamcapture/__init__.py index d8cf5b7..208ff66 100644 --- a/streamcapture/__init__.py +++ b/streamcapture/__init__.py @@ -119,7 +119,7 @@ class Writer: count: int increment: int lock: threading.Lock - _write: Callable[[bytes], None] + _write: Callable[[bytes], int] def __init__(self, stream: io.IOBase, count: Optional[int] = None, lock_write: bool = False): """`Writer` constructor. @@ -137,15 +137,17 @@ def __init__(self, stream: io.IOBase, count: Optional[int] = None, lock_write: b else: (self.count, self.increment) = (count, 0) self.lock = threading.Lock() - self._write = self.locked_write if lock_write else stream.write + self._write = self.locked_write if lock_write else stream.write # type: ignore[assignment] - def write_from(self, data: bytes, cap: int) -> None: + def write_from(self, data: bytes, cap: 'FDCapture') -> int: """Perform a write operation. :param data: The bytes to write. :param cap: Unused. Remains for legacy purposes. + + :return: The amount of bytes written. """ - self._write(data) + return self._write(data) def writer_open(self) -> None: """Register that the writer is used.""" @@ -165,13 +167,15 @@ def close(self) -> None: return self.stream.close() - def locked_write(self, z: bytes) -> None: + def locked_write(self, z: bytes) -> int: """Perform the write operation in a thread-safe manner. :param z: Bytes to write. + :return: Return the amount of bytes written """ with self.lock: - self.stream.write(z) + written = self.stream.write(z) + return written class FDCapture: @@ -208,7 +212,7 @@ def __init__( writer.writer_open() (self.active, self.writer, self.fd, self.echo, self.magic) = (True, writer, fd, echo, magic) self.write = ( - (lambda data: self.writer.write_from(data, self)) + (lambda data: self.writer.write_from(data, self)) # type: ignore[union-attr, assignment] if hasattr(writer, "write_from") else writer.write ) @@ -293,8 +297,8 @@ def __init__( self.stream = stream_to_redirect self.monkeypatch = platform.system() == "Windows" if monkeypatch is None else monkeypatch if self.monkeypatch: - self.oldwrite = stream_to_redirect.write - stream_to_redirect.write = lambda z: os.write( + self.oldwrite = stream_to_redirect.write # type: ignore[assignment] + stream_to_redirect.write = lambda z: os.write( # type: ignore[method-assign] stream_to_redirect.fileno(), z.encode() if hasattr(z, "encode") else z ) else: @@ -305,7 +309,7 @@ def close(self) -> None: self.stream.flush() self.fdcapture.close() if self.monkeypatch: - self.stream.write = self.oldwrite + self.stream.write = self.oldwrite # type: ignore[assignment,method-assign] def __enter__(self): """Start the stream redirect as a contextmanager."""