diff --git a/audformat/core/database.py b/audformat/core/database.py index 9aa7e55f..28e8b3ca 100644 --- a/audformat/core/database.py +++ b/audformat/core/database.py @@ -433,6 +433,119 @@ def duration(file: str) -> pd.Timedelta: return y + def get( + self, + schemes: typing.Union[str, typing.Sequence], + ) -> pd.Series: + r"""Get labels by scheme(s). + + Return all labels + that match the requested schemes. + A scheme is defined more broadly + and does not only match + schemes of the database, + but also columns with the same name + or labels of a scheme with the requested name. + + Args: + schemes: scheme or sequence of scheme. + + Returns: + labels + + """ + + def scheme_in_column(scheme, column, column_id): + # Check if scheme + # is attached to a column, + # or is part of the column name + return ( + scheme in column_id + or ( + column.scheme_id is not None + and scheme in column.scheme_id + ) + ) + + def clean_y(y): + # Remove NaN and normalize dtypes + # + # TODO: at the moment we simply convert to string + # to avoid errors for different categorical data types. + # In real implementation we need to adjust those dtypes. + return y.dropna().astype('string') + + requested_schemes = audeer.to_list(schemes) + + print(f'{requested_schemes=}') + # Check if requested schemes + # are stored as labels in other schemes + scheme_mappings = [] + for scheme_id, scheme in self.schemes.items(): + + print(f'{scheme_id=}') + if scheme.uses_table and scheme_id in self.misc_tables: + print('scheme stored in misc table') + # Labels stored as misc table + for column_id, column in self[scheme_id].columns.items(): + print(f'{column_id=}') + for requested_scheme in requested_schemes: + if scheme_in_column( + requested_scheme, column, column_id + ): + scheme_mappings.append( + (scheme_id, requested_scheme) + ) + break + else: + # Labels stored in scheme + for requested_scheme in requested_schemes: + print('scheme stored in labels') + if scheme_id in scheme.labels_as_list: + scheme_mappings.append((scheme_id, requested_scheme)) + break + print(f'{scheme_mappings=}') + print(f'{requested_schemes=}') + + # Get data points for requested schemes + ys = [] + for table_id, table in self.tables.items(): + print(f'{table_id=}') + for column_id, column in table.columns.items(): + print(f'{column_id=}') + print(f'{column.scheme_id=}') + print(f'{"speaker" in column_id=}') + print(f'{"speaker" in column.scheme_id=}') + # Get series based on scheme in column + if any( + [ + scheme_in_column( + requested_scheme, + column, + column_id, + ) + for requested_scheme in requested_schemes + ] + ): + print('Found requested scheme') + y = self[table_id][column_id].get() + ys.append(clean_y(y)) + # Get series based on label of scheme + else: + print('Look for mapped schemes') + for (scheme_id, mapping) in scheme_mappings: + print(f'{scheme_id=}') + if scheme_in_column(scheme_id, column, column_id): + print('Found scheme') + y = self[table_id][column_id].get(map=mapping) + ys.append(clean_y(y)) + + index = utils.union([y.index for y in ys]) + y = utils.concat(ys).loc[index] + y.name = ', '.join(requested_schemes) + + return y + def map_files( self, func: typing.Callable[[str], str], diff --git a/tests/test_database_get.py b/tests/test_database_get.py new file mode 100644 index 00000000..6675fbd7 --- /dev/null +++ b/tests/test_database_get.py @@ -0,0 +1,93 @@ +import numpy as np +import pandas as pd +import pytest + +import audeer + +import audformat +import audformat.testing + + +@pytest.fixture(scope='function', autouse=True) +def mono_db(tmpdir): + r"""Database with ...""" + name = 'mono-db' + path = audeer.mkdir(audeer.path(tmpdir, name)) + db = audformat.Database(name) + + db.schemes['age'] = audformat.Scheme('int', minimum=0) + + # Misc table with speaker information + index = pd.Index(['s1', 's2', 's3'], name='speaker', dtype='string') + db['speaker'] = audformat.MiscTable(index) + db['speaker']['age'] = audformat.Column(scheme_id='age') + db['speaker']['gender'] = audformat.Column() + db['speaker']['age'].set([23, np.NaN, 59]) + db['speaker']['gender'].set(['female', '', 'male']) + + db.schemes['speaker'] = audformat.Scheme('str', labels='speaker') + + # Filewise table with speaker information + index = audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']) + db['files'] = audformat.Table(index) + db['files']['channel0'] = audformat.Column(scheme_id='speaker') + db['files']['channel0'].set(['s1', 's2', 's3']) + + db.save(path) + audformat.testing.create_audio_files(db, channels=1, file_duration='1s') + + return db + + +@pytest.fixture(scope='function') +def stereo_db(tmpdir): + r"""Database with ...""" + name = 'stereo-db' + path = audeer.mkdir(audeer.path(tmpdir, name)) + db = audformat.Database(name) + + db.schemes['age'] = audformat.Scheme('int', minimum=0) + + # Misc table with speaker information + index = pd.Index(['s1', 's2', 's3'], name='speaker', dtype='string') + db['speaker'] = audformat.MiscTable(index) + db['speaker']['age'] = audformat.Column(scheme_id='age') + db['speaker']['gender'] = audformat.Column() + db['speaker']['age'].set([23, np.NaN, 59]) + db['speaker']['gender'].set(['female', '', 'male']) + + db.schemes['speaker'] = audformat.Scheme('str', labels='speaker') + + # Filewise table with speaker information + index = audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']) + db['files'] = audformat.Table(index) + db['files']['channel0'] = audformat.Column(scheme_id='speaker') + db['files']['channel1'] = audformat.Column(scheme_id='speaker') + db['files']['channel0'].set(['s1', 's2', 's3']) + db['files']['channel1'].set(['s3', 's1', 's2']) + + db.save(path) + audformat.testing.create_audio_files(db, channels=2, file_duration='1s') + + +@pytest.mark.parametrize( + 'db, schemes, expected', + [ + ( + 'mono_db', + ['sex', 'gender'], + pd.Series( + ['female', '', 'male'], + index=audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']), + dtype='string', + name='sex, gender', + ), + ), + ] +) +def test_database_get(request, db, schemes, expected): + db = request.getfixturevalue(db) + y = db.get(schemes) + print(f'{y=}') + print(f'{expected=}') + pd.testing.assert_series_equal(y, expected)