Skip to content

Commit

Permalink
Add aggregate_function argument
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 25, 2023
1 parent 253fb28 commit 8ac6c22
Show file tree
Hide file tree
Showing 2 changed files with 353 additions and 20 deletions.
91 changes: 71 additions & 20 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def concat(
objs: typing.Sequence[typing.Union[pd.Series, pd.DataFrame]],
*,
overwrite: bool = False,
aggregate_function: typing.Callable[[pd.Series], typing.Any] = None,
) -> typing.Union[pd.Series, pd.DataFrame]:
r"""Concatenate objects.
Expand Down Expand Up @@ -64,18 +65,28 @@ def concat(
or one column contains ``NaN``.
If ``overwrite`` is set to ``True``,
the value of the last object in the list is kept.
If ``overwrite`` is set to ``False``,
a custom aggregation function can be provided
with ``aggregate_function``
that converts the overlapping values
into a single one.
Args:
objs: objects
overwrite: overwrite values where indices overlap
aggregate_function: function to be applied on all entries
that contain more then one data point per index.
The function gets a dataframe row as input,
and is expected to return a single value
Returns:
concatenated objects
Raises:
ValueError: if level and dtypes of object indices do not match
ValueError: if columns with the same name have different dtypes
ValueError: if values in the same position do not match
ValueError: if ``aggregate_function`` is ``None``
and values in the same position do not match
Examples:
>>> concat(
Expand All @@ -97,6 +108,15 @@ def concat(
0 0 1
>>> concat(
... [
... pd.Series([1], index=pd.Index([0])),
... pd.Series([2], index=pd.Index([0])),
... ],
... aggregate_function=np.sum,
... )
0 3
dtype: Int64
>>> concat(
... [
... pd.Series(
... [0., 1.],
... index=pd.Index(
Expand Down Expand Up @@ -194,6 +214,7 @@ def concat(

# reindex all columns to the new index
columns_reindex = {}
overlapping_values = {}
for column in columns:

# if we already have a column with that name, we have to merge them
Expand Down Expand Up @@ -233,26 +254,41 @@ def concat(
)
# We use len() here as index.empty takes a very long time
if len(intersection) > 0:
combine = pd.DataFrame(
{
'left': columns_reindex[column.name][intersection],
'right': column[intersection]
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"

# Custom handling of overlapping values
if aggregate_function is not None:
if column.name not in overlapping_values:
overlapping_values[column.name] = [
columns_reindex[column.name].loc[intersection]
]
overlapping_values[column.name].append(
column.loc[intersection]
)
column = column.loc[~column.index.isin(intersection)]

else:
combine = pd.DataFrame(
{
'left':
columns_reindex[column.name][intersection],
'right':
column[intersection]
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"
)

# drop NaN to avoid overwriting values from other column
column = column.dropna()
Expand All @@ -269,6 +305,21 @@ def concat(
)
columns_reindex[column.name][column.index] = column

# Apply custom aggregation function
# on collected overlapping data
# (no overlapping data is collected
# when no aggregation function is provided)
if len(overlapping_values) > 0:
for column in overlapping_values:
df = pd.concat(
overlapping_values[column],
axis=1,
ignore_index=True,
)
dtype = columns_reindex[column].dtype
columns_reindex[column] = df.apply(aggregate_function, axis=1)
columns_reindex[column] = columns_reindex[column].astype(dtype)

# Use `None` to force `{}` return the correct index, see
# https://github.com/pandas-dev/pandas/issues/52404
df = pd.DataFrame(columns_reindex or None, index=index)
Expand Down
Loading

0 comments on commit 8ac6c22

Please sign in to comment.