Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup utils.concat() with aggregate_function #405

Merged
merged 20 commits into from
Nov 21, 2023

Conversation

hagenw
Copy link
Member

@hagenw hagenw commented Nov 8, 2023

This adds the aggregate_strategy argument to audformat.utils.concat() to specify when aggregate_function should be applied. aggregate_strategy='overlap' is the old behavior. In addition aggregate_strategy='mismatch' is added which applies the aggregate function only to index entries for which values cannot be joined. aggregate_strategy='mismatch' can be faster by a factor up to >10x for large databases with lots of matching data overlap as we have encountered when testing #399.

image

@hagenw hagenw requested a review from ChristianGeng November 8, 2023 15:52
@hagenw hagenw marked this pull request as draft November 8, 2023 15:56
@hagenw
Copy link
Member Author

hagenw commented Nov 8, 2023

@ChristianGeng sorry for requesting the review already.

I need to first update the tests and double check that the new behavior is really what we want to have.
Marked the pull request as draft until I'm ready.

@frankenjoe
Copy link
Collaborator

So, I think in practice we will gain more by applying aggregate_function only to non matching entries.

Ok, so it means that if you apply e.g. sum() and you have matching entries the result will now change, right?

@hagenw
Copy link
Member Author

hagenw commented Nov 9, 2023

The implementation I propose here makes it much harder to predict what you get back as if you have two different columns that have randomly the same label it will not apply aggregate_function. So I would recommend to stay with the current approach to apply aggregate_function to all the index entries that have more than one matching column.

@hagenw
Copy link
Member Author

hagenw commented Nov 9, 2023

The implementation I propose here makes it much harder to predict what you get back as if you have two different columns that have randomly the same label it will not apply aggregate_function. So I would recommend to stay with the current approach to apply aggregate_function to all the index entries that have more than one matching column.

My comment was too pessimistic. The new changes will still apply aggregate_function to all values of the overlapping columns. Only if the two columns can completely be joined without the need to apply aggregate_function it is now skipped.

Copy link

codecov bot commented Nov 9, 2023

Codecov Report

Merging #405 (aa0cc73) into main (fd9c481) will not change coverage.
Report is 1 commits behind head on main.
The diff coverage is 100.0%.

Additional details and impacted files
Files Coverage Δ
audformat/core/utils.py 100.0% <100.0%> (ø)

@hagenw hagenw requested a review from frankenjoe November 9, 2023 11:14
@hagenw hagenw marked this pull request as ready for review November 9, 2023 11:15
@frankenjoe
Copy link
Collaborator

My comment was too pessimistic. The new changes will still apply aggregate_function to all values of the overlapping columns. Only if the two columns can completely be joined without the need to apply aggregate_function it is now skipped.

But it still means that sum() will return something different than before, right?

@frankenjoe
Copy link
Collaborator

frankenjoe commented Nov 9, 2023

Here's a simple example:

import numpy as np
import pandas as pd

import audformat

index = audformat.filewise_index(['f1', 'f2'])
y = pd.Series([1, 2], index)
audformat.utils.concat([y, y], aggregate_function=np.sum)
file
f1    1
f2    2
dtype: Int64

Before it returned:

file
f1    2
f2    4
dtype: Int64

@frankenjoe
Copy link
Collaborator

This might not be an issue since aggregate_function is a new feature. I just wonder if it can lead to unexpected behavior.

@hagenw
Copy link
Member Author

hagenw commented Nov 9, 2023

Yes, so far there was no release, so we can change aggregate_function however we want. We just need to decide which behavior we would like to have.

The behavior currently in main applies aggregate_function to all samples that are stored in different columns, whereas the new behavior proposed in this pull request only applies aggregate_function to all samples that come from two columns that cannot be joined. I think the new behavior makes more sense, as otherwise the output of aggregate_function depends if labels are stored in several tables or not (e.g. having an unbalanced and balanced test set where the later is a sub-set of the former).

@frankenjoe
Copy link
Collaborator

frankenjoe commented Nov 10, 2023

In principle I agree. It still find it strange that now the following can happen:

index = audformat.filewise_index(['f1', 'f2'])
y1 = pd.Series([1, 1], index)
y2 = pd.Series([1, 2], index)
audformat.utils.concat([y1, y2], aggregate_function=np.sum)
file
f1    2
f2    3
dtype: Int64
file
f1    2
f2    3
dtype: Int64

but:

