Skip to content
This repository has been archived by the owner on Sep 13, 2024. It is now read-only.

Add support for driving BufferedProtocol instances using sock_recv_into #60

Merged
merged 8 commits into from
May 4, 2024
1 change: 1 addition & 0 deletions changes/58.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for driving ``BufferedProtocol`` instances using ``sock_recv_into`` was added.
21 changes: 17 additions & 4 deletions src/gbulb/glib_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,31 +659,44 @@ def sock_recv(self, sock, nbytes, flags=0):
channel = self._channel_from_socket(sock)

def read_func(channel, nbytes):
return sock.recv(nbytes, flags)
if not sock._closed:
return sock.recv(nbytes, flags)

return self._channel_read(channel, nbytes, read_func)

def sock_recv_into(self, sock, buf, flags=0):
channel = self._channel_from_socket(sock)

def read_func(channel, nbytes):
if not sock._closed:
return sock.recv_into(buf, flags)

return self._channel_read(channel, len(buf), read_func)

def sock_recvfrom(self, sock, nbytes, flags=0):
channel = self._channel_from_socket(sock)

def read_func(channel, nbytes):
return sock.recvfrom(nbytes, flags)
if not sock._closed:
return sock.recvfrom(nbytes, flags)

return self._channel_read(channel, nbytes, read_func)

def sock_sendall(self, sock, buf, flags=0):
channel = self._channel_from_socket(sock)

def write_func(channel, buf):
return sock.send(buf, flags)
if not sock._closed:
return sock.send(buf, flags)

return self._channel_write(channel, buf, write_func)

def sock_sendallto(self, sock, buf, addr, flags=0):
channel = self._channel_from_socket(sock)

def write_func(channel, buf):
return sock.sendto(buf, flags, addr)
if not sock._closed:
return sock.sendto(buf, flags, addr)

return self._channel_write(channel, buf, write_func)

Expand Down
93 changes: 67 additions & 26 deletions src/gbulb/transports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import collections
import io
import socket
import subprocess
import sys
Expand All @@ -14,12 +16,12 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None):

self._loop = loop
self._sock = sock
self._protocol = protocol
self._server = server
self._closing = False
self._closing_delayed = False
self._closed = False
self._cancelable = set()
self.set_protocol(protocol)

if sock is not None:
self._loop._transports[sock.fileno()] = self
Expand Down Expand Up @@ -90,15 +92,22 @@ def _force_close_async(self, exc):


class ReadTransport(BaseTransport, transports.ReadTransport):
max_size = 256 * 1024
max_size = io.DEFAULT_BUFFER_SIZE

def __init__(self, *args, **kwargs):
BaseTransport.__init__(self, *args, **kwargs)

self._paused = False
self._read_fut = None
self._read_buffer = None
self._alloc_read_buffers = False

BaseTransport.__init__(self, *args, **kwargs)

self._loop.call_soon(self._loop_reading)

def set_protocol(self, protocol):
self._alloc_read_buffers = isinstance(protocol, asyncio.BufferedProtocol)
super().set_protocol(protocol)

def pause_reading(self):
if self._closing:
raise RuntimeError("Cannot pause_reading() when closing")
Expand Down Expand Up @@ -127,22 +136,33 @@ def close(self):
super().close()

def _create_read_future(self, size):
return self._loop.sock_recv(self._sock, size)
if self._alloc_read_buffers:
self._read_buffer = self._protocol.get_buffer(size)
return self._loop.sock_recv_into(self._sock, self._read_buffer)
else:
return self._loop.sock_recv(self._sock, size)

def _submit_read_data(self, data):
if data:
self._protocol.data_received(data)
if data != b"" and data != 0:
if self._alloc_read_buffers:
assert isinstance(data, int) # Actually `nbytes`
self._protocol.buffer_updated(data)
self._read_buffer = None
else:
assert isinstance(data, bytes)
self._protocol.data_received(data)
else:
self._read_buffer = None
keep_open = self._protocol.eof_received()
if not keep_open:
self.close()

def _loop_reading(self, fut=None):
if self._paused:
return
data = None

