diff --git a/streamcapture/__init__.py b/streamcapture/__init__.py index d8cf5b7..208ff66 100644 --- a/streamcapture/__init__.py +++ b/streamcapture/__init__.py @@ -119,7 +119,7 @@ class Writer: count: int increment: int lock: threading.Lock - _write: Callable[[bytes], None] + _write: Callable[[bytes], int] def __init__(self, stream: io.IOBase, count: Optional[int] = None, lock_write: bool = False): """`Writer` constructor. @@ -137,15 +137,17 @@ def __init__(self, stream: io.IOBase, count: Optional[int] = None, lock_write: b else: (self.count, self.increment) = (count, 0) self.lock = threading.Lock() - self._write = self.locked_write if lock_write else stream.write + self._write = self.locked_write if lock_write else stream.write # type: ignore[assignment] - def write_from(self, data: bytes, cap: int) -> None: + def write_from(self, data: bytes, cap: 'FDCapture') -> int: """Perform a write operation. :param data: The bytes to write. :param cap: Unused. Remains for legacy purposes. + + :return: The amount of bytes written. """ - self._write(data) + return self._write(data) def writer_open(self) -> None: """Register that the writer is used.""" @@ -165,13 +167,15 @@ def close(self) -> None: return self.stream.close() - def locked_write(self, z: bytes) -> None: + def locked_write(self, z: bytes) -> int: """Perform the write operation in a thread-safe manner. :param z: Bytes to write. + :return: Return the amount of bytes written """ with self.lock: - self.stream.write(z) + written = self.stream.write(z) + return written class FDCapture: @@ -208,7 +212,7 @@ def __init__( writer.writer_open() (self.active, self.writer, self.fd, self.echo, self.magic) = (True, writer, fd, echo, magic) self.write = ( - (lambda data: self.writer.write_from(data, self)) + (lambda data: self.writer.write_from(data, self)) # type: ignore[union-attr, assignment] if hasattr(writer, "write_from") else writer.write ) @@ -293,8 +297,8 @@ def __init__( self.stream = stream_to_redirect self.monkeypatch = platform.system() == "Windows" if monkeypatch is None else monkeypatch if self.monkeypatch: - self.oldwrite = stream_to_redirect.write - stream_to_redirect.write = lambda z: os.write( + self.oldwrite = stream_to_redirect.write # type: ignore[assignment] + stream_to_redirect.write = lambda z: os.write( # type: ignore[method-assign] stream_to_redirect.fileno(), z.encode() if hasattr(z, "encode") else z ) else: @@ -305,7 +309,7 @@ def close(self) -> None: self.stream.flush() self.fdcapture.close() if self.monkeypatch: - self.stream.write = self.oldwrite + self.stream.write = self.oldwrite # type: ignore[assignment,method-assign] def __enter__(self): """Start the stream redirect as a contextmanager."""