Skip to content

Commit

Permalink
Add audformat.utils.join_labels() (#66)
Browse files Browse the repository at this point in the history
* Add audformat.utils.join_labels()

* Add missing doc entry

* Raise error if only dict given

* Fix typo

* Remove check for dict

* Allow for overwriting dict labels

* Remove check for dict labels

Co-authored-by: Johannes Wagner <[email protected]>

* Check for different label dict types

* Switch to ValueError

Co-authored-by: Johannes Wagner <[email protected]>
  • Loading branch information
hagenw and frankenjoe authored May 6, 2021
1 parent 022865e commit bc23b7f
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
82 changes: 82 additions & 0 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
filewise_index,
segmented_index,
)
from audformat.core.scheme import Scheme


def concat(
Expand Down Expand Up @@ -286,6 +287,87 @@ def intersect(
return index


def join_labels(
labels: typing.Sequence[typing.Union[typing.List, typing.Dict]],
):
r"""Combine scheme labels.
This might be helpful,
if you would like to combine two databases
that have the same scheme,
but with different labels:
.. code-block:: python
labels = audformat.utils.join_labels(
[
db.schemes['scheme'].labels,
db_new.schemes['scheme'].labels,
]
)
db.schemes['scheme'].replace_labels(labels)
db_new.schemes['scheme'].replace_labels(labels)
db.update(db_new)
Args:
labels: sequence of labels to join.
For dictionary labels,
labels further to the right
can overwrite previous labels
Returns:
joined labels
Raises:
ValueError: if labels are of different type
ValueError: if label type is not ``list`` or ``dict``
Example:
>>> join_labels([{'a': 0, 'b': 1}, {'b': 2, 'c': 2}])
{'a': 0, 'b': 2, 'c': 2}
"""
if len(labels) == 0:
return labels

if not isinstance(labels, list):
labels = list(labels)

label_type = type(labels[0])
joined_labels = labels[0]

for label in labels[1:]:
if type(label) != label_type:
raise ValueError(
f"Labels are of different type:\n"
f"{label_type}\n"
f"!=\n"
f"{type(label)}"
)

if label_type == dict:
for label in labels[1:]:
for key, value in label.items():
if key not in joined_labels or joined_labels[key] != value:
joined_labels[key] = value
elif label_type == list:
joined_labels = list(
set(list(joined_labels) + audeer.flatten_list(labels[1:]))
)
joined_labels = sorted(audeer.flatten_list(joined_labels))
else:
raise ValueError(
f"Supported label types are 'list' and 'dict', "
f"but your is '{label_type}'"
)

# Check if joined labels have a valid format,
# e.g. {0: {'age': 20}, '0': {'age': 30}} is not allowed
Scheme(labels=joined_labels)

return joined_labels


def map_language(language: str) -> str:
r"""Map language to ISO 639-3.
Expand Down
1 change: 1 addition & 0 deletions audformat/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from audformat.core.utils import (
concat,
intersect,
join_labels,
map_language,
read_csv,
to_filewise_index,
Expand Down
6 changes: 6 additions & 0 deletions docs/api-utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ intersect
.. autofunction:: intersect


join_labels
-----------

.. autofunction:: join_labels


map_language
------------

Expand Down
74 changes: 74 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,80 @@ def test_intersect(objs, expected):
)


@pytest.mark.parametrize(
'labels, expected',
[
(
[],
[],
),
(
(['a'], ['b']),
['a', 'b'],
),
(
(['a'], ['b', 'c']),
['a', 'b', 'c'],
),
(
(['a'], ['a']),
['a'],
),
(
[{'a': 0}],
{'a': 0},
),
(
[{'a': 0}, {'b': 1}],
{'a': 0, 'b': 1},
),
(
[{'a': 0}, {'b': 1, 'c': 2}],
{'a': 0, 'b': 1, 'c': 2},
),
(
[{'a': 0, 'b': 1}, {'b': 1, 'c': 2}],
{'a': 0, 'b': 1, 'c': 2},
),
(
[{'a': 0, 'b': 1}, {'b': 2, 'c': 2}],
{'a': 0, 'b': 2, 'c': 2},
),
(
[{'a': 0}, {'a': 1}, {'a': 2}],
{'a': 2},
),
pytest.param(
['a', 'b', 'c'],
[],
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
('a', 'b', 'c'),
[],
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
[{'a': 0, 'b': 1}, ['c']],
[],
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
[['a', 'b'], ['b', 'c'], 'd'],
[],
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
[{0: {'age': 20}}, {'0': {'age': 30}}],
[],
marks=pytest.mark.xfail(raises=ValueError),
),
]
)
def test_join_labels(labels, expected):
assert utils.join_labels(labels) == expected


@pytest.mark.parametrize(
'language, expected',
[
Expand Down

0 comments on commit bc23b7f

Please sign in to comment.