diff --git a/audformat/core/utils.py b/audformat/core/utils.py index b3c82bf4..93373cbf 100644 --- a/audformat/core/utils.py +++ b/audformat/core/utils.py @@ -38,6 +38,7 @@ def concat( *, overwrite: bool = False, aggregate_function: typing.Callable[[pd.Series], typing.Any] = None, + aggregate_strategy: str = 'mismatch', ) -> typing.Union[pd.Series, pd.DataFrame]: r"""Concatenate objects. @@ -74,7 +75,9 @@ def concat( Args: objs: objects overwrite: overwrite values where indices overlap - aggregate_function: function to aggregate overlapping values. + aggregate_function: function to aggregate overlapping values, + that cannot be joined + when ``overwrite`` is ``False``. The function gets a :class:`pandas.Series` with overlapping values as input. @@ -84,6 +87,14 @@ def concat( or to ``tuple`` to return them as a tuple + aggregate_strategy: if ``aggregate_function`` is not ``None``, + ``aggregate_strategy`` decides + when ``aggregate_function`` is applied. + ``'overlap'``: apply to all samples + that have an overlapping index; + ``'mismatch'``: apply to all samples + that have an overlapping index + and a different value Returns: concatenated objects @@ -91,6 +102,8 @@ def concat( Raises: ValueError: if level and dtypes of object indices do not match ValueError: if columns with the same name have different dtypes + ValueError: if ``aggregate_strategy`` is not one of + ``'overlap'``, ``'mismatch'`` ValueError: if ``aggregate_function`` is ``None``, ``overwrite`` is ``False``, and values in the same position do not match @@ -115,12 +128,34 @@ def concat( 0 0 1 >>> concat( ... [ - ... pd.Series([1], index=pd.Index([0])), - ... pd.Series([2], index=pd.Index([0])), + ... pd.Series([1, 1], index=pd.Index([0, 1])), + ... pd.Series([1, 1], index=pd.Index([0, 1])), ... ], ... aggregate_function=np.sum, ... ) - 0 3 + 0 1 + 1 1 + dtype: Int64 + >>> concat( + ... [ + ... pd.Series([1, 1], index=pd.Index([0, 1])), + ... pd.Series([1, 2], index=pd.Index([0, 1])), + ... ], + ... aggregate_function=np.sum, + ... ) + 0 1 + 1 3 + dtype: Int64 + >>> concat( + ... [ + ... pd.Series([1, 1], index=pd.Index([0, 1])), + ... pd.Series([1, 1], index=pd.Index([0, 1])), + ... ], + ... aggregate_function=np.sum, + ... aggregate_strategy='overlap', + ... ) + 0 2 + 1 2 dtype: Int64 >>> concat( ... [ @@ -195,6 +230,13 @@ def concat( f3 0 days NaT 2.0 b """ + allowed_values = ['overlap', 'mismatch'] + if aggregate_strategy not in allowed_values: + raise ValueError( + "aggregate_strategy needs to be one of: " + f"{', '.join(allowed_values)}" + ) + if not objs: return pd.Series([], index=pd.Index([]), dtype='object') @@ -239,7 +281,7 @@ def concat( raise ValueError( "Found two columns with name " f"'{column.name}' " - "buf different dtypes:\n" + "but different dtypes:\n" f"{dtype_1} " "!= " f"{dtype_2}." @@ -253,6 +295,19 @@ def concat( # Handle overlapping values if not overwrite: + + def collect_overlap(overlapping_values, column, index): + """Collect overlap for aggregate function.""" + if column.name not in overlapping_values: + overlapping_values[column.name] = [] + overlapping_values[column.name].append( + column.loc[index] + ) + column = column.loc[~column.index.isin(index)] + column = column.dropna() + return column, overlapping_values + + # Apply aggregate function only to overlapping entries intersection = intersect( [ columns_reindex[column.name].dropna().index, @@ -262,28 +317,51 @@ def concat( # We use len() here as index.empty takes a very long time if len(intersection) > 0: - # Store overlap if custom aggregate function is provided - if aggregate_function is not None: - if column.name not in overlapping_values: - overlapping_values[column.name] = [] - overlapping_values[column.name].append( - column.loc[intersection] + # Apply aggregate function + # to all overlapping entries + if ( + aggregate_function is not None + and aggregate_strategy == 'overlap' + ): + column, overlapping_values = collect_overlap( + overlapping_values, + column, + intersection, ) - column = column.loc[~column.index.isin(intersection)] + columns_reindex[column.name][column.index] = column + continue + + # Find data that differ and cannot be joined + combine = pd.DataFrame( + { + 'left': + columns_reindex[column.name][intersection], + 'right': + column[intersection], + } + ) + combine.dropna(inplace=True) + differ = combine['left'] != combine['right'] + + if np.any(differ): + + # Apply aggregate function + # to overlapping entries + # that do not match in value + if ( + aggregate_function is not None + and aggregate_strategy == 'mismatch' + ): + column, overlapping_values = collect_overlap( + overlapping_values, + column, + intersection[differ], + ) + columns_reindex[column.name][column.index] = column + continue - # Raise error if values don't match and are not NaN - else: - combine = pd.DataFrame( - { - 'left': - columns_reindex[column.name][intersection], - 'right': - column[intersection], - } - ) - combine.dropna(inplace=True) - differ = combine['left'] != combine['right'] - if np.any(differ): + # Raise error if values don't match and are not NaN + else: max_display = 10 overlap = combine[differ] msg_overlap = str(overlap[:max_display]) diff --git a/tests/test_utils_concat.py b/tests/test_utils_concat.py index caaf54af..22a21485 100644 --- a/tests/test_utils_concat.py +++ b/tests/test_utils_concat.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pandas as pd import pytest @@ -424,124 +426,6 @@ audformat.segmented_index(['f1', 'f2', 'f3', 'f4']), ), ), - # error: dtypes do not match - pytest.param( - [ - pd.Series([1], audformat.filewise_index('f1')), - pd.Series([1.], audformat.filewise_index('f1')), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series( - [1, 2, 3], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - ), - pd.Series( - ['a', 'b', 'a'], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - dtype='category', - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series( - ['a', 'b', 'a'], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - ), - pd.Series( - ['a', 'b', 'a'], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - dtype='category', - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series( - ['a', 'b', 'a'], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - dtype='category', - ), - pd.Series( - ['a', 'b', 'c'], - index=audformat.filewise_index(['f1', 'f2', 'f3']), - dtype='category', - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series( - [1.], - pd.Index(['f1'], name='idx', dtype='string'), - ), - pd.Series( # default dtype is object - [2.], - pd.MultiIndex.from_arrays([['f1']], names=['idx']), - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - # error: values do not match - pytest.param( - [ - pd.Series([1.], audformat.filewise_index('f1')), - pd.Series([2.], audformat.filewise_index('f1')), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series([1.], pd.Index(['f1'], name='idx')), - pd.Series( - [2.], - pd.MultiIndex.from_arrays([['f1']], names=['idx']), - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - # error: index names do not match - pytest.param( - [ - pd.Series([], index=pd.Index([], name='idx1'), dtype='object'), - pd.Series([], index=pd.Index([], name='idx2'), dtype='object'), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), - pytest.param( - [ - pd.Series([1.], pd.Index(['f1'], name='idx1')), - pd.Series( - [2.], - pd.MultiIndex.from_arrays([['f2']], names=['idx2']), - ), - ], - False, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), ], ) def test_concat(objs, overwrite, expected): @@ -578,21 +462,13 @@ def test_concat(objs, overwrite, expected): np.mean, pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), - ( - [ - pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), - pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), - ], - tuple, - pd.Series([(1, 1), (2, 2)], pd.Index(['a', 'b']), dtype='object'), - ), ( [ pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], np.sum, - pd.Series([2, 4], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -600,7 +476,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], np.var, - pd.Series([0, 0], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -608,11 +484,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], lambda y: 'a', - pd.Series( - ['a', 'a'], - pd.Index(['a', 'b']), - dtype='object', - ), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -620,11 +492,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], lambda y: 0, - pd.Series( - [0, 0], - pd.Index(['a', 'b']), - dtype='float', - ), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -632,11 +500,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), ], lambda y: 0, - pd.Series( - [0, 0], - pd.Index(['a', 'b']), - dtype='Int64', - ), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='Int64'), ), ( [ @@ -644,11 +508,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), ], lambda y: 0.5, - pd.Series( - [0.5, 0.5], - pd.Index(['a', 'b']), - dtype='float', - ), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='Int64'), ), ( [ @@ -656,11 +516,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], lambda y: ('a', 'b'), - pd.Series( - [('a', 'b'), ('a', 'b')], - pd.Index(['a', 'b']), - dtype='object', - ), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -669,7 +525,7 @@ def test_concat(objs, overwrite, expected): pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], np.sum, - pd.Series([3, 6], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ @@ -686,7 +542,7 @@ def test_concat(objs, overwrite, expected): ], np.sum, pd.Series( - [2, 4], + [1, 2], audformat.filewise_index(['a', 'b']), dtype='float', ), @@ -697,7 +553,7 @@ def test_concat(objs, overwrite, expected): pd.Series(['a', 'b'], pd.Index(['a', 'b']), dtype='string'), ], lambda y: np.char.add(y[0], y[1]), - pd.Series(['aa', 'bb'], pd.Index(['a', 'b']), dtype='string'), + pd.Series(['a', 'b'], pd.Index(['a', 'b']), dtype='string'), ), ( [ @@ -750,8 +606,8 @@ def test_concat(objs, overwrite, expected): np.sum, pd.DataFrame( { - 'A': [2, 6], - 'B': [4, 8], + 'A': [1, 3], + 'B': [2, 4], }, pd.Index(['a', 'b']), dtype='float', @@ -778,7 +634,7 @@ def test_concat(objs, overwrite, expected): np.sum, pd.DataFrame( { - 'A': [2, 6], + 'A': [1, 3], 'B': [2, 4], }, pd.Index(['a', 'b']), @@ -807,7 +663,7 @@ def test_concat(objs, overwrite, expected): pd.DataFrame( { 'A': [1, 3], - 'B': [4, 8], + 'B': [2, 4], }, pd.Index(['a', 'b']), dtype='float', @@ -834,7 +690,7 @@ def test_concat(objs, overwrite, expected): np.sum, pd.DataFrame( { - 'A': [2, 6], + 'A': [1, 3], 'B': [2, 4], }, pd.Index(['a', 'b']), @@ -862,7 +718,7 @@ def test_concat(objs, overwrite, expected): np.sum, pd.DataFrame( { - 'B': [4, 8], + 'B': [2, 4], 'A': [1, 3], }, pd.Index(['a', 'b']), @@ -870,15 +726,6 @@ def test_concat(objs, overwrite, expected): ), ), # different values - pytest.param( - [ - pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), - pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), - ], - None, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), ( [ pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), @@ -950,29 +797,6 @@ def test_concat(objs, overwrite, expected): lambda y: y[2], pd.Series([3, 4], pd.Index(['a', 'b']), dtype='float'), ), - pytest.param( - [ - pd.DataFrame( - { - 'A': [1, 3], - 'B': [2, 4], - }, - pd.Index(['a', 'b']), - dtype='float', - ), - pd.DataFrame( - { - 'A': [2, 4], - 'B': [3, 5], - }, - pd.Index(['a', 'b']), - dtype='float', - ), - ], - None, - None, - marks=pytest.mark.xfail(raises=ValueError), - ), ( [ pd.DataFrame( @@ -1287,8 +1111,8 @@ def test_concat(objs, overwrite, expected): np.sum, pd.DataFrame( { - 'A': [4, 3, 2, 5], - 'B': [4, 3, 2, 5], + 'A': [4, 3, 1, 5], + 'B': [4, 3, 1, 5], 'C': [np.NaN, 2, 1, 2], }, index=pd.Index(['a', 'b', 'c', 'd']), @@ -1298,7 +1122,10 @@ def test_concat(objs, overwrite, expected): ] ) def test_concat_aggregate_function(objs, aggregate_function, expected): - obj = audformat.utils.concat(objs, aggregate_function=aggregate_function) + obj = audformat.utils.concat( + objs, + aggregate_function=aggregate_function, + ) if isinstance(obj, pd.Series): pd.testing.assert_series_equal(obj, expected) else: @@ -1306,61 +1133,709 @@ def test_concat_aggregate_function(objs, aggregate_function, expected): @pytest.mark.parametrize( - 'objs, aggregate_function, expected', + 'objs, aggregate_function, aggregate_strategy, expected', [ - # empty ( - [], - None, - pd.Series([], pd.Index([]), dtype='object'), + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + tuple, + 'overlap', + pd.Series([(1, 1), (2, 2)], pd.Index(['a', 'b']), dtype='object'), ), - # identical values ( [ pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ], - None, + tuple, + 'mismatch', pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), ), ( [ pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), + ], + tuple, + 'overlap', + pd.Series( + [(1., 1.), (2., 2.), 3.], + pd.Index(['a', 'b', 'c']), + dtype='object', + ), + ), + ( + [ pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), ], - np.mean, - pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + tuple, + 'mismatch', + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), ), - # different values ( [ + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), - pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), ], - None, - pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + tuple, + 'overlap', + pd.Series( + [(1., 1.), (2., 2.), 3.], + pd.Index(['a', 'b', 'c']), + dtype='object', + ), ), ( [ + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), - pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), ], - np.mean, - pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + tuple, + 'mismatch', + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), ), - ] -) -def test_concat_overwrite_aggregate_function( - objs, - aggregate_function, - expected, -): - obj = audformat.utils.concat( - objs, - overwrite=True, - aggregate_function=aggregate_function, - ) - if isinstance(obj, pd.Series): - pd.testing.assert_series_equal(obj, expected) - else: - pd.testing.assert_frame_equal(obj, expected) + ( + [ + pd.Series([2, 3], pd.Index(['b', 'c']), dtype='float'), + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), + ], + tuple, + 'overlap', + pd.Series( + [(2., 2.), (3., 3.), 1.], + pd.Index(['b', 'c', 'a']), + dtype='object', + ), + ), + ( + [ + pd.Series([2, 3], pd.Index(['b', 'c']), dtype='float'), + pd.Series([1, 2, 3], pd.Index(['a', 'b', 'c']), dtype='float'), + ], + tuple, + 'mismatch', + pd.Series([2, 3, 1], pd.Index(['b', 'c', 'a']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.sum, + 'overlap', + pd.Series([2, 4], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.sum, + 'mismatch', + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.var, + 'overlap', + pd.Series([0, 0], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + lambda y: 'a', + 'overlap', + pd.Series( + ['a', 'a'], + pd.Index(['a', 'b']), + dtype='object', + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + lambda y: 0, + 'overlap', + pd.Series( + [0, 0], + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), + ], + lambda y: 0, + 'overlap', + pd.Series( + [0, 0], + pd.Index(['a', 'b']), + dtype='Int64', + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'), + ], + lambda y: 0.5, + 'overlap', + pd.Series( + [0.5, 0.5], + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + lambda y: ('a', 'b'), + 'overlap', + pd.Series( + [('a', 'b'), ('a', 'b')], + pd.Index(['a', 'b']), + dtype='object', + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.sum, + 'overlap', + pd.Series([3, 6], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series( + [1, 2], + audformat.filewise_index(['a', 'b']), + dtype='float', + ), + pd.Series( + [1, 2], + audformat.filewise_index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.Series( + [2, 4], + audformat.filewise_index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.Series(['a', 'b'], pd.Index(['a', 'b']), dtype='string'), + pd.Series(['a', 'b'], pd.Index(['a', 'b']), dtype='string'), + ], + lambda y: np.char.add(y[0], y[1]), + 'overlap', + pd.Series(['aa', 'bb'], pd.Index(['a', 'b']), dtype='string'), + ), + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + None, + 'overlap', + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.DataFrame( + { + 'A': [2, 6], + 'B': [4, 8], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [1, 3], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.DataFrame( + { + 'A': [2, 6], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.DataFrame( + { + 'A': [1, 3], + 'B': [4, 8], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.DataFrame( + { + 'A': [2, 6], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ( + [ + pd.DataFrame( + { + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + np.sum, + 'overlap', + pd.DataFrame( + { + 'B': [4, 8], + 'A': [1, 3], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ), + ], +) +def test_concat_aggregate_function_aggregate( + objs, + aggregate_function, + aggregate_strategy, + expected, +): + obj = audformat.utils.concat( + objs, + aggregate_function=aggregate_function, + aggregate_strategy=aggregate_strategy, + ) + if isinstance(obj, pd.Series): + pd.testing.assert_series_equal(obj, expected) + else: + pd.testing.assert_frame_equal(obj, expected) + + +@pytest.mark.parametrize( + 'objs, aggregate_function, expected', + [ + # empty + ( + [], + None, + pd.Series([], pd.Index([]), dtype='object'), + ), + # identical values + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + None, + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.mean, + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ), + # different values + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ], + None, + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ], + np.mean, + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ), + ] +) +def test_concat_overwrite_aggregate_function( + objs, + aggregate_function, + expected, +): + obj = audformat.utils.concat( + objs, + overwrite=True, + aggregate_function=aggregate_function, + ) + if isinstance(obj, pd.Series): + pd.testing.assert_series_equal(obj, expected) + else: + pd.testing.assert_frame_equal(obj, expected) + + +@pytest.mark.parametrize( + 'objs, aggregate_function, aggregate_strategy, ' + 'expected_error, expected_error_msg', + [ + # wrong aggregate_strategy argument + ( + [], + None, + 'non-existent', + ValueError, + "aggregate_strategy needs to be one of: overlap, mismatch", + ), + # dtypes do not match + ( + [ + pd.Series([1], audformat.filewise_index('f1')), + pd.Series([1.], audformat.filewise_index('f1')), + ], + None, + 'overlap', + ValueError, + ( + "Found two columns with name 'None' but different dtypes:\n" + "Int64 != float64." + ), + ), + ( + [ + pd.Series( + [1, 2, 3], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + ), + pd.Series( + ['a', 'b', 'a'], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + dtype='category', + ), + ], + None, + 'overlap', + ValueError, + re.escape( + "Found two columns with name 'None' but different dtypes:\n" + "Int64 != CategoricalDtype(categories=['a', 'b']" + ), + ), + ( + [ + pd.Series( + ['a', 'b', 'a'], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + dtype='string', + ), + pd.Series( + ['a', 'b', 'a'], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + dtype='category', + ), + ], + None, + 'overlap', + ValueError, + re.escape( + "Found two columns with name 'None' but different dtypes:\n" + "string != CategoricalDtype(categories=['a', 'b']" + ), + ), + ( + [ + pd.Series( + ['a', 'b', 'a'], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + dtype='category', + ), + pd.Series( + ['a', 'b', 'c'], + index=audformat.filewise_index(['f1', 'f2', 'f3']), + dtype='category', + ), + ], + None, + 'overlap', + ValueError, + ( + "Found two columns with name 'None' but different dtypes:\n" + r"CategoricalDtype\(categories=\['a', 'b'\]," + ".*" + r"!= CategoricalDtype\(categories=\['a', 'b', 'c'\]" + ), + ), + # values do not match + ( + [ + pd.DataFrame( + { + 'A': [1, 3], + 'B': [2, 4], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + pd.DataFrame( + { + 'A': [2, 4], + 'B': [3, 5], + }, + pd.Index(['a', 'b']), + dtype='float', + ), + ], + None, + 'mismatch', + ValueError, + ( + "Found overlapping data in column 'A':\n " + "left right\n" + "a 1.0 2.0\n" + "b 3.0 4.0" + ), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ], + None, + 'mismatch', + ValueError, + ( + "Found overlapping data in column 'None':\n " + "left right\n" + "a 1.0 2.0\n" + "b 2.0 3.0" + ), + ), + ( + [ + pd.Series([1.], audformat.filewise_index('f1')), + pd.Series([2.], audformat.filewise_index('f1')), + ], + None, + 'overlap', + ValueError, + ( + "Found overlapping data in column 'None':\n " + "left right\n" + "file \n" + "f1 1.0 2.0" + ), + ), + ( + [ + pd.Series([1.], pd.Index(['f1'], name='idx')), + pd.Series( + [2.], + pd.MultiIndex.from_arrays([['f1']], names=['idx']), + ), + ], + None, + 'overlap', + ValueError, + ( + "Found overlapping data in column 'None':\n " + "left right\n" + "idx \n" + "f1 1.0 2.0" + ), + ), + # index names do not match + ( + [ + pd.Series( + [1.], + pd.Index(['f1'], name='idx', dtype='string'), + ), + pd.Series( # default dtype is object + [2.], + pd.MultiIndex.from_arrays([['f1']], names=['idx']), + ), + ], + None, + 'overlap', + ValueError, + re.escape( + "Levels and dtypes of all objects must match. " + "Found different level dtypes: ['str', 'object']." + ), + ), + ( + [ + pd.Series([], index=pd.Index([], name='idx1'), dtype='object'), + pd.Series([], index=pd.Index([], name='idx2'), dtype='object'), + ], + None, + 'overlap', + ValueError, + re.escape( + "Levels and dtypes of all objects must match. " + "Found different level names: ['idx1', 'idx2']." + ), + ), + ( + [ + pd.Series([1.], pd.Index(['f1'], name='idx1')), + pd.Series( + [2.], + pd.MultiIndex.from_arrays([['f2']], names=['idx2']), + ), + ], + None, + 'overlap', + ValueError, + re.escape( + "Levels and dtypes of all objects must match. " + "Found different level names: ['idx1', 'idx2']." + ), + ), + ], +) +def test_concat_errors( + objs, + aggregate_function, + aggregate_strategy, + expected_error, + expected_error_msg, +): + with pytest.raises(expected_error, match=expected_error_msg): + audformat.utils.concat( + objs, + aggregate_function=aggregate_function, + aggregate_strategy=aggregate_strategy, + )