Skip to content

Commit

Permalink
Support pickling Record-s
Browse files Browse the repository at this point in the history
Closes #451
  • Loading branch information
vmarkovtsev committed Feb 8, 2023
1 parent 7df9812 commit d53779b
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 14 deletions.
2 changes: 1 addition & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1024,4 +1024,4 @@ def _create_record(object mapping, tuple elems):
return rec


Record = <object>record.ApgRecord_InitTypes()
Record, RecordDescriptor = record.ApgRecord_InitTypes()
2 changes: 1 addition & 1 deletion asyncpg/protocol/record/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ cimport cpython

cdef extern from "record/recordobj.h":

cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
tuple ApgRecord_InitTypes()

int ApgRecord_CheckExact(object)
object ApgRecord_New(type, object, int)
Expand Down
130 changes: 124 additions & 6 deletions asyncpg/protocol/record/recordobj.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ static PyObject * record_new_items_iter(PyObject *);
static ApgRecordObject *free_list[ApgRecord_MAXSAVESIZE];
static int numfree[ApgRecord_MAXSAVESIZE];

static PyObject *record_reconstruct_obj;
static PyObject *record_desc_reconstruct_obj;

static PyMethodDef record_desc_methods[];

static size_t MAX_RECORD_SIZE = (
((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *))
/ sizeof(PyObject *)
Expand Down Expand Up @@ -575,14 +580,14 @@ record_repr(ApgRecordObject *v)


static PyObject *
record_values(PyObject *o, PyObject *args)
record_values(PyObject *o, PyObject *Py_UNUSED(unused))
{
return record_iter(o);
}


static PyObject *
record_keys(PyObject *o, PyObject *args)
record_keys(PyObject *o, PyObject *Py_UNUSED(unused))
{
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
Expand All @@ -594,7 +599,7 @@ record_keys(PyObject *o, PyObject *args)


static PyObject *
record_items(PyObject *o, PyObject *args)
record_items(PyObject *o, PyObject *Py_UNUSED(unused))
{
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
Expand Down Expand Up @@ -657,12 +662,71 @@ static PyMappingMethods record_as_mapping = {
0 /* mp_ass_subscript */
};

static PyMethodDef record_methods[];

static PyObject *
record_reduce(ApgRecordObject *o, PyObject *Py_UNUSED(unused))
{
PyObject *value = PyTuple_New(2);
if (value == NULL) {
return NULL;
}
Py_ssize_t len = Py_SIZE(o);
PyObject *state = PyTuple_New(1 + len);
if (state == NULL) {
Py_DECREF(value);
return NULL;
}
PyTuple_SET_ITEM(value, 0, record_reconstruct_obj);
Py_INCREF(record_reconstruct_obj);
PyTuple_SET_ITEM(value, 1, state);
PyTuple_SET_ITEM(state, 0, (PyObject *)o->desc);
Py_INCREF(o->desc);
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = ApgRecord_GET_ITEM(o, i);
PyTuple_SET_ITEM(state, i + 1, item);
Py_INCREF(item);
}
return value;
}

static PyObject *
record_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
{
if (!PyTuple_CheckExact(args)) {
return NULL;
}
Py_ssize_t len = PyTuple_GET_SIZE(args);
if (len < 2) {
return NULL;
}
len--;
ApgRecordDescObject *desc = (ApgRecordDescObject *)PyTuple_GET_ITEM(args, 0);
if (!ApgRecordDesc_CheckExact(desc)) {
return NULL;
}
if (PyObject_Length(desc->mapping) != len) {
return NULL;
}
PyObject *record = ApgRecord_New(&ApgRecord_Type, (PyObject *)desc, len);
if (record == NULL) {
return NULL;
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = PyTuple_GET_ITEM(args, i + 1);
ApgRecord_SET_ITEM(record, i, item);
Py_INCREF(item);
}
return record;
}

static PyMethodDef record_methods[] = {
{"values", (PyCFunction)record_values, METH_NOARGS},
{"keys", (PyCFunction)record_keys, METH_NOARGS},
{"items", (PyCFunction)record_items, METH_NOARGS},
{"get", (PyCFunction)record_get, METH_VARARGS},
{"__reduce__", (PyCFunction)record_reduce, METH_NOARGS},
{"__reconstruct__", (PyCFunction)record_reconstruct, METH_VARARGS | METH_STATIC},
{NULL, NULL} /* sentinel */
};

Expand Down Expand Up @@ -942,7 +1006,7 @@ record_new_items_iter(PyObject *seq)
}


PyTypeObject *
PyObject *
ApgRecord_InitTypes(void)
{
if (PyType_Ready(&ApgRecord_Type) < 0) {
Expand All @@ -961,7 +1025,22 @@ ApgRecord_InitTypes(void)
return NULL;
}

return &ApgRecord_Type;
record_reconstruct_obj = PyCFunction_New(
&record_methods[5], (PyObject *)&ApgRecord_Type
);
record_desc_reconstruct_obj = PyCFunction_New(
&record_desc_methods[1], (PyObject *)&ApgRecordDesc_Type
);

PyObject *types = PyTuple_New(2);
if (types == NULL) {
return NULL;
}
PyTuple_SET_ITEM(types, 0, (PyObject *)&ApgRecord_Type);
Py_INCREF(&ApgRecord_Type);
PyTuple_SET_ITEM(types, 1, (PyObject *)&ApgRecordDesc_Type);
Py_INCREF(&ApgRecordDesc_Type);
return types;
}


Expand All @@ -987,15 +1066,54 @@ record_desc_traverse(ApgRecordDescObject *o, visitproc visit, void *arg)
}


