diff --git a/luigi/__init__.py b/luigi/__init__.py index 874014211a..f0fe6d468a 100644 --- a/luigi/__init__.py +++ b/luigi/__init__.py @@ -40,9 +40,9 @@ DateIntervalParameter, TimeDeltaParameter, IntParameter, FloatParameter, BoolParameter, PathParameter, TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, EnumListParameter, - NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter, - OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, OptionalPathParameter, - OptionalDictParameter, OptionalListParameter, OptionalTupleParameter, + NumericalParameter, ChoiceParameter, ChoiceListParameter, OptionalParameter, + OptionalStrParameter, OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, + OptionalPathParameter, OptionalDictParameter, OptionalListParameter, OptionalTupleParameter, OptionalChoiceParameter, OptionalNumericalParameter, ) @@ -66,9 +66,9 @@ 'FloatParameter', 'BoolParameter', 'PathParameter', 'TaskParameter', 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'EnumListParameter', 'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event', - 'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'OptionalStrParameter', - 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', 'OptionalPathParameter', - 'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter', + 'NumericalParameter', 'ChoiceParameter', 'ChoiceListParameter', 'OptionalParameter', + 'OptionalStrParameter', 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', + 'OptionalPathParameter', 'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter', 'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode', '__version__', ] diff --git a/luigi/parameter.py b/luigi/parameter.py index f7f137a6d1..babe2ca51a 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -1540,6 +1540,52 @@ def normalize(self, var): var=var, choices=self._choices)) +class ChoiceListParameter(ChoiceParameter): + """ + A parameter which takes two values: + 1. an instance of :class:`~collections.Iterable` and + 2. the class of the variables to convert to. + + Values are taken to be a list, i.e. order is preserved, duplicates may occur, and empty list is possible. + + In the task definition, use + + .. code-block:: python + + class MyTask(luigi.Task): + my_param = luigi.ChoiceListParameter(choices=['foo', 'bar', 'baz'], var_type=str) + + At the command line, use + + .. code-block:: console + + $ luigi --module my_tasks MyTask --my-param foo,bar + + Consider using :class:`~luigi.EnumListParameter` for a typed, structured + alternative. This class can perform the same role when all choices are the + same type and transparency of parameter value on the command line is + desired. + """ + + _sep = ',' + + def __init__(self, *args, **kwargs): + super(ChoiceListParameter, self).__init__(*args, **kwargs) + + def parse(self, s): + values = [] if s == '' else s.split(self._sep) + return self.normalize(map(self._var_type, values)) + + def normalize(self, var): + values = [] + for v in var: + values.append(super().normalize(v)) + return tuple(values) + + def serialize(self, values): + return self._sep.join(values) + + class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter): """Class to parse optional choice parameters.""" diff --git a/test/parameter_test.py b/test/parameter_test.py index 155e9b61ac..625a1f200d 100644 --- a/test/parameter_test.py +++ b/test/parameter_test.py @@ -310,6 +310,25 @@ def test_enum_list_param_invalid(self): def test_enum_list_param_missing(self): self.assertRaises(ParameterException, lambda: luigi.parameter.EnumListParameter()) + def test_choice_list_param_valid(self): + p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"]) + self.assertEqual((), p.parse('')) + self.assertEqual(("1",), p.parse('1')) + self.assertEqual(("1", "3"), p.parse('1,3')) + + def test_choice_list_param_invalid(self): + p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"]) + self.assertRaises(ValueError, lambda: p.parse('1,4')) + + def test_invalid_choice_type(self): + self.assertRaises( + AssertionError, + lambda: luigi.ChoiceListParameter(var_type=int, choices=[1, 2, "3"]), + ) + + def test_choice_list_param_missing(self): + self.assertRaises(ParameterException, lambda: luigi.parameter.ChoiceListParameter()) + def test_tuple_serialize_parse(self): a = luigi.TupleParameter() b_tuple = ((1, 2), (3, 4)) @@ -469,6 +488,13 @@ class FooWithDefault(luigi.Task): self.assertEqual(FooWithDefault().args, p.parse('C')) + def test_choice_list(self): + class Foo(luigi.Task): + args = luigi.ChoiceListParameter(var_type=str, choices=["1", "2", "3"]) + + p = luigi.ChoiceListParameter(var_type=str, choices=["3", "2", "1"]) + self.assertEqual(hash(Foo(args=("3",)).args), hash(p.parse("3"))) + def test_dict(self): class Foo(luigi.Task): args = luigi.parameter.DictParameter()