try:
data = None
if fut is not None:
assert self._read_fut is fut or (
self._read_fut is None and self._closing
Expand All @@ -157,7 +177,10 @@ def _loop_reading(self, fut=None):
data = None
return

if data == b"":
if data is not None:
self._submit_read_data(data)

if data == b"" or data == 0:
# No need to reschedule on end-of-file
return

Expand All @@ -179,9 +202,6 @@ def _loop_reading(self, fut=None):
self._cancelable.add(self._read_fut)
else:
self._read_fut.add_done_callback(self._loop_reading)
finally:
if data is not None:
self._submit_read_data(data)


class WriteTransport(BaseTransport, transports._FlowControlMixin):
Expand All @@ -191,8 +211,8 @@ def __init__(self, loop, *args, **kwargs):
transports._FlowControlMixin.__init__(self, None, loop)
BaseTransport.__init__(self, loop, *args, **kwargs)

self._buffer = self._buffer_factory()
self._buffer_empty_callbacks = set()
self._write_buffer = self._buffer_factory()
self._drained_callbacks = set()
self._write_fut = None
self._eof_written = False

Expand All @@ -203,7 +223,7 @@ def can_write_eof(self):
return True

def get_write_buffer_size(self):
return len(self._buffer)
return len(self._write_buffer)

def _close_write(self):
if self._write_fut is not None:
Expand All @@ -213,7 +233,7 @@ def transport_write_done_callback():
self._closing_delayed = False
self.close()

self._buffer_empty_callbacks.add(transport_write_done_callback)
self._drained_callbacks.add(transport_write_done_callback)

def close(self):
self._close_write()
Expand All @@ -238,12 +258,12 @@ def _create_write_future(self, data):
return self._loop.sock_sendall(self._sock, data)

def _buffer_add_data(self, data):
self._buffer.extend(data)
self._write_buffer.extend(data)

def _buffer_pop_data(self):
if len(self._buffer) > 0:
data = self._buffer
self._buffer = bytearray()
if len(self._write_buffer) > 0:
data = self._write_buffer
self._write_buffer = self._buffer_factory()
return data
else:
return None
Expand All @@ -264,10 +284,10 @@ def _loop_writing(self, fut=None, data=None):
data = self._buffer_pop_data()

if not data:
if len(self._buffer_empty_callbacks) > 0:
for callback in self._buffer_empty_callbacks:
if len(self._drained_callbacks) > 0:
for callback in self._drained_callbacks:
callback()
self._buffer_empty_callbacks.clear()
self._drained_callbacks.clear()

self._maybe_resume_protocol()
else:
Expand Down Expand Up @@ -358,11 +378,11 @@ def _create_write_future(self, args):
def _buffer_add_data(self, args):
(data, addr) = args

self._buffer.append((bytes(data), addr))
self._write_buffer.append((bytes(data), addr))

def _buffer_pop_data(self):
if len(self._buffer) > 0:
return self._buffer.popleft()
if len(self._write_buffer) > 0:
return self._write_buffer.popleft()
else:
return None

Expand Down Expand Up @@ -392,8 +412,29 @@ def __init__(self, loop, channel, protocol, waiter, extra):
super().__init__(loop, None, protocol, waiter, extra)

def _create_read_future(self, size):
if self._alloc_read_buffers:
self._read_buffer = self._protocol.get_buffer(size)
size = len(self._read_buffer)
return self._loop._channel_read(self._channel, size)

def _submit_read_data(self, data):
assert isinstance(data, bytes)
if data != b"" and data != 0:
if self._alloc_read_buffers:
# FIXME: GLib does not actually expose the equivalent to
# `recv_into` in its channel interface, so we have to
# add an extra copy here rather than avoiding one
self._read_buffer[0 : len(data)] = data
self._protocol.buffer_updated(len(data))
self._read_buffer = None
else:
self._protocol.data_received(data)
else:
self._read_buffer = None
keep_open = self._protocol.eof_received()
if not keep_open:
self.close()

def _force_close_async(self, exc):
try:
super()._force_close_async(exc)
Expand Down