From 79777de566cef6acd191c66ddaf7c151ad717767 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 31 Oct 2023 17:56:04 +0000 Subject: [PATCH] Feature: Add support for __getattr__ (#234) --- docs/guide/classes.md | 23 ++++++++++++----------- example/classes.pyi | 9 +++++++++ example/classes.zig | 16 ++++++++++++++++ pydust/src/functions.zig | 1 + pydust/src/types/obj.zig | 8 ++++++++ test/test_classes.py | 8 ++++++++ 6 files changed, 54 insertions(+), 11 deletions(-) diff --git a/docs/guide/classes.md b/docs/guide/classes.md index ee422db2..3388880b 100644 --- a/docs/guide/classes.md +++ b/docs/guide/classes.md @@ -173,17 +173,18 @@ const inquiry = fn(*Self) !bool; ### Type Methods -| Method | Signature | -| :--------- | :--------------------------------------- | -| `__init__` | `#!zig fn() void` | -| `__init__` | `#!zig fn(*Self) !void` | -| `__init__` | `#!zig fn(*Self, CallArgs) !void` | -| `__del__` | `#!zig fn(*Self) void` | -| `__repr__` | `#!zig fn(*Self) !py.PyString` | -| `__str__` | `#!zig fn(*Self) !py.PyString` | -| `__call__` | `#!zig fn(*Self, CallArgs) !py.PyObject` | -| `__iter__` | `#!zig fn(*Self) !object` | -| `__next__` | `#!zig fn(*Self) !?object` | +| Method | Signature | +| :------------ | :--------------------------------------- | +| `__init__` | `#!zig fn() void` | +| `__init__` | `#!zig fn(*Self) !void` | +| `__init__` | `#!zig fn(*Self, CallArgs) !void` | +| `__del__` | `#!zig fn(*Self) void` | +| `__repr__` | `#!zig fn(*Self) !py.PyString` | +| `__str__` | `#!zig fn(*Self) !py.PyString` | +| `__call__` | `#!zig fn(*Self, CallArgs) !py.PyObject` | +| `__iter__` | `#!zig fn(*Self) !object` | +| `__next__` | `#!zig fn(*Self) !?object` | +| `__getattr__` | `#!zig fn(*Self, object) !?object` | ### Sequence Methods diff --git a/example/classes.pyi b/example/classes.pyi index 59eb051a..2f024a83 100644 --- a/example/classes.pyi +++ b/example/classes.pyi @@ -23,6 +23,15 @@ class Counter: count: ... +class GetAttr: + def __init__(self, /): + pass + def __getattribute__(self, name, /): + """ + Return getattr(self, name). + """ + ... + class Hash: def __init__(self, x, /): pass diff --git a/example/classes.zig b/example/classes.zig index 7da9ef9b..d7897b10 100644 --- a/example/classes.zig +++ b/example/classes.zig @@ -172,6 +172,22 @@ pub const Callable = py.class(struct { } }); +pub const GetAttr = py.class(struct { + const Self = @This(); + + pub fn __init__(self: *Self) void { + _ = self; + } + + pub fn __getattr__(self: *const Self, attr: py.PyString) !py.PyObject { + const name = try attr.asSlice(); + if (std.mem.eql(u8, name, "number")) { + return py.create(42); + } + return py.object(self).getAttribute(name); + } +}); + comptime { py.rootmodule(@This()); } diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 55ef9166..c67befa4 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -81,6 +81,7 @@ pub const BinaryOperators = std.ComptimeStringMap(c_int, .{ .{ "__matmul__", ffi.Py_nb_matrix_multiply }, .{ "__imatmul__", ffi.Py_nb_inplace_matrix_multiply }, .{ "__getitem__", ffi.Py_mp_subscript }, + .{ "__getattr__", ffi.Py_tp_getattro }, }); // TODO(marko): Move this somewhere. diff --git a/pydust/src/types/obj.zig b/pydust/src/types/obj.zig index 0c1c0a28..6dd2cc5f 100644 --- a/pydust/src/types/obj.zig +++ b/pydust/src/types/obj.zig @@ -55,6 +55,14 @@ pub const PyObject = extern struct { return .{ .py = ffi.PyObject_GetAttr(self.py, attrStr.obj.py) orelse return PyError.PyRaised }; } + /// Returns a new reference to the attribute of the object using default lookup semantics. + pub fn getAttribute(self: PyObject, attrName: []const u8) !py.PyObject { + const attrStr = try py.PyString.create(attrName); + defer attrStr.decref(); + + return .{ .py = ffi.PyObject_GenericGetAttr(self.py, attrStr.obj.py) orelse return PyError.PyRaised }; + } + /// Returns a new reference to the attribute of the object. pub fn getAs(self: PyObject, comptime T: type, attrName: []const u8) !T { return try py.as(T, try self.get(attrName)); diff --git a/test/test_classes.py b/test/test_classes.py index 1a7c5319..18ff19df 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -98,3 +98,11 @@ def test_refcnt(): rc = sys.getrefcount(classes) classes.Hash(42) assert sys.getrefcount(classes) == rc + + +def test_getattr(): + c = classes.GetAttr() + assert c.number == 42 + with pytest.raises(AttributeError) as exc_info: + c.attr + assert str(exc_info.value) == "'example.classes.GetAttr' object has no attribute 'attr'"