diff --git a/changes/58.feature.rst b/changes/58.feature.rst new file mode 100644 index 0000000..98bbf9c --- /dev/null +++ b/changes/58.feature.rst @@ -0,0 +1 @@ +Support for driving ``BufferedProtocol`` instances using ``sock_recv_into`` was added. diff --git a/src/gbulb/glib_events.py b/src/gbulb/glib_events.py index 64d9598..0215a37 100644 --- a/src/gbulb/glib_events.py +++ b/src/gbulb/glib_events.py @@ -659,15 +659,26 @@ def sock_recv(self, sock, nbytes, flags=0): channel = self._channel_from_socket(sock) def read_func(channel, nbytes): - return sock.recv(nbytes, flags) + if not sock._closed: + return sock.recv(nbytes, flags) return self._channel_read(channel, nbytes, read_func) + def sock_recv_into(self, sock, buf, flags=0): + channel = self._channel_from_socket(sock) + + def read_func(channel, nbytes): + if not sock._closed: + return sock.recv_into(buf, flags) + + return self._channel_read(channel, len(buf), read_func) + def sock_recvfrom(self, sock, nbytes, flags=0): channel = self._channel_from_socket(sock) def read_func(channel, nbytes): - return sock.recvfrom(nbytes, flags) + if not sock._closed: + return sock.recvfrom(nbytes, flags) return self._channel_read(channel, nbytes, read_func) @@ -675,7 +686,8 @@ def sock_sendall(self, sock, buf, flags=0): channel = self._channel_from_socket(sock) def write_func(channel, buf): - return sock.send(buf, flags) + if not sock._closed: + return sock.send(buf, flags) return self._channel_write(channel, buf, write_func) @@ -683,7 +695,8 @@ def sock_sendallto(self, sock, buf, addr, flags=0): channel = self._channel_from_socket(sock) def write_func(channel, buf): - return sock.sendto(buf, flags, addr) + if not sock._closed: + return sock.sendto(buf, flags, addr) return self._channel_write(channel, buf, write_func) diff --git a/src/gbulb/transports.py b/src/gbulb/transports.py index 2b179a5..ac5ee9b 100644 --- a/src/gbulb/transports.py +++ b/src/gbulb/transports.py @@ -1,4 +1,6 @@ +import asyncio import collections +import io import socket import subprocess import sys @@ -14,12 +16,12 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): self._loop = loop self._sock = sock - self._protocol = protocol self._server = server self._closing = False self._closing_delayed = False self._closed = False self._cancelable = set() + self.set_protocol(protocol) if sock is not None: self._loop._transports[sock.fileno()] = self @@ -90,15 +92,22 @@ def _force_close_async(self, exc): class ReadTransport(BaseTransport, transports.ReadTransport): - max_size = 256 * 1024 + max_size = io.DEFAULT_BUFFER_SIZE def __init__(self, *args, **kwargs): - BaseTransport.__init__(self, *args, **kwargs) - self._paused = False self._read_fut = None + self._read_buffer = None + self._alloc_read_buffers = False + + BaseTransport.__init__(self, *args, **kwargs) + self._loop.call_soon(self._loop_reading) + def set_protocol(self, protocol): + self._alloc_read_buffers = isinstance(protocol, asyncio.BufferedProtocol) + super().set_protocol(protocol) + def pause_reading(self): if self._closing: raise RuntimeError("Cannot pause_reading() when closing") @@ -127,12 +136,23 @@ def close(self): super().close() def _create_read_future(self, size): - return self._loop.sock_recv(self._sock, size) + if self._alloc_read_buffers: + self._read_buffer = self._protocol.get_buffer(size) + return self._loop.sock_recv_into(self._sock, self._read_buffer) + else: + return self._loop.sock_recv(self._sock, size) def _submit_read_data(self, data): - if data: - self._protocol.data_received(data) + if data != b"" and data != 0: + if self._alloc_read_buffers: + assert isinstance(data, int) # Actually `nbytes` + self._protocol.buffer_updated(data) + self._read_buffer = None + else: + assert isinstance(data, bytes) + self._protocol.data_received(data) else: + self._read_buffer = None keep_open = self._protocol.eof_received() if not keep_open: self.close() @@ -140,9 +160,9 @@ def _submit_read_data(self, data): def _loop_reading(self, fut=None): if self._paused: return - data = None try: + data = None if fut is not None: assert self._read_fut is fut or ( self._read_fut is None and self._closing @@ -157,7 +177,10 @@ def _loop_reading(self, fut=None): data = None return - if data == b"": + if data is not None: + self._submit_read_data(data) + + if data == b"" or data == 0: # No need to reschedule on end-of-file return @@ -179,9 +202,6 @@ def _loop_reading(self, fut=None): self._cancelable.add(self._read_fut) else: self._read_fut.add_done_callback(self._loop_reading) - finally: - if data is not None: - self._submit_read_data(data) class WriteTransport(BaseTransport, transports._FlowControlMixin): @@ -191,8 +211,8 @@ def __init__(self, loop, *args, **kwargs): transports._FlowControlMixin.__init__(self, None, loop) BaseTransport.__init__(self, loop, *args, **kwargs) - self._buffer = self._buffer_factory() - self._buffer_empty_callbacks = set() + self._write_buffer = self._buffer_factory() + self._drained_callbacks = set() self._write_fut = None self._eof_written = False @@ -203,7 +223,7 @@ def can_write_eof(self): return True def get_write_buffer_size(self): - return len(self._buffer) + return len(self._write_buffer) def _close_write(self): if self._write_fut is not None: @@ -213,7 +233,7 @@ def transport_write_done_callback(): self._closing_delayed = False self.close() - self._buffer_empty_callbacks.add(transport_write_done_callback) + self._drained_callbacks.add(transport_write_done_callback) def close(self): self._close_write() @@ -238,12 +258,12 @@ def _create_write_future(self, data): return self._loop.sock_sendall(self._sock, data) def _buffer_add_data(self, data): - self._buffer.extend(data) + self._write_buffer.extend(data) def _buffer_pop_data(self): - if len(self._buffer) > 0: - data = self._buffer - self._buffer = bytearray() + if len(self._write_buffer) > 0: + data = self._write_buffer + self._write_buffer = self._buffer_factory() return data else: return None @@ -264,10 +284,10 @@ def _loop_writing(self, fut=None, data=None): data = self._buffer_pop_data() if not data: - if len(self._buffer_empty_callbacks) > 0: - for callback in self._buffer_empty_callbacks: + if len(self._drained_callbacks) > 0: + for callback in self._drained_callbacks: callback() - self._buffer_empty_callbacks.clear() + self._drained_callbacks.clear() self._maybe_resume_protocol() else: @@ -358,11 +378,11 @@ def _create_write_future(self, args): def _buffer_add_data(self, args): (data, addr) = args - self._buffer.append((bytes(data), addr)) + self._write_buffer.append((bytes(data), addr)) def _buffer_pop_data(self): - if len(self._buffer) > 0: - return self._buffer.popleft() + if len(self._write_buffer) > 0: + return self._write_buffer.popleft() else: return None @@ -392,8 +412,29 @@ def __init__(self, loop, channel, protocol, waiter, extra): super().__init__(loop, None, protocol, waiter, extra) def _create_read_future(self, size): + if self._alloc_read_buffers: + self._read_buffer = self._protocol.get_buffer(size) + size = len(self._read_buffer) return self._loop._channel_read(self._channel, size) + def _submit_read_data(self, data): + assert isinstance(data, bytes) + if data != b"" and data != 0: + if self._alloc_read_buffers: + # FIXME: GLib does not actually expose the equivalent to + # `recv_into` in its channel interface, so we have to + # add an extra copy here rather than avoiding one + self._read_buffer[0 : len(data)] = data + self._protocol.buffer_updated(len(data)) + self._read_buffer = None + else: + self._protocol.data_received(data) + else: + self._read_buffer = None + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + def _force_close_async(self, exc): try: super()._force_close_async(exc)