audformat.utils.concat([y1[:1], y2[:1]], aggregate_function=np.sum)
file
f1    1
dtype: Int64

@frankenjoe
Copy link
Collaborator

So maybe we need an additional argument to control when the aggregate_function should be applied?

@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

I agree that your example looks very counter-intuitive to a user.

I'm also not really in favor of adding another argument. Maybe we simply close this pull request.
In #399 I solved the issue by first trying to join without aggregate function and only using it when really needed:

try:
    y = audformat.utils.concat(ys)
except ValueError:
    y = audformat.utils.concat(ys, aggregate_function=aggregate_function)

@frankenjoe
Copy link
Collaborator

by first trying to join without aggregate function and only using it when really needed:

In the worst case when the error is raised late, does it mean the function is executed more or less twice then?

I'm also not really in favor of adding another argument.

Mhh, ok. To me it seems to make sense, it could have three values: 1. apply always 2. apply on all duplicates 3. apply only on non-matching duplicates.

@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

by first trying to join without aggregate function and only using it when really needed:

In the worst case when the error is raised late, does it mean the function is executed more or less twice then?

Good point, haven't thought about that.
Yes, if the first columns can be joined without error, but only the last column raises an error it will indeed execute more or less twice then.

I'm also not really in favor of adding another argument.

Mhh, ok. To me it seems to make sense, it could have three values: 1. apply always 2. apply on all duplicates 3. apply only on non-matching duplicates.

OK, I will have a look into it.

@hagenw hagenw marked this pull request as draft November 15, 2023 11:05
@hagenw hagenw marked this pull request as ready for review November 15, 2023 12:00
@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

I have now added the aggregate argument, that can select between always and non-matching for when to apply aggregate_function.

Your third suggestion of duplicates we cannot add as we need to adjust always non-matching labels, otherwise we cannot join the columns.

@hagenw hagenw marked this pull request as draft November 15, 2023 12:07
@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

Sorry, there is still something not behaving as expected, will first take another look.

@hagenw hagenw marked this pull request as ready for review November 15, 2023 12:26
@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

I fixed the remaining problem.

Now we have the following behavior:

concat(
    [
        pd.Series([1, 1], index=pd.Index([0, 1])),
        pd.Series([1, 1], index=pd.Index([0, 1])),
    ],
    aggregate_function=np.sum,
    aggregate='always',
)

returns

0    2
1    2
dtype: Int64
concat(
    [
        pd.Series([1, 1], index=pd.Index([0, 1])),
        pd.Series([1, 1], index=pd.Index([0, 1])),
    ],
    aggregate_function=np.sum,
    aggregate='non-matching',
)

returns

0    1
1    1
dtype: Int64
concat(
    [
        pd.Series([1, 1], index=pd.Index([0, 1])),
        pd.Series([1, 2], index=pd.Index([0, 1])),
    ],
    aggregate_function=np.sum,
    aggregate='non-matching',
)

returns

0    1
1    3
dtype: Int64

There is a third option missing, but I also think that is too hard to grasp:

concat(
    [
        pd.Series([1, 1], index=pd.Index([0, 1])),
        pd.Series([1, 1], index=pd.Index([0, 1])),
    ],
    aggregate_function=np.sum,
    aggregate='third-option',
)

returns

0    1
1    1
dtype: Int64
concat(
    [
        pd.Series([1, 1], index=pd.Index([0, 1])),
        pd.Series([1, 2], index=pd.Index([0, 1])),
    ],
    aggregate_function=np.sum,
    aggregate='third-option',
)

returns

0    2
1    3
dtype: Int64

@hagenw
Copy link
Member Author

hagenw commented Nov 15, 2023

I just realised that the third option is actually what we want to have in audformat.Database.get(). I will add it under the name 'when-non-matching' and will rename 'non-matching' to 'only-non-matching'.

@frankenjoe
Copy link
Collaborator

frankenjoe commented Nov 15, 2023

I think it should be 'not-maching' or maybe better 'mismatch'. But I am also not sure if I would understand the difference between 'when-not-matching' and 'only-not-matching'.

So in the docstring we usually talk about duplicates. So maybe we could rename to 'always', 'duplicates' and 'mismatch'?

@frankenjoe
Copy link
Collaborator

I.e., we then have:

  • 'always': apply to every file / segment
  • 'duplicates': apply when files / segments overlap
  • 'mismatch': apply when files / segments overlap and their values do not match

