Skip to content

Commit

Permalink
[Stdlib] Add PythonObject.__contains__
Browse files Browse the repository at this point in the history
Signed-off-by: rd4com <[email protected]>
  • Loading branch information
rd4com committed Jul 16, 2024
1 parent 4837a9e commit b81e384
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 5 deletions.
10 changes: 10 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ future and `StringSlice.__len__` now does return the Unicode codepoints length.
# pwd.struct_passwd(pw_name='root', pw_passwd='*', pw_uid=0, pw_gid=0,
# pw_gecos='System Administrator', pw_dir='/var/root', pw_shell='/bin/zsh')
- Added `PythonObject.__contains__`.
([PR #3101](https://github.com/modularml/mojo/pull/3101) by [@rd4com](https://github.com/rd4com))
Example usage:
```mojo
x = PythonObject([1,2,3])
if 1 in x:
print("1 in x")
```

### 🦋 Changed
Expand Down
6 changes: 1 addition & 5 deletions docs/manual/python/types.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,13 @@
" var py_set = Python.evaluate('set([2, 3, 5, 7, 11])')\n",
" var num_items = len(py_set)\n",
" print(num_items, \" items in set.\") # prints \"5 items in set\"\n",
" print(py_set.__contains__(6)) # prints \"False\""
" print(6 in py_set) # prints \"False\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TODO: You should be able to use the expression `6 in py_set`. However, because\n",
"of the way `PythonObject` currently works, you need to call the \n",
"`__contains__()` method directly.\n",
"\n",
"Some Mojo APIs handle `PythonObject` just fine, but sometimes you'll need to \n",
"explicitly convert a Python value into a native Mojo value. \n",
"\n",
Expand Down
10 changes: 10 additions & 0 deletions stdlib/src/python/_cpython.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,16 @@ struct CPython:
self._inc_total_rc()
return r

fn PyObject_HasAttrString(
inout self,
obj: PyObjectPtr,
name: StringRef,
) -> Int:
var r = self.lib.get_function[
fn (PyObjectPtr, DTypePointer[DType.uint8]) -> Int
]("PyObject_HasAttrString")(obj, name.data)
return r

fn PyObject_GetAttrString(
inout self,
obj: PyObjectPtr,
Expand Down
22 changes: 22 additions & 0 deletions stdlib/src/python/object.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,28 @@ struct PythonObject(
"""
return self._call_zero_arg_method("__invert__")

fn __contains__(self, rhs: PythonObject) raises -> Bool:
"""Contains dunder.
Calls the underlying object's `__contains__` method.
Args:
rhs: Right hand value.
Returns:
True if rhs is in self.
"""
# TODO: replace/optimize with c-python function.
# TODO: implement __getitem__ step for cpython membership test operator.
var cpython = _get_global_python_itf().cpython()
if cpython.PyObject_HasAttrString(self.py_object, "__contains__"):
return self._call_single_arg_method("__contains__", rhs).__bool__()
for v in self:
if v[] == rhs:
return True
return False


fn _get_ptr_as_int(self) -> Int:
return self.py_object._get_ptr_as_int()

Expand Down
40 changes: 40 additions & 0 deletions stdlib/test/python/test_python_cpython.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# XFAIL: asan && !system-darwin
# RUN: %mojo %s

from python import Python, PythonObject
from testing import assert_equal, assert_false, assert_raises, assert_true


def test_PyObject_HasAttrString(inout python: Python):
var Cpython_env = python.impl._cpython

var the_object = PythonObject(0)
var result = Cpython_env[].PyObject_HasAttrString(
the_object.py_object, "__contains__"
)
assert_equal(0, result)

the_object = PythonObject([1, 2, 3])
result = Cpython_env[].PyObject_HasAttrString(
the_object.py_object, "__contains__"
)
assert_equal(1, result)
_ = the_object


def main():
# initializing Python instance calls init_python
var python = Python()
test_PyObject_HasAttrString(python)
26 changes: 26 additions & 0 deletions stdlib/test/python/test_python_object.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,31 @@ fn test_none() raises:
assert_true(n is None)


def test_contains_dunder():
with assert_raises(contains="'int' object is not iterable"):
var z = PythonObject(0)
_ = 5 in z

var x = PythonObject([1.1, 2.2])
assert_true(1.1 in x)
assert_false(3.3 in x)

x = PythonObject(["Hello", "World"])
assert_true("World" in x)

x = PythonObject((1.5, 2))
assert_true(1.5 in x)
assert_false(3.5 in x)

var y = Dict[PythonObject, PythonObject]()
y["A"] = "A"
y["B"] = 5
x = PythonObject(y)
assert_true("A" in x)
assert_false("C" in x)
assert_true("B" in x)


def main():
# initializing Python instance calls init_python
var python = Python()
Expand All @@ -411,3 +436,4 @@ def main():
test_setitem()
test_dict()
test_none()
test_contains_dunder()

0 comments on commit b81e384

Please sign in to comment.