Skip to content

Commit

Permalink
Add test for error
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 20, 2023
1 parent 335d7bd commit 2a222fc
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions tests/test_database_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,6 @@ def stereo_db(tmpdir):
def test_database_get(request, db, schemes, expected):
db = request.getfixturevalue(db)
df = db.get(schemes)
print(f'{expected=}')
print(f'{df=}')
pd.testing.assert_frame_equal(df, expected)


Expand Down Expand Up @@ -669,6 +667,26 @@ def test_database_get_limit_search(
):
db = request.getfixturevalue(db)
df = db.get(schemes, tables=tables, splits=splits)
print(f'{expected=}')
print(f'{df=}')
pd.testing.assert_frame_equal(df, expected)


def test_database_get_errors():

# Scheme with different categorical dtypes
db = audformat.Database('db')
db.schemes['label'] = audformat.Scheme('int', labels=[0, 1])
db['speaker'] = audformat.MiscTable(
pd.Index(['s1', 's2'], dtype='string', name='speaker')
)
db['speaker']['label'] = audformat.Column()
db['speaker']['label'].set([1.0, 1.0])
db['files'] = audformat.Table(audformat.filewise_index(['f1', 'f2']))
db['files']['label'] = audformat.Column(scheme_id='label')
db['files']['label'].set([0, 1])
db['other'] = audformat.Table(audformat.filewise_index(['f1', 'f2']))
db.schemes['speaker'] = audformat.Scheme('str', labels='speaker')
db['other']['speaker'] = audformat.Column(scheme_id='speaker')
db['other']['speaker'].set(['s1', 's2'])
error_msg = 'All categorical data must have the same dtype.'
with pytest.raises(ValueError, match=error_msg):
db.get('label')

0 comments on commit 2a222fc

Please sign in to comment.