From d62c70bca9f6032e1105a83576a313ac2545a89f Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Thu, 19 Oct 2023 14:00:02 +0200 Subject: [PATCH] Return dataframe or series --- audformat/core/database.py | 28 +++++++++----- tests/test_database_get.py | 76 ++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 16 deletions(-) diff --git a/audformat/core/database.py b/audformat/core/database.py index 28e8b3ca..57135fff 100644 --- a/audformat/core/database.py +++ b/audformat/core/database.py @@ -436,7 +436,7 @@ def duration(file: str) -> pd.Timedelta: def get( self, schemes: typing.Union[str, typing.Sequence], - ) -> pd.Series: + ) -> typing.Union[pd.Series, pd.DataFrame]: r"""Get labels by scheme(s). Return all labels @@ -473,7 +473,13 @@ def clean_y(y): # TODO: at the moment we simply convert to string # to avoid errors for different categorical data types. # In real implementation we need to adjust those dtypes. - return y.dropna().astype('string') + # return y.astype('string') + return y + + if isinstance(schemes, str): + return_series = True + else: + return_series = False requested_schemes = audeer.to_list(schemes) @@ -512,10 +518,6 @@ def clean_y(y): for table_id, table in self.tables.items(): print(f'{table_id=}') for column_id, column in table.columns.items(): - print(f'{column_id=}') - print(f'{column.scheme_id=}') - print(f'{"speaker" in column_id=}') - print(f'{"speaker" in column.scheme_id=}') # Get series based on scheme in column if any( [ @@ -541,10 +543,18 @@ def clean_y(y): ys.append(clean_y(y)) index = utils.union([y.index for y in ys]) - y = utils.concat(ys).loc[index] - y.name = ', '.join(requested_schemes) + obj = utils.concat(ys).loc[index] - return y + if not return_series and isinstance(obj, pd.Series): + obj = obj.to_frame() + + if isinstance(obj, pd.DataFrame): + for requested_scheme in requested_schemes: + if requested_scheme not in obj: + # FIXME: try to get dtype from scheme + obj[requested_scheme] = pd.NA + + return obj def map_files( self, diff --git a/tests/test_database_get.py b/tests/test_database_get.py index 6675fbd7..37eaef1c 100644 --- a/tests/test_database_get.py +++ b/tests/test_database_get.py @@ -33,6 +33,12 @@ def mono_db(tmpdir): db['files']['channel0'] = audformat.Column(scheme_id='speaker') db['files']['channel0'].set(['s1', 's2', 's3']) + # Filewise table with sex + index = audformat.filewise_index(['f1.wav', 'f3.wav']) + db['gender'] = audformat.Table(index) + db['gender']['sex'] = audformat.Column() + db['gender']['sex'].set(['female', 'male']) + db.save(path) audformat.testing.create_audio_files(db, channels=1, file_duration='1s') @@ -75,19 +81,75 @@ def stereo_db(tmpdir): [ ( 'mono_db', - ['sex', 'gender'], + 'gender', pd.Series( ['female', '', 'male'], - index=audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']), - dtype='string', - name='sex, gender', + index=audformat.filewise_index( + ['f1.wav', 'f2.wav', 'f3.wav'] + ), + dtype=pd.CategoricalDtype( + categories=['female', '', 'male'], + ordered=False, + ), + name='gender', + ), + ), + ( + 'mono_db', + ['sex', 'gender'], + pd.concat( + [ + pd.Series( + ['female', '', 'male'], + index=audformat.filewise_index( + ['f1.wav', 'f2.wav', 'f3.wav'] + ), + dtype=pd.CategoricalDtype( + categories=['female', '', 'male'], + ordered=False, + ), + name='gender', + ), + pd.Series( + ['female', np.NaN, 'male'], + index=audformat.filewise_index( + ['f1.wav', 'f2.wav', 'f3.wav'] + ), + dtype='object', + name='sex', + ), + ], + axis=1, + ) + ), + ( + 'mono_db', + 'age', + pd.Series( + [23, np.NaN, 59], + index=audformat.filewise_index( + ['f1.wav', 'f2.wav', 'f3.wav'] + ), + dtype=pd.CategoricalDtype( + categories=[23.0, 59.0], + ordered=False, + ), + name='age', ), ), ] ) def test_database_get(request, db, schemes, expected): db = request.getfixturevalue(db) - y = db.get(schemes) - print(f'{y=}') + print(f"{db['files']['channel0'].get(map='age')=}") + obj = db.get(schemes) print(f'{expected=}') - pd.testing.assert_series_equal(y, expected) + print(f'{obj=}') + if isinstance(expected, pd.Series): + pd.testing.assert_series_equal(obj, expected) + else: + print(f'{expected["gender"]=}') + print(f'{obj["gender"]=}') + print(f'{expected["sex"]=}') + print(f'{obj["sex"]=}') + pd.testing.assert_frame_equal(obj, expected)