Skip to content

Commit

Permalink
Handle dtype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 25, 2023
1 parent 8ac6c22 commit 957a577
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
6 changes: 5 additions & 1 deletion audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,11 @@ def concat(
)
dtype = columns_reindex[column].dtype
columns_reindex[column] = df.apply(aggregate_function, axis=1)
columns_reindex[column] = columns_reindex[column].astype(dtype)
# Restore the original dtype if possible
try:
columns_reindex[column] = columns_reindex[column].astype(dtype)
except (TypeError, ValueError):
pass

# Use `None` to force `{}` return the correct index, see
# https://github.com/pandas-dev/pandas/issues/52404
Expand Down
60 changes: 60 additions & 0 deletions tests/test_utils_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,66 @@ def test_concat(objs, overwrite, expected):
np.var,
pd.Series([0, 0], pd.Index(['a', 'b']), dtype='float'),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
],
lambda x: 'a',
pd.Series(
['a', 'a'],
pd.Index(['a', 'b']),
dtype='object',
),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
],
lambda x: 0,
pd.Series(
[0, 0],
pd.Index(['a', 'b']),
dtype='float',
),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'),
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'),
],
lambda x: 0,
pd.Series(
[0, 0],
pd.Index(['a', 'b']),
dtype='Int64',
),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'),
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='int'),
],
lambda x: 0.5,
pd.Series(
[0.5, 0.5],
pd.Index(['a', 'b']),
dtype='float',
),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
],
lambda x: ('a', 'b'),
pd.Series(
[('a', 'b'), ('a', 'b')],
pd.Index(['a', 'b']),
dtype='object',
),
),
(
[
pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'),
Expand Down

0 comments on commit 957a577

Please sign in to comment.