Skip to content

Commit

Permalink
Return dataframe or series
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 19, 2023
1 parent 02fde21 commit d62c70b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 16 deletions.
28 changes: 19 additions & 9 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def duration(file: str) -> pd.Timedelta:
def get(
self,
schemes: typing.Union[str, typing.Sequence],
) -> pd.Series:
) -> typing.Union[pd.Series, pd.DataFrame]:
r"""Get labels by scheme(s).
Return all labels
Expand Down Expand Up @@ -473,7 +473,13 @@ def clean_y(y):
# TODO: at the moment we simply convert to string
# to avoid errors for different categorical data types.
# In real implementation we need to adjust those dtypes.
return y.dropna().astype('string')
# return y.astype('string')
return y

if isinstance(schemes, str):
return_series = True
else:
return_series = False

requested_schemes = audeer.to_list(schemes)

Expand Down Expand Up @@ -512,10 +518,6 @@ def clean_y(y):
for table_id, table in self.tables.items():
print(f'{table_id=}')
for column_id, column in table.columns.items():
print(f'{column_id=}')
print(f'{column.scheme_id=}')
print(f'{"speaker" in column_id=}')
print(f'{"speaker" in column.scheme_id=}')
# Get series based on scheme in column
if any(
[
Expand All @@ -541,10 +543,18 @@ def clean_y(y):
ys.append(clean_y(y))

index = utils.union([y.index for y in ys])
y = utils.concat(ys).loc[index]
y.name = ', '.join(requested_schemes)
obj = utils.concat(ys).loc[index]

return y
if not return_series and isinstance(obj, pd.Series):
obj = obj.to_frame()

if isinstance(obj, pd.DataFrame):
for requested_scheme in requested_schemes:
if requested_scheme not in obj:
# FIXME: try to get dtype from scheme
obj[requested_scheme] = pd.NA

return obj

def map_files(
self,
Expand Down
76 changes: 69 additions & 7 deletions tests/test_database_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def mono_db(tmpdir):
db['files']['channel0'] = audformat.Column(scheme_id='speaker')
db['files']['channel0'].set(['s1', 's2', 's3'])

# Filewise table with sex
index = audformat.filewise_index(['f1.wav', 'f3.wav'])
db['gender'] = audformat.Table(index)
db['gender']['sex'] = audformat.Column()
db['gender']['sex'].set(['female', 'male'])

db.save(path)
audformat.testing.create_audio_files(db, channels=1, file_duration='1s')

Expand Down Expand Up @@ -75,19 +81,75 @@ def stereo_db(tmpdir):
[
(
'mono_db',
['sex', 'gender'],
'gender',
pd.Series(
['female', '', 'male'],
index=audformat.filewise_index(['f1.wav', 'f2.wav', 'f3.wav']),
dtype='string',
name='sex, gender',
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=['female', '', 'male'],
ordered=False,
),
name='gender',
),
),
(
'mono_db',
['sex', 'gender'],
pd.concat(
[
pd.Series(
['female', '', 'male'],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=['female', '', 'male'],
ordered=False,
),
name='gender',
),
pd.Series(
['female', np.NaN, 'male'],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype='object',
name='sex',
),
],
axis=1,
)
),
(
'mono_db',
'age',
pd.Series(
[23, np.NaN, 59],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=[23.0, 59.0],
ordered=False,
),
name='age',
),
),
]
)
def test_database_get(request, db, schemes, expected):
db = request.getfixturevalue(db)
y = db.get(schemes)
print(f'{y=}')
print(f"{db['files']['channel0'].get(map='age')=}")
obj = db.get(schemes)
print(f'{expected=}')
pd.testing.assert_series_equal(y, expected)
print(f'{obj=}')
if isinstance(expected, pd.Series):
pd.testing.assert_series_equal(obj, expected)
else:
print(f'{expected["gender"]=}')
print(f'{obj["gender"]=}')
print(f'{expected["sex"]=}')
print(f'{obj["sex"]=}')
pd.testing.assert_frame_equal(obj, expected)

0 comments on commit d62c70b

Please sign in to comment.