Skip to content

Commit

Permalink
Add audformat.Database.get()
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 19, 2023
1 parent d461614 commit 02fde21
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
113 changes: 113 additions & 0 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,119 @@ def duration(file: str) -> pd.Timedelta:

return y

def get(
self,
schemes: typing.Union[str, typing.Sequence],
) -> pd.Series:
r"""Get labels by scheme(s).
Return all labels
that match the requested schemes.
A scheme is defined more broadly
and does not only match
schemes of the database,
but also columns with the same name
or labels of a scheme with the requested name.
Args:
schemes: scheme or sequence of scheme.
Returns:
labels
"""

def scheme_in_column(scheme, column, column_id):
# Check if scheme
# is attached to a column,
# or is part of the column name
return (
scheme in column_id
or (
column.scheme_id is not None
and scheme in column.scheme_id
)
)

def clean_y(y):
# Remove NaN and normalize dtypes
#
# TODO: at the moment we simply convert to string
# to avoid errors for different categorical data types.
# In real implementation we need to adjust those dtypes.
return y.dropna().astype('string')

requested_schemes = audeer.to_list(schemes)

print(f'{requested_schemes=}')
# Check if requested schemes
# are stored as labels in other schemes
scheme_mappings = []
for scheme_id, scheme in self.schemes.items():

print(f'{scheme_id=}')
if scheme.uses_table and scheme_id in self.misc_tables:
print('scheme stored in misc table')
# Labels stored as misc table
for column_id, column in self[scheme_id].columns.items():
print(f'{column_id=}')
for requested_scheme in requested_schemes:
if scheme_in_column(
requested_scheme, column, column_id
):
scheme_mappings.append(
(scheme_id, requested_scheme)
)
break
else:
# Labels stored in scheme
for requested_scheme in requested_schemes:
print('scheme stored in labels')
if scheme_id in scheme.labels_as_list:
scheme_mappings.append((scheme_id, requested_scheme))
break
print(f'{scheme_mappings=}')
print(f'{requested_schemes=}')

# Get data points for requested schemes
ys = []
for table_id, table in self.tables.items():
print(f'{table_id=}')
for column_id, column in table.columns.items():
print(f'{column_id=}')
print(f'{column.scheme_id=}')
print(f'{"speaker" in column_id=}')
print(f'{"speaker" in column.scheme_id=}')
# Get series based on scheme in column
if any(
[
scheme_in_column(
requested_scheme,
column,
column_id,
)
for requested_scheme in requested_schemes
]
):
print('Found requested scheme')
y = self[table_id][column_id].get()
ys.append(clean_y(y))
# Get series based on label of scheme
else:
print('Look for mapped schemes')
for (scheme_id, mapping) in scheme_mappings:
print(f'{scheme_id=}')
if scheme_in_column(scheme_id, column, column_id):
print('Found scheme')
y = self[table_id][column_id].get(map=mapping)
ys.append(clean_y(y))

index = utils.union([y.index for y in ys])
y = utils.concat(ys).loc[index]
y.name = ', '.join(requested_schemes)

return y

def map_files(
self,
func: typing.Callable[[str], str],
Expand Down
93 changes: 93 additions & 0 deletions tests/test_database_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np
import pandas as pd
import pytest

import audeer

import audformat
import audformat.testing


@pytest.fixture(scope='function', autouse=True)
def mono_db(tmpdir):
r"""Database with ..."""
name = 'mono-db'
path = audeer.mkdir(audeer.path(tmpdir, name))
db = audformat.Database(name)

db.schemes['age'] = audformat.Scheme('int', minimum=0)

# Misc table with speaker information
index = pd.Index(['s1', 's2', 's3'], name='speaker', dtype='string')
db['speaker'] = audformat.MiscTable(index)
db['speaker']['age'] = audformat.Column(scheme_id='age')
db['speaker']['gender'] = audformat.Column()
db['speaker']['age'].set([23, np.NaN, 59])
db['speaker']['gender'].set(['female', '', 'male'])

db.schemes['speaker'] = audformat.Scheme('str', labels='speaker')

# Filewise table with speaker information
index = audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav'])
db['files'] = audformat.Table(index)
db['files']['channel0'] = audformat.Column(scheme_id='speaker')
db['files']['channel0'].set(['s1', 's2', 's3'])

db.save(path)
audformat.testing.create_audio_files(db, channels=1, file_duration='1s')

return db


@pytest.fixture(scope='function')
def stereo_db(tmpdir):
r"""Database with ..."""
name = 'stereo-db'
path = audeer.mkdir(audeer.path(tmpdir, name))
db = audformat.Database(name)

db.schemes['age'] = audformat.Scheme('int', minimum=0)

# Misc table with speaker information
index = pd.Index(['s1', 's2', 's3'], name='speaker', dtype='string')
db['speaker'] = audformat.MiscTable(index)
db['speaker']['age'] = audformat.Column(scheme_id='age')
db['speaker']['gender'] = audformat.Column()
db['speaker']['age'].set([23, np.NaN, 59])
db['speaker']['gender'].set(['female', '', 'male'])

db.schemes['speaker'] = audformat.Scheme('str', labels='speaker')

# Filewise table with speaker information
index = audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav'])
db['files'] = audformat.Table(index)
db['files']['channel0'] = audformat.Column(scheme_id='speaker')
db['files']['channel1'] = audformat.Column(scheme_id='speaker')
db['files']['channel0'].set(['s1', 's2', 's3'])
db['files']['channel1'].set(['s3', 's1', 's2'])

db.save(path)
audformat.testing.create_audio_files(db, channels=2, file_duration='1s')


@pytest.mark.parametrize(
'db, schemes, expected',
[
(
'mono_db',
['sex', 'gender'],
pd.Series(
['female', '', 'male'],
index=audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']),
dtype='string',
name='sex, gender',
),
),
]
)
def test_database_get(request, db, schemes, expected):
db = request.getfixturevalue(db)
y = db.get(schemes)
print(f'{y=}')
print(f'{expected=}')
pd.testing.assert_series_equal(y, expected)

0 comments on commit 02fde21

Please sign in to comment.