From d30a473da5b304fcb8eadcfad5839d7bd7d35a42 Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Thu, 23 Nov 2023 17:39:19 +0100 Subject: [PATCH] Automatically convert enums from strings --- libmambapy/src/libmambapy/bindings/specs.cpp | 14 +++++++++++++- libmambapy/tests/test_specs.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/libmambapy/src/libmambapy/bindings/specs.cpp b/libmambapy/src/libmambapy/bindings/specs.cpp index 5b2a41ce5e..08387ac9a3 100644 --- a/libmambapy/src/libmambapy/bindings/specs.cpp +++ b/libmambapy/src/libmambapy/bindings/specs.cpp @@ -17,6 +17,14 @@ namespace mambapy { + + template + auto enum_from_str(const pybind11::str& name) + { + auto pyenum = pybind11::type::of(); + return pyenum.attr("__members__")[name].template cast(); + } + template auto copy(const T& x) -> std::unique_ptr { @@ -52,9 +60,11 @@ namespace mambapy .value("win_64", Platform::win_64) .value("win_arm64", Platform::win_arm64) .value("zos_z", Platform::zos_z) + .def(py::init(&enum_from_str)) .def_static("parse", &platform_parse) .def_static("count", &known_platforms_count) .def_static("build_platform", &build_platform); + py::implicitly_convertible(); auto py_channel_spec = py::class_(m, "ChannelSpec"); @@ -64,7 +74,9 @@ namespace mambapy .value("Path", ChannelSpec::Type::Path) .value("PackagePath", ChannelSpec::Type::PackagePath) .value("Name", ChannelSpec::Type::Name) - .value("Unknown", ChannelSpec::Type::Unknown); + .value("Unknown", ChannelSpec::Type::Unknown) + .def(py::init(&enum_from_str)); + py::implicitly_convertible(); py_channel_spec // .def_static("parse", ChannelSpec::parse) diff --git a/libmambapy/tests/test_specs.py b/libmambapy/tests/test_specs.py index bd7a749097..94ecd0f256 100644 --- a/libmambapy/tests/test_specs.py +++ b/libmambapy/tests/test_specs.py @@ -1,5 +1,7 @@ import copy +import pytest + import libmambapy @@ -27,6 +29,11 @@ def test_platform(): assert len(Platform.__members__) == Platform.count() assert Platform.build_platform() in Platform.__members__.values() assert Platform.parse("linux-64") == Platform.linux_64 + assert Platform("linux_64") == Platform.linux_64 + + with pytest.raises(KeyError): + # No parsing, explicit name + Platform("linux-64") == Platform.linux_64 def test_channel_spec_type(): @@ -38,6 +45,7 @@ def test_channel_spec_type(): assert Type.PackagePath.name == "PackagePath" assert Type.Name.name == "Name" assert Type.Unknown.name == "Unknown" + assert Type("Name").name == "Name" def test_channel_spec(): @@ -53,6 +61,10 @@ def test_channel_spec(): assert spec.platform_filters == set() assert spec.type == ChannelSpec.Type.Unknown + # Enum cast + spec = ChannelSpec(location="conda-forge", platform_filters=set(), type="Name") + assert spec.type == ChannelSpec.Type.Name + # Parser spec = ChannelSpec.parse("conda-forge[linux-64]") assert spec.location == "conda-forge"