static PyObject *record_desc_reduce(ApgRecordDescObject *o, PyObject *Py_UNUSED(unused))
{
PyObject *value = PyTuple_New(2);
if (value == NULL) {
return NULL;
}
PyObject *state = PyTuple_New(2);
if (state == NULL) {
Py_DECREF(value);
return NULL;
}
PyTuple_SET_ITEM(value, 0, record_desc_reconstruct_obj);
Py_INCREF(record_desc_reconstruct_obj);
PyTuple_SET_ITEM(value, 1, state);
PyTuple_SET_ITEM(state, 0, o->mapping);
Py_INCREF(o->mapping);
PyTuple_SET_ITEM(state, 1, o->keys);
Py_INCREF(o->keys);
return value;
}


static PyObject *record_desc_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
{
if (PyTuple_GET_SIZE(args) != 2) {
return NULL;
}
return ApgRecordDesc_New(PyTuple_GET_ITEM(args, 0), PyTuple_GET_ITEM(args, 1));
}


static PyMethodDef record_desc_methods[] = {
{"__reduce__", (PyCFunction)record_desc_reduce, METH_NOARGS},
{"__reconstruct__", (PyCFunction)record_desc_reconstruct, METH_VARARGS | METH_STATIC},
{NULL, NULL} /* sentinel */
};


PyTypeObject ApgRecordDesc_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "RecordDescriptor",
.tp_name = "asyncpg.protocol.protocol.RecordDescriptor",
.tp_basicsize = sizeof(ApgRecordDescObject),
.tp_dealloc = (destructor)record_desc_dealloc,
.tp_getattro = PyObject_GenericGetAttr,
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
.tp_traverse = (traverseproc)record_desc_traverse,
.tp_iter = PyObject_SelfIter,
.tp_methods = record_desc_methods,
};


Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/record/recordobj.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ extern PyTypeObject ApgRecordDesc_Type;
#define ApgRecord_GET_ITEM(op, i) \
(((ApgRecordObject *)(op))->ob_item[i])

PyTypeObject *ApgRecord_InitTypes(void);
PyObject *ApgRecord_InitTypes(void);
PyObject *ApgRecord_New(PyTypeObject *, PyObject *, Py_ssize_t);
PyObject *ApgRecordDesc_New(PyObject *, PyObject *);

Expand Down
13 changes: 8 additions & 5 deletions tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,6 @@ def test_record_get(self):
self.assertEqual(r.get('nonexistent'), None)
self.assertEqual(r.get('nonexistent', 'default'), 'default')

def test_record_not_pickleable(self):
r = Record(R_A, (42,))
with self.assertRaises(Exception):
pickle.dumps(r)

def test_record_empty(self):
r = Record(None, ())
self.assertEqual(r, ())
Expand Down Expand Up @@ -575,3 +570,11 @@ class MyRecordBad:
'record_class is expected to be a subclass of asyncpg.Record',
):
await self.connect(record_class=MyRecordBad)

def test_record_pickle(self):
r = pickle.loads(pickle.dumps(Record(R_AB, (42, 43))))
self.assertEqual(len(r), 2)
self.assertEqual(r[0], 42)
self.assertEqual(r[1], 43)
self.assertEqual(r['a'], 42)
self.assertEqual(r['b'], 43)

0 comments on commit d53779b

Please sign in to comment.