Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup utils.concat() with aggregate_function #405

Merged
merged 20 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 103 additions & 25 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def concat(
*,
overwrite: bool = False,
aggregate_function: typing.Callable[[pd.Series], typing.Any] = None,
aggregate_strategy: str = 'mismatch',
) -> typing.Union[pd.Series, pd.DataFrame]:
r"""Concatenate objects.

Expand Down Expand Up @@ -74,7 +75,9 @@ def concat(
Args:
objs: objects
overwrite: overwrite values where indices overlap
aggregate_function: function to aggregate overlapping values.
aggregate_function: function to aggregate overlapping values,
that cannot be joined
when ``overwrite`` is ``False``.
The function gets a :class:`pandas.Series`
with overlapping values
as input.
Expand All @@ -84,13 +87,23 @@ def concat(
or to
``tuple``
to return them as a tuple
aggregate_strategy: if ``aggregate_function`` is not ``None``,
``aggregate_strategy`` decides
when ``aggregate_function`` is applied.
``'overlap'``: apply to all samples
that have an overlapping index;
``'mismatch'``: apply to all samples
that have an overlapping index
and a different value

Returns:
concatenated objects

Raises:
ValueError: if level and dtypes of object indices do not match
ValueError: if columns with the same name have different dtypes
ValueError: if ``aggregate_strategy`` is not one of
``'overlap'``, ``'mismatch'``
ValueError: if ``aggregate_function`` is ``None``,
``overwrite`` is ``False``,
and values in the same position do not match
Expand All @@ -115,12 +128,34 @@ def concat(
0 0 1
>>> concat(
... [
... pd.Series([1], index=pd.Index([0])),
... pd.Series([2], index=pd.Index([0])),
... pd.Series([1, 1], index=pd.Index([0, 1])),
... pd.Series([1, 1], index=pd.Index([0, 1])),
... ],
... aggregate_function=np.sum,
... )
0 3
0 1
1 1
dtype: Int64
>>> concat(
... [
... pd.Series([1, 1], index=pd.Index([0, 1])),
... pd.Series([1, 2], index=pd.Index([0, 1])),
... ],
... aggregate_function=np.sum,
... )
0 1
1 3
dtype: Int64
>>> concat(
... [
... pd.Series([1, 1], index=pd.Index([0, 1])),
... pd.Series([1, 1], index=pd.Index([0, 1])),
... ],
... aggregate_function=np.sum,
... aggregate_strategy='overlap',
... )
0 2
1 2
dtype: Int64
>>> concat(
... [
Expand Down Expand Up @@ -195,6 +230,13 @@ def concat(
f3 0 days NaT 2.0 b

"""
allowed_values = ['overlap', 'mismatch']
if aggregate_strategy not in allowed_values:
raise ValueError(
"aggregate_strategy needs to be one of: "
f"{', '.join(allowed_values)}"
)

if not objs:
return pd.Series([], index=pd.Index([]), dtype='object')

Expand Down Expand Up @@ -239,7 +281,7 @@ def concat(
raise ValueError(
"Found two columns with name "
f"'{column.name}' "
"buf different dtypes:\n"
"but different dtypes:\n"
f"{dtype_1} "
"!= "
f"{dtype_2}."
Expand All @@ -253,6 +295,19 @@ def concat(

# Handle overlapping values
if not overwrite:

def collect_overlap(overlapping_values, column, index):
"""Collect overlap for aggregate function."""
if column.name not in overlapping_values:
overlapping_values[column.name] = []
overlapping_values[column.name].append(
column.loc[index]
)
column = column.loc[~column.index.isin(index)]
column = column.dropna()
return column, overlapping_values

# Apply aggregate function only to overlapping entries
intersection = intersect(
[
columns_reindex[column.name].dropna().index,
Expand All @@ -262,28 +317,51 @@ def concat(
# We use len() here as index.empty takes a very long time
if len(intersection) > 0:

# Store overlap if custom aggregate function is provided
if aggregate_function is not None:
if column.name not in overlapping_values:
overlapping_values[column.name] = []
overlapping_values[column.name].append(
column.loc[intersection]
# Apply aggregate function
# to all overlapping entries
if (
aggregate_function is not None
and aggregate_strategy == 'overlap'
):
column, overlapping_values = collect_overlap(
overlapping_values,
column,
intersection,
)
column = column.loc[~column.index.isin(intersection)]
columns_reindex[column.name][column.index] = column
continue

# Find data that differ and cannot be joined
combine = pd.DataFrame(
{
'left':
columns_reindex[column.name][intersection],
'right':
column[intersection],
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']

if np.any(differ):

# Apply aggregate function
# to overlapping entries
# that do not match in value
if (
aggregate_function is not None
and aggregate_strategy == 'mismatch'
):
column, overlapping_values = collect_overlap(
overlapping_values,
column,
intersection[differ],
)
columns_reindex[column.name][column.index] = column
continue

# Raise error if values don't match and are not NaN
else:
combine = pd.DataFrame(
{
'left':
columns_reindex[column.name][intersection],
'right':
column[intersection],
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
# Raise error if values don't match and are not NaN
else:
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
Expand Down
Loading