diff --git a/pyannote/database/loader.py b/pyannote/database/loader.py index e6a0150..c618bd2 100644 --- a/pyannote/database/loader.py +++ b/pyannote/database/loader.py @@ -32,7 +32,7 @@ from typing import Text from pathlib import Path import string -from pyannote.database.util import load_rttm, load_uem +from pyannote.database.util import load_rttm, load_uem, load_lab import pandas as pd from pyannote.core import Segment, Timeline, Annotation from pyannote.database import ProtocolFile @@ -187,6 +187,42 @@ def __call__(self, file: ProtocolFile) -> Timeline: return self.loaded_[uri] +class LABLoader: + """LAB loader + + Parameters + ---------- + path : str + Path to LAB file with mandatory {uri} placeholder. + (e.g. "/path/to/{uri}.lab") + + each .lab file contains the segments for a single audio file, in the following format: + start end label + + ex. + 0.0 12.3456 sing + 12.3456 15.0 nosing + ... + """ + + def __init__(self, path: Text = None): + super().__init__() + + self.path = str(path) + + _, placeholders, _, _ = zip(*string.Formatter().parse(self.path)) + self.placeholders_ = set(placeholders) - set([None]) + if "uri" not in self.placeholders_: + raise ValueError("`path` must contain the {uri} placeholder.") + + def __call__(self, file: ProtocolFile) -> Annotation: + + uri = file["uri"] + + sub_file = {key: file[key] for key in self.placeholders_} + return load_lab(self.path.format(**sub_file), uri=uri) + + class CTMLoader: """CTM loader diff --git a/pyannote/database/util.py b/pyannote/database/util.py index 06b67f7..d42ff09 100644 --- a/pyannote/database/util.py +++ b/pyannote/database/util.py @@ -382,6 +382,30 @@ def load_uem(file_uem): return timelines +def load_lab(path, uri: str = None) -> Annotation: + """Load LAB file + + Parameter + --------- + file_lab : `str` + Path to LAB file + + Returns + ------- + data : `pyannote.core.Annotation` + """ + + names = ["start", "end", "label"] + dtype = {"start": float, "end": float, "label": str} + data = pd.read_csv(path, names=names, dtype=dtype, delim_whitespace=True) + + annotation = Annotation(uri=uri) + for i, turn in data.iterrows(): + segment = Segment(turn.start, turn.end) + annotation[segment, i] = turn.label + + return annotation + def load_lst(file_lst): """Load LST file diff --git a/setup.py b/setup.py index 59ec3c6..b77de9b 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ ".uem = pyannote.database.loader:UEMLoader", ".ctm = pyannote.database.loader:CTMLoader", ".map = pyannote.database.loader:MAPLoader", + ".lab = pyannote.database.loader:LABLoader" ], }, # versioneer