diff --git a/audformat/core/database.py b/audformat/core/database.py index 4162785d..ca6a261b 100644 --- a/audformat/core/database.py +++ b/audformat/core/database.py @@ -27,6 +27,8 @@ from audformat.core.errors import BadKeyError from audformat.core.errors import TableExistsError from audformat.core.index import filewise_index +from audformat.core.index import is_filewise_index +from audformat.core.index import is_segmented_index from audformat.core.media import Media from audformat.core.rater import Rater from audformat.core.scheme import Scheme @@ -867,6 +869,21 @@ def scheme_in_column(scheme_id, column, column_id): original_column_names=original_column_names, aggregate_function=aggregate_function, ) + # Expand filewise labels to segments + if is_segmented_index(obj) and is_filewise_index(additional_obj): + print(f"{obj=}") + print(f"{obj.index.get_level_values(define.IndexField.FILE)=}") + common_files = obj.index.get_level_values(define.IndexField.FILE) + # Problem of next line: + # squeezes [f1, f1, f2] to [f1, f2] + # so we need to separate intersection from listing files + # common_files = utils.intersect([additional_obj.index, files]) + print(f"{files=}") + additional_obj = additional_obj.loc[files] + print(f"{additional_obj=}") + print(f"{obj.loc[files]=}") + print(f"{obj.loc[files].index=}") + additional_obj.index = obj.loc[files].index objs.append(additional_obj) if len(objs) > 1: obj = utils.concat(objs) diff --git a/tests/test_database_get.py b/tests/test_database_get.py index 83ebd312..ec17c6eb 100644 --- a/tests/test_database_get.py +++ b/tests/test_database_get.py @@ -23,6 +23,7 @@ def mono_db(tmpdir): db.schemes["age"] = audformat.Scheme("int", minimum=0) db.schemes["height"] = audformat.Scheme("float") db.schemes["int"] = audformat.Scheme("int") + db.schemes["partial"] = audformat.Scheme("str") db.schemes["rating"] = audformat.Scheme("int", labels=[0, 1, 2]) db.schemes["regression"] = audformat.Scheme("float") db.schemes["selection"] = audformat.Scheme("int", labels=[0, 1]) @@ -103,6 +104,8 @@ def mono_db(tmpdir): db["files.sub"]["text"].set("a") db["files.sub"]["numbers"] = audformat.Column(scheme_id="int") db["files.sub"]["numbers"].set(0) + db["files.sub"]["partial"] = audformat.Column(scheme_id="partial") + db["files.sub"]["partial"].set("a") index = audformat.filewise_index(["f1.wav", "f3.wav"]) db["other"] = audformat.Table(index) @@ -542,6 +545,69 @@ def wrong_scheme_labels_db(tmpdir): dtype="float", ), ), + ( + "mono_db", + "regression", + ["speaker"], + pd.concat( + [ + pd.Series( + [0.3, 0.2, 0.6, 0.4], + index=audformat.segmented_index( + ["f1.wav", "f1.wav", "f1.wav", "f2.wav"], + [0, 0.1, 0.3, 0], + [0.2, 0.2, 0.5, 0.7], + ), + dtype="float", + name="regression", + ), + pd.Series( + ["s1", "s1", "s1", "s2"], + index=audformat.segmented_index( + ["f1.wav", "f1.wav", "f1.wav", "f2.wav"], + [0, 0.1, 0.3, 0], + [0.2, 0.2, 0.5, 0.7], + ), + dtype=pd.CategoricalDtype( + ["s1", "s2", "s3"], + ordered=False, + ), + name="speaker", + ), + ], + axis=1, + ), + ), + ( + "mono_db", + "regression", + ["partial"], + pd.concat( + [ + pd.Series( + [0.3, 0.2, 0.6, 0.4], + index=audformat.segmented_index( + ["f1.wav", "f1.wav", "f1.wav", "f2.wav"], + [0, 0.1, 0.3, 0], + [0.2, 0.2, 0.5, 0.7], + ), + dtype="float", + name="regression", + ), + pd.Series( + ["a", "a", "a", None], + index=audformat.segmented_index( + ["f1.wav", "f1.wav", "f1.wav", "f2.wav"], + [0, 0.1, 0.3, 0], + [0.2, 0.2, 0.5, 0.7], + ), + dtype="string", + name="partial", + ), + ], + axis=1, + ), + ), ( "mono_db", "selection",