@hagenw hagenw marked this pull request as draft November 15, 2023 20:21
@hagenw
Copy link
Member Author

hagenw commented Nov 16, 2023

For always I'm not completely sure what to expect. Should it first expand the columns with NaN to make sure they have the same number of entries and then apply the aggregate function or should it work on the given number of values that are available for each index entry?

E.g. for

audformat.utils.concat(
    [
        pd.Series([1, 1], index=pd.Index(['b', 'c'])),
        pd.Series([2, 3, 4], index=pd.Index(['a', 'b', 'c'])),
    ],
    aggregate_function=lambda y: y[0],
    aggregate='always',
)

the first solution would return

a    NaN
b    1
c    1
dtype: Int64

and the second solution

b    1
c    1
a    2
dtype: Int64

The other question is: do we really need the option to apply the aggregate function to non-overlapping entries?


But I am also not sure if I would understand the difference between 'when-not-matching' and 'only-not-matching'.

I guess that's the reason why it is not covered by your three cases ;)
My current solution in #399 is not covered by duplicates nor by mismatch, but would require duplicates-when-mismatch.

Let's assume we have one case where all values match and another, where one value is different:

Matching case:

objs = [
    pd.Series([1, 1]),
    pd.Series([1, 1]),
]

Non-matching case:

objs = [
    pd.Series([1, 1]),
    pd.Series([1, 2]),
]

Here is what the three different option for aggregate would return for aggregate_function=np.sum:

case duplicates mismatch duplicates-when-mismatch
matching [2, 2] [1, 1] [1, 1]
non-matching [2, 3] [1, 3] [2, 3]

@frankenjoe
Copy link
Collaborator

For always I'm not completely sure what to expect. Should it first expand the columns with NaN to make sure they have the same number of entries and then apply the aggregate function or should it work on the given number of values that are available for each index entry?

The latter, for segments with overlap we also do not expand I think.

The other question is: do we really need the option to apply the aggregate function to non-overlapping entries?

If it is too complicated to implement we can skip it. But you could use it to count the number of overlaps for every segment for instance.

I guess that's the reason why it is not covered by your three cases ;)
My current solution in #399 is not covered by duplicates nor by mismatch, but would require duplicates-when-mismatch.

Mhh, I am still not sure if I understand what duplicates-when-mismatch does :) Is it that as soon as there is one value mismatch for one segment that overlaps that then it applies the aggregate function to all segments with overlap? Why exactly do we need this case?

Btw: probably "overlap" would be a better name than "duplicates", as the latter implies that also values have to match, which is not what we want.

@hagenw
Copy link
Member Author

hagenw commented Nov 17, 2023

Is it that as soon as there is one value mismatch for one segment that overlaps that then it applies the aggregate function to all segments with overlap? Why exactly do we need this case?

Exactly. For audformat.Database.get() we want to apply the aggregate function only if needed, e.g. a column has at least one value with a mismatch. My main reason to apply it to all values of the column in that case is to better indicate to the user which parts have overlap (e.g. when tuple was chosen as aggregate function) as it might be the case that the label is the same due to randomness, e.g. same gender for speaker of left and right channel. It can still happen, but it is less likely.

The alternative would be to always use duplicates, but this would slow down get() as we have a lot of tables that simply repeat entries found in other tables.

@frankenjoe
Copy link
Collaborator

Mhh, ok. But shouldn't we let the user give the option to decide which strategy should be used?

@hagenw
Copy link
Member Author

hagenw commented Nov 17, 2023

Good idea, then we set the default to mismatch in Database.get() and the user can change this if it is really needed.
OK, then let's stay with the current 3 options, and rename duplicates to overlap as you suggested.

@frankenjoe
Copy link
Collaborator

I think also here we could set the default to mismatch as this is the most common use-case.

@hagenw hagenw marked this pull request as ready for review November 17, 2023 11:44
@hagenw
Copy link
Member Author

hagenw commented Nov 17, 2023

I implemented now the suggested changes, but skipped 'always' as the desired result was not achievable with the current code. If we still want to have it we should open an issue for it.
I selected now 'mismatch' as the default behavior.

audformat/core/utils.py Outdated Show resolved Hide resolved
@frankenjoe
Copy link
Collaborator

Very cool, looks like we are done here.

@frankenjoe frankenjoe merged commit 9a7b944 into main Nov 21, 2023
10 checks passed
@frankenjoe frankenjoe deleted the speed-up-aggregate-function branch November 21, 2023 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants