From 2a222fc61bb8207500608f76937e72b36850fccf Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Fri, 20 Oct 2023 14:08:53 +0200 Subject: [PATCH] Add test for error --- tests/test_database_get.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/test_database_get.py b/tests/test_database_get.py index 78431ea5..72882686 100644 --- a/tests/test_database_get.py +++ b/tests/test_database_get.py @@ -459,8 +459,6 @@ def stereo_db(tmpdir): def test_database_get(request, db, schemes, expected): db = request.getfixturevalue(db) df = db.get(schemes) - print(f'{expected=}') - print(f'{df=}') pd.testing.assert_frame_equal(df, expected) @@ -669,6 +667,26 @@ def test_database_get_limit_search( ): db = request.getfixturevalue(db) df = db.get(schemes, tables=tables, splits=splits) - print(f'{expected=}') - print(f'{df=}') pd.testing.assert_frame_equal(df, expected) + + +def test_database_get_errors(): + + # Scheme with different categorical dtypes + db = audformat.Database('db') + db.schemes['label'] = audformat.Scheme('int', labels=[0, 1]) + db['speaker'] = audformat.MiscTable( + pd.Index(['s1', 's2'], dtype='string', name='speaker') + ) + db['speaker']['label'] = audformat.Column() + db['speaker']['label'].set([1.0, 1.0]) + db['files'] = audformat.Table(audformat.filewise_index(['f1', 'f2'])) + db['files']['label'] = audformat.Column(scheme_id='label') + db['files']['label'].set([0, 1]) + db['other'] = audformat.Table(audformat.filewise_index(['f1', 'f2'])) + db.schemes['speaker'] = audformat.Scheme('str', labels='speaker') + db['other']['speaker'] = audformat.Column(scheme_id='speaker') + db['other']['speaker'].set(['s1', 's2']) + error_msg = 'All categorical data must have the same dtype.' + with pytest.raises(ValueError, match=error_msg): + db.get('label')