Skip to content

Commit

Permalink
Automatically convert enums from strings
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoinePrv committed Nov 23, 2023
1 parent e73d239 commit d30a473
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
14 changes: 13 additions & 1 deletion libmambapy/src/libmambapy/bindings/specs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

namespace mambapy
{

template <typename Enum>
auto enum_from_str(const pybind11::str& name)
{
auto pyenum = pybind11::type::of<Enum>();
return pyenum.attr("__members__")[name].template cast<Enum>();
}

template <typename T>
auto copy(const T& x) -> std::unique_ptr<T>
{
Expand Down Expand Up @@ -52,9 +60,11 @@ namespace mambapy
.value("win_64", Platform::win_64)
.value("win_arm64", Platform::win_arm64)
.value("zos_z", Platform::zos_z)
.def(py::init(&enum_from_str<Platform>))
.def_static("parse", &platform_parse)
.def_static("count", &known_platforms_count)
.def_static("build_platform", &build_platform);
py::implicitly_convertible<py::str, Platform>();

auto py_channel_spec = py::class_<ChannelSpec>(m, "ChannelSpec");

Expand All @@ -64,7 +74,9 @@ namespace mambapy
.value("Path", ChannelSpec::Type::Path)
.value("PackagePath", ChannelSpec::Type::PackagePath)
.value("Name", ChannelSpec::Type::Name)
.value("Unknown", ChannelSpec::Type::Unknown);
.value("Unknown", ChannelSpec::Type::Unknown)
.def(py::init(&enum_from_str<ChannelSpec::Type>));
py::implicitly_convertible<py::str, ChannelSpec::Type>();

py_channel_spec //
.def_static("parse", ChannelSpec::parse)
Expand Down
12 changes: 12 additions & 0 deletions libmambapy/tests/test_specs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy

import pytest

import libmambapy


Expand Down Expand Up @@ -27,6 +29,11 @@ def test_platform():
assert len(Platform.__members__) == Platform.count()
assert Platform.build_platform() in Platform.__members__.values()
assert Platform.parse("linux-64") == Platform.linux_64
assert Platform("linux_64") == Platform.linux_64

with pytest.raises(KeyError):
# No parsing, explicit name
Platform("linux-64") == Platform.linux_64


def test_channel_spec_type():
Expand All @@ -38,6 +45,7 @@ def test_channel_spec_type():
assert Type.PackagePath.name == "PackagePath"
assert Type.Name.name == "Name"
assert Type.Unknown.name == "Unknown"
assert Type("Name").name == "Name"


def test_channel_spec():
Expand All @@ -53,6 +61,10 @@ def test_channel_spec():
assert spec.platform_filters == set()
assert spec.type == ChannelSpec.Type.Unknown

# Enum cast
spec = ChannelSpec(location="conda-forge", platform_filters=set(), type="Name")
assert spec.type == ChannelSpec.Type.Name

# Parser
spec = ChannelSpec.parse("conda-forge[linux-64]")
assert spec.location == "conda-forge"
Expand Down

0 comments on commit d30a473

Please sign in to comment.