diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst index d6428a76..45fc1462 100644 --- a/docs/source/changelog/changelog_3.0.rst +++ b/docs/source/changelog/changelog_3.0.rst @@ -5,7 +5,15 @@ 3.0 Changelog ************* -3.1.1 +3.1.3 +----- + +- Fixed an issue where silence probability being zero was not correctly removing silence +- Compatibility with kalpy v0.6.5 +- Added API functionality for verifying transcripts with interjection words in alignment +- Fixed an error in fine tuning that generated nonsensical boundaries + +3.1.2 ----- - Fixed a bug where hidden files and folders would be parsed as corpus data @@ -13,6 +21,10 @@ - Fixed a rare crash in training when a job would not have utterances assigned to it - Fixed a bug where MFA would mistakenly report a dictionary and acoustic model phones did not match for older versions +3.1.1 +----- + +- Fixed an issue with TextGrids missing intervals 3.1.0 ----- diff --git a/montreal_forced_aligner/acoustic_modeling/lda.py b/montreal_forced_aligner/acoustic_modeling/lda.py index fa9c8c09..dbd3e454 100644 --- a/montreal_forced_aligner/acoustic_modeling/lda.py +++ b/montreal_forced_aligner/acoustic_modeling/lda.py @@ -148,7 +148,7 @@ def __init__(self, args: CalcLdaMlltArguments): self.model_path = args.model_path self.lda_options = args.lda_options - def _run(self) -> typing.Generator[int]: + def _run(self) -> None: """Run the function""" # Estimating MLLT with self.session() as session, thread_logger( diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index c663e0d3..3847981e 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -87,7 +87,7 @@ def __init__(self, args: TransitionAccArguments): self.working_directory = args.working_directory self.model_path = args.model_path - def _run(self) -> typing.Generator[typing.Tuple[int, str]]: + def _run(self) -> None: """Run the function""" with self.session() as session: diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py index b0d6434f..33c75fcf 100644 --- a/montreal_forced_aligner/acoustic_modeling/triphone.py +++ b/montreal_forced_aligner/acoustic_modeling/triphone.py @@ -80,7 +80,7 @@ def __init__(self, args: ConvertAlignmentsArguments): self.ali_paths = args.ali_paths self.new_ali_paths = args.new_ali_paths - def _run(self) -> typing.Generator[typing.Tuple[int, int]]: + def _run(self) -> None: """Run the function""" with self.session() as session, thread_logger( "kalpy.train", self.log_path, job_name=self.job_name diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index 0511f775..35fe5634 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -345,7 +345,9 @@ def align(self, workflow_name=None) -> None: """Run the aligner""" self.alignment_mode = True self.initialize_database() - self.create_new_current_workflow(WorkflowType.alignment, workflow_name) + wf = self.current_workflow + if wf is None: + self.create_new_current_workflow(WorkflowType.alignment, workflow_name) wf = self.current_workflow if wf.done: logger.info("Alignment already done, skipping.") @@ -383,11 +385,11 @@ def align(self, workflow_name=None) -> None: assert self.alignment_model_path.suffix == ".mdl" logger.info("Performing second-pass alignment...") self.align_utterances() - self.collect_alignments() - if self.use_phone_model: - self.transcribe(WorkflowType.phone_transcription) - elif self.fine_tune: - self.fine_tune_alignments() + self.collect_alignments() + if self.use_phone_model: + self.transcribe(WorkflowType.phone_transcription) + elif self.fine_tune: + self.fine_tune_alignments() with self.session() as session: session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( @@ -1062,7 +1064,7 @@ def fine_tune_alignments(self) -> None: Fine tune aligned boundaries to millisecond precision """ logger.info("Fine tuning alignments...") - begin = time.time() + all_begin = time.time() with self.session() as session: arguments = self.fine_tune_arguments() update_mappings = [] @@ -1110,7 +1112,7 @@ def fine_tune_alignments(self) -> None: ) session.commit() self.export_frame_shift = round(self.export_frame_shift / 10, 4) - logger.debug(f"Fine tuning alignments took {time.time() - begin:.3f} seconds") + logger.debug(f"Fine tuning alignments took {time.time() - all_begin:.3f} seconds") def fine_tune_arguments(self) -> List[FineTuneArguments]: """ @@ -1137,6 +1139,9 @@ def fine_tune_arguments(self) -> List[FineTuneArguments]: options = self.pitch_options options["frame_shift"] = 1 pitch_computer = PitchComputer(**options) + align_options = self.align_options + # align_options['transition_scale'] = align_options['transition_scale'] / 10 + align_options["acoustic_scale"] = 1.0 for j in self.jobs: log_path = self.working_log_directory.joinpath(f"fine_tune.{j.id}.log") args.append( @@ -1149,7 +1154,7 @@ def fine_tune_arguments(self) -> List[FineTuneArguments]: lexicon_compiler, self.model_path, self.tree_path, - self.align_options, + align_options, phone_to_group_mapping, self.mfcc_computer.frame_shift, ) diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index 2b963385..b4fab296 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -48,6 +48,7 @@ PhoneType, PronunciationProbabilityCounter, WordType, + WorkflowType, ) from montreal_forced_aligner.db import ( CorpusWorkflow, @@ -60,9 +61,15 @@ TextFile, Utterance, Word, + WordInterval, ) from montreal_forced_aligner.exceptions import AlignmentCollectionError, AlignmentExportError -from montreal_forced_aligner.helper import mfa_open, split_phone_position +from montreal_forced_aligner.helper import ( + align_words, + fix_unk_words, + mfa_open, + split_phone_position, +) from montreal_forced_aligner.textgrid import construct_textgrid_output from montreal_forced_aligner.utils import thread_logger @@ -80,9 +87,10 @@ "AlignmentExtractionArguments", "ExportTextGridArguments", "AlignFunction", - "AnalyzeAlignmentsFunction", "AlignArguments", + "AnalyzeAlignmentsFunction", "AnalyzeAlignmentsArguments", + "AnalyzeTranscriptsFunction", "AccStatsFunction", "AccStatsArguments", "CompileTrainGraphsFunction", @@ -410,6 +418,16 @@ def _run(self): .filter(CorpusWorkflow.current == True) # noqa .first() ) + interjection_costs = {} + if workflow.workflow_type is WorkflowType.transcript_verification: + interjection_words = ( + session.query(Word).filter(Word.word_type == WordType.interjection).all() + ) + if interjection_words: + max_count = max(x.count for x in interjection_words) + for w in interjection_words: + cost = max_count / w.count + interjection_costs[w.word] = cost if self.use_g2p: text_column = Utterance.normalized_character_text else: @@ -420,12 +438,18 @@ def _run(self): lexicon = self.lexicon_compilers[d.id] else: lexicon = d.lexicon_compiler + if workflow.workflow_type is WorkflowType.transcript_verification: + if interjection_words and d.oov_word not in interjection_costs: + interjection_costs[d.oov_word] = min(interjection_costs.values()) + # interjection_costs[d.cutoff_word] = min(interjection_costs.values()) compiler = TrainingGraphCompiler( self.model_path, self.tree_path, lexicon, - lexicon.word_table, use_g2p=self.use_g2p, + batch_size=1000 + if workflow.workflow_type is not WorkflowType.transcript_verification + else 500, ) graph_logger.debug(f"Set up took {time.time() - begin} seconds") query = ( @@ -442,7 +466,9 @@ def _run(self): compiler.export_graphs( fst_ark_path, query, - # callback=self.callback + # callback=self.callback, + interjection_words=interjection_costs, + # cutoff_pattern = d.cutoff_word ) graph_logger.debug(f"Total compilation time: {time.time() - begin} seconds") del compiler @@ -717,6 +743,77 @@ def _run(self): ) +class AnalyzeTranscriptsFunction(KaldiFunction): + """ + Multiprocessing function for analyzing alignments. + + See Also + -------- + :meth:`.CorpusAligner.analyze_alignments` + Main function that calls this function in parallel + :meth:`.CorpusAligner.calculate_speech_post_arguments` + Job method for generating arguments for this function + :kaldi_src:`lattice-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + + Parameters + ---------- + args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CalculateSpeechPostArguments` + Arguments for the function + """ + + def __init__(self, args: AnalyzeAlignmentsArguments): + super().__init__(args) + self.model_path = args.model_path + self.align_options = args.align_options + + def _run(self): + """Run the function""" + + with self.session() as session: + job: Job = ( + session.query(Job) + .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) + .filter(Job.id == self.job_name) + .first() + ) + workflow = ( + session.query(CorpusWorkflow) + .filter(CorpusWorkflow.current == True) # noqa + .first() + ) + query = session.query(Utterance).filter( + Utterance.job_id == job.id, Utterance.alignment_log_likelihood != None # noqa + ) + for utterance in query: + word_intervals = [ + x.as_ctm() + for x in ( + session.query(WordInterval) + .join(WordInterval.word) + .filter( + WordInterval.utterance_id == utterance.id, + WordInterval.workflow_id == workflow.id, + Word.word_type != WordType.silence, + WordInterval.end - WordInterval.begin > 0.03, + ) + .options( + joinedload(WordInterval.word, innerjoin=True), + ) + .order_by(WordInterval.begin) + ) + ] + if not word_intervals: + continue + extra_duration, wer, aligned_duration = align_words( + utterance.normalized_text.split(), word_intervals, "", debug=True + ) + transcript = " ".join(x.label for x in word_intervals) + self.callback((utterance.id, wer, extra_duration, transcript)) + + class FineTuneFunction(KaldiFunction): """ Multiprocessing function for fine tuning alignment. @@ -739,11 +836,11 @@ def __init__(self, args: FineTuneArguments): self.frame_shift_seconds = args.original_frame_shift self.new_frame_shift_seconds = 0.001 - self.feature_padding_factor = 4 + self.feature_padding_factor = 3 self.padding = round(self.frame_shift_seconds, 3) self.splice_frames = 3 - def _run(self) -> typing.Generator[typing.Tuple[int, float]]: + def _run(self): """Run the function""" with self.session() as session, thread_logger( "kalpy.align", self.log_path, job_name=self.job_name @@ -783,7 +880,6 @@ def _run(self) -> typing.Generator[typing.Tuple[int, float]]: self.model_path, self.tree_path, self.lexicon_compiler, - self.lexicon_compiler.word_table, ) for d_id in job.dictionary_ids: utterance_query = ( @@ -852,18 +948,18 @@ def _run(self) -> typing.Generator[typing.Tuple[int, float]]: {"id": interval.id, "begin": interval.begin, "end": interval.end} ) continue - segment_begin = max(round(interval.begin - self.padding, 4), 0) + end_padding = round(self.frame_shift_seconds * 1.5, 3) + prev_padding = round(self.frame_shift_seconds * 1.5, 3) + segment_begin = max(round(interval.begin - prev_padding, 4), 0) feature_segment_begin = max( round( - interval.begin - (self.padding * self.feature_padding_factor), 4 + interval.begin - (prev_padding * self.feature_padding_factor), 4 ), 0, ) - segment_end = min(round(interval.begin + self.padding, 4), utterance.end) + segment_end = round(min(interval.begin + end_padding, interval.end), 3) feature_segment_end = min( - round( - interval.begin + (self.padding * self.feature_padding_factor), 4 - ), + round(interval.begin + (end_padding * self.feature_padding_factor), 4), utterance.end, ) begin_offset = round(segment_begin - feature_segment_begin, 4) @@ -875,7 +971,6 @@ def _run(self) -> typing.Generator[typing.Tuple[int, float]]: train_graph = compiler.compile_fst(text) - prev_label = phone feats = self.mfcc_computer.compute_mfccs_for_export( segment, compress=False ) @@ -904,17 +999,27 @@ def _run(self) -> typing.Generator[typing.Tuple[int, float]]: ) feats = FloatMatrix(sub_matrix) alignment = aligner.align_utterance(train_graph, feats) + if alignment is None: + aligner.acoustic_scale = 0.1 + alignment = aligner.align_utterance(train_graph, feats) + aligner.acoustic_scale = 1.0 ctm_intervals = alignment.generate_ctm( - aligner.transition_model, self.lexicon_compiler.phone_table + aligner.transition_model, + self.lexicon_compiler.phone_table, + frame_shift=0.001, ) interval_mapping.append( { "id": interval.id, - "begin": round(ctm_intervals[1].begin + feature_segment_begin, 4), + "begin": round( + ctm_intervals[1].begin + feature_segment_begin + begin_offset, + 4, + ), "end": interval.end, "label": phone_mapping[ctm_intervals[1].label], } ) + prev_label = phone deletions = [] while True: for i in range(len(interval_mapping) - 1): @@ -1234,17 +1339,21 @@ def _run(self) -> None: else: utts = ( - session.query(Utterance.id, Utterance.begin, Utterance.end) + session.query( + Utterance.id, Utterance.begin, Utterance.end, Utterance.normalized_text + ) .join(Utterance.speaker) .filter(Utterance.job_id == self.job_name) .filter(Speaker.dictionary_id == d.id) ) - for u_id, begin, end in utts: + for u_id, begin, end, text in utts: utterance_times[u_id] = (begin, end) + utterance_texts[u_id] = text if self.lexicon_compilers and d.id in self.lexicon_compilers: lexicon_compiler = self.lexicon_compilers[d.id] else: lexicon_compiler = d.lexicon_compiler + if self.transcription: lat_path = job.construct_path(workflow.working_directory, "lat", "ark", d.id) if not lat_path.exists(): @@ -1316,13 +1425,18 @@ def _run(self) -> None: utterance = int(alignment.utterance_id.split("-")[-1]) found_utterances.add(utterance) try: + text = utterance_texts.get(utterance, None) ctm = lexicon_compiler.phones_to_pronunciations( alignment.words, intervals, transcription=False, - text=utterance_texts.get(utterance, None), + text=text, ) ctm.update_utterance_boundaries(*utterance_times[utterance]) + if text is not None: + ctm.word_intervals = fix_unk_words( + text.split(), ctm.word_intervals, lexicon_compiler + ) extraction_logger.debug(f"Processed {utterance}") self.callback((utterance, d.id, ctm)) except Exception: diff --git a/montreal_forced_aligner/alignment/pretrained.py b/montreal_forced_aligner/alignment/pretrained.py index c7ea9df1..dedaf300 100644 --- a/montreal_forced_aligner/alignment/pretrained.py +++ b/montreal_forced_aligner/alignment/pretrained.py @@ -17,6 +17,7 @@ from sqlalchemy.orm import Session from montreal_forced_aligner.abc import TopLevelMfaWorker +from montreal_forced_aligner.alignment.multiprocessing import AnalyzeTranscriptsFunction from montreal_forced_aligner.data import PhoneType, WorkflowType from montreal_forced_aligner.db import ( CorpusWorkflow, @@ -25,6 +26,7 @@ Phone, Speaker, Utterance, + bulk_update, ) from montreal_forced_aligner.exceptions import KaldiProcessingError from montreal_forced_aligner.helper import ( @@ -39,10 +41,9 @@ update_utterance_intervals, ) from montreal_forced_aligner.transcription.transcriber import TranscriberMixin -from montreal_forced_aligner.utils import log_kaldi_errors +from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function if TYPE_CHECKING: - from montreal_forced_aligner.abc import MetaDict __all__ = ["PretrainedAligner", "DictionaryTrainer"] @@ -322,6 +323,34 @@ def align_one_utterance(self, utterance: Utterance, session: Session) -> None: ) update_utterance_intervals(session, utterance, workflow.id, ctm) + def verify_transcripts(self, workflow_name=None) -> None: + self.initialize_database() + self.create_new_current_workflow(WorkflowType.transcript_verification, workflow_name) + wf = self.current_workflow + if wf.done: + logger.info("Transcript verification already done, skipping.") + return + self.setup() + self.write_lexicon_information(write_disambiguation=True) + super().align() + + arguments = self.analyze_alignments_arguments() + update_mappings = [] + for utt_id, word_error_rate, duration_deviation, transcript in run_kaldi_function( + AnalyzeTranscriptsFunction, arguments, total_count=self.num_current_utterances + ): + update_mappings.append( + { + "id": utt_id, + "word_error_rate": word_error_rate, + "duration_deviation": duration_deviation, + "transcription_text": transcript, + } + ) + with self.session() as session: + bulk_update(session, Utterance, update_mappings) + session.commit() + def align(self, workflow_name=None) -> None: """Run the aligner""" self.initialize_database() diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index 33ae6aa9..2ee44541 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -711,8 +711,8 @@ def normalize_text(self) -> None: word_key += 1 max_mapping_ids[1] = word_key - 1 for w_id, m_id, d_id, w, wt in words: - if wt is WordType.oov: - existing_oovs[(d_id, w)] = {"id": w_id, "count": 0} + if wt is WordType.oov and w not in self.specials_set: + existing_oovs[(d_id, w)] = {"id": w_id, "count": 0, "included": False} continue word_indexes[(d_id, w)] = w_id word_mapping_ids[(d_id, w)] = m_id @@ -752,7 +752,9 @@ def normalize_text(self) -> None: result["oovs"] = " ".join(sorted(oovs)) else: for w in result["normalized_text"].split(): - if (dict_id, w) not in word_indexes: + if (dict_id, w) in existing_oovs: + existing_oovs[(dict_id, w)]["count"] += 1 + elif (dict_id, w) not in word_indexes: if (dict_id, w) not in word_insert_mappings: word_insert_mappings[(dict_id, w)] = { "id": word_key, @@ -761,6 +763,7 @@ def normalize_text(self) -> None: "mapping_id": word_key - 1, "count": 0, "dictionary_id": dict_id, + "included": False, } pronunciation_insert_mappings.append( { @@ -886,6 +889,7 @@ def normalize_text(self) -> None: "word_type": WordType.oov, "mapping_id": word_key - 1, "count": 0, + "included": False, "dictionary_id": dict_id, } pronunciation_insert_mappings.append( diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index b388fa5d..b3b63657 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -409,6 +409,18 @@ class WorkflowType(enum.Enum): g2p = 11 language_model_training = 12 tokenizer_training = 13 + transcript_verification = 14 + + @classmethod + def alignment_workflows(cls): + return { + cls.alignment, + cls.online_alignment, + cls.transcript_verification, + cls.transcription, + cls.per_speaker_transcription, + cls.phone_transcription, + } class WordType(enum.Enum): diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index 374d03bf..92f7c432 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -487,8 +487,15 @@ def lexicon_compiler(self): oov_phone=self.oov_phone, position_dependent_phones=self.position_dependent_phones, ) - lexicon_compiler.load_l_from_file(self.lexicon_fst_path) - lexicon_compiler.load_l_align_from_file(self.align_lexicon_path) + if self.lexicon_disambig_fst_path.exists(): + lexicon_compiler.load_l_from_file(self.lexicon_disambig_fst_path) + lexicon_compiler.disambiguation = True + elif self.lexicon_fst_path.exists(): + lexicon_compiler.load_l_from_file(self.lexicon_fst_path) + if self.align_lexicon_disambig_path.exists(): + lexicon_compiler.load_l_align_from_file(self.align_lexicon_disambig_path) + elif self.align_lexicon_path.exists(): + lexicon_compiler.load_l_align_from_file(self.align_lexicon_path) lexicon_compiler.word_table = self.word_table lexicon_compiler.phone_table = self.phone_table return lexicon_compiler @@ -1533,7 +1540,11 @@ def aligned_word_intervals(self) -> typing.List[CtmInterval]: """ Word intervals from :attr:`montreal_forced_aligner.data.WorkflowType.alignment` """ - return [x.as_ctm() for x in self.word_intervals] + return [ + x.as_ctm() + for x in self.word_intervals + if x.workflow.workflow_type in [WorkflowType.alignment, WorkflowType.online_alignment] + ] @property def transcribed_phone_intervals(self) -> typing.List[CtmInterval]: @@ -1543,7 +1554,12 @@ def transcribed_phone_intervals(self) -> typing.List[CtmInterval]: return [ x.as_ctm() for x in self.phone_intervals - if x.workflow.workflow_type is WorkflowType.transcription + if x.workflow.workflow_type + in [ + WorkflowType.transcription, + WorkflowType.per_speaker_transcription, + WorkflowType.transcript_verification, + ] ] @property @@ -1554,7 +1570,12 @@ def transcribed_word_intervals(self) -> typing.List[CtmInterval]: return [ x.as_ctm() for x in self.word_intervals - if x.workflow.workflow_type is WorkflowType.transcription + if x.workflow.workflow_type + in [ + WorkflowType.transcription, + WorkflowType.per_speaker_transcription, + WorkflowType.transcript_verification, + ] ] @property diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index 075001c0..80c55bb3 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -969,6 +969,7 @@ def calculate_disambiguation(self) -> None: words = ( session.query(Word) .filter(Word.dictionary_id == d.id) + .filter(Word.included == True) # noqa .options(selectinload(Word.pronunciations)) ) for w in words: @@ -978,16 +979,16 @@ def calculate_disambiguation(self) -> None: subsequences.add(tuple(pron)) pron = pron[:-1] last_used = collections.defaultdict(int) - for p_id, pron in ( - session.query(Pronunciation.id, Pronunciation.pronunciation) - .join(Pronunciation.word) - .filter(Word.dictionary_id == d.id) - ): - pron = tuple(pron.split()) - if pron in subsequences: - last_used[pron] += 1 + for w in words: + for p in w.pronunciations: + pron = p.pronunciation + pron = tuple(pron.split()) + if pron in subsequences: + last_used[pron] += 1 - update_pron_objs.append({"id": p_id, "disambiguation": last_used[pron]}) + update_pron_objs.append( + {"id": p.id, "disambiguation": last_used[pron]} + ) if last_used: d.max_disambiguation_symbol = max( @@ -1584,14 +1585,14 @@ def find_all_cutoffs(self) -> None: logger.info("Finding all cutoffs...") initial_brackets = re.escape("".join(x[0] for x in self.brackets)) final_brackets = re.escape("".join(x[1] for x in self.brackets)) - pronunciation_mapping = {} - word_mapping = {} cutoff_identifier = re.sub( rf"[{initial_brackets}{final_brackets}]", "", self.cutoff_word ) max_ids = collections.defaultdict(int) max_pron_id = session.query(sqlalchemy.func.max(Pronunciation.id)).scalar() max_word_id = session.query(sqlalchemy.func.max(Word.id)).scalar() + new_word_mapping = {} + new_pronunciation_mapping = [] for d_id, max_id in ( session.query(Dictionary.id, sqlalchemy.func.max(Word.mapping_id)) .join(Word.dictionary) @@ -1599,58 +1600,31 @@ def find_all_cutoffs(self) -> None: ): max_ids[d_id] = max_id for d_id in self.dictionary_lookup.values(): - pronunciation_mapping[d_id] = collections.defaultdict(list) - word_mapping[d_id] = {} + pronunciation_mapping = collections.defaultdict(set) + word_mapping = {} + max_id = ( + session.query(sqlalchemy.func.max(Word.mapping_id)) + .join(Word.dictionary) + .filter(Dictionary.id == d_id) + ).first()[0] words = ( session.query(Word.mapping_id, Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) - .filter(Word.dictionary_id == d_id) - ) - for m_id, w, pron in words: - pronunciation_mapping[d_id][w].append(pron) - word_mapping[d_id][w] = m_id - new_word_mapping = [] - new_pronunciation_mapping = [] - utterances = ( - session.query( - Utterance.id, - Speaker.dictionary_id, - Utterance.normalized_text, - ) - .join(Utterance.speaker) - .filter( - Utterance.normalized_text.regexp_match( - f"[{initial_brackets}]({cutoff_identifier}|hes)" + .filter( + Word.dictionary_id == d_id, + Word.count > 1, + Word.word_type == WordType.speech, ) ) - ) - utterance_mapping = [] - for u_id, dict_id, normalized_text in utterances: - text = normalized_text.split() - modified = False - for i, word in enumerate(text): - m = re.match( - f"^[{initial_brackets}]({cutoff_identifier}|hes(itation)?)([-_](?P[^{final_brackets}]+))?[{final_brackets}]$", - word, - ) - if not m: - continue - next_word = m.group("word") - if next_word not in word_mapping[dict_id]: - if i != len(text) - 1: - next_word = text[i + 1] - if ( - next_word is None - or next_word not in pronunciation_mapping[dict_id] - or self.oov_phone in pronunciation_mapping[dict_id][next_word] - or self.optional_silence_phone in pronunciation_mapping[dict_id][next_word] - ): - continue - new_word = f"{self.cutoff_word[:-1]}-{next_word}{self.cutoff_word[-1]}" - if new_word not in word_mapping[dict_id]: + for m_id, w, pron in words: + word_mapping[w] = m_id + pronunciation_mapping[w].add(pron) + new_word = f"{self.cutoff_word[:-1]}-{w}{self.cutoff_word[-1]}" + if new_word not in new_word_mapping: max_word_id += 1 - max_ids[dict_id] += 1 + max_id += 1 max_pron_id += 1 + pronunciation_mapping[new_word].add(self.oov_phone) new_pronunciation_mapping.append( { "id": max_pron_id, @@ -1658,44 +1632,73 @@ def find_all_cutoffs(self) -> None: "word_id": max_word_id, } ) - prons = pronunciation_mapping[dict_id][next_word] - pronunciation_mapping[dict_id][new_word] = [] - for p in prons: - p = p.split() - for pi in range(len(p)): - new_p = " ".join(p[: pi + 1]) - if new_p in pronunciation_mapping[dict_id][new_word]: - continue - pronunciation_mapping[dict_id][new_word].append(new_p) - max_pron_id += 1 - new_pronunciation_mapping.append( - { - "id": max_pron_id, - "pronunciation": new_p, - "word_id": max_word_id, - } - ) - new_word_mapping.append( + new_word_mapping[new_word] = { + "id": max_word_id, + "word": new_word, + "dictionary_id": d_id, + "mapping_id": max_id, + "word_type": WordType.cutoff, + } + word_mapping[new_word] = max_id + p = pron.split() + for pi in range(len(p)): + new_p = " ".join(p[: pi + 1]) + if new_p in pronunciation_mapping[new_word]: + continue + pronunciation_mapping[new_word].add(new_p) + max_pron_id += 1 + new_pronunciation_mapping.append( { - "id": max_word_id, - "word": new_word, - "dictionary_id": dict_id, - "mapping_id": max_ids[dict_id], - "word_type": WordType.cutoff, + "id": max_pron_id, + "pronunciation": new_p, + "word_id": new_word_mapping[new_word]["id"], } ) - word_mapping[dict_id][new_word] = max_ids[dict_id] - text[i] = new_word - modified = True - if modified: - utterance_mapping.append( - { - "id": u_id, - "normalized_text": " ".join(text), - } + utterances = ( + session.query( + Utterance.id, + Utterance.normalized_text, + ) + .join(Utterance.speaker) + .filter( + Speaker.dictionary_id == d_id, + Utterance.normalized_text.regexp_match( + f"[{initial_brackets}]({cutoff_identifier}|hes)" + ), ) + ) + utterance_mapping = [] + for u_id, normalized_text in utterances: + text = normalized_text.split() + modified = False + for i, word in enumerate(text): + m = re.match( + f"^[{initial_brackets}]({cutoff_identifier}|hes(itation)?)([-_](?P[^{final_brackets}]+))?[{final_brackets}]$", + word, + ) + if not m: + continue + next_word = m.group("word") + new_word = f"{self.cutoff_word[:-1]}-{next_word}{self.cutoff_word[-1]}" + if ( + next_word is None + or next_word not in word_mapping + or self.oov_phone in pronunciation_mapping[next_word] + or self.optional_silence_phone in pronunciation_mapping[next_word] + or new_word not in word_mapping + ): + continue + text[i] = new_word + modified = True + if modified: + utterance_mapping.append( + { + "id": u_id, + "normalized_text": " ".join(text), + } + ) session.bulk_insert_mappings( - Word, new_word_mapping, return_defaults=False, render_nulls=True + Word, new_word_mapping.values(), return_defaults=False, render_nulls=True ) session.bulk_insert_mappings( Pronunciation, new_pronunciation_mapping, return_defaults=False, render_nulls=True @@ -1779,7 +1782,7 @@ def build_lexicon_compiler( lexicon_compiler.phone_table = self.phone_table else: lexicon_compiler = acoustic_model.lexicon_compiler - lexicon_compiler.disambiguation = disambiguation + lexicon_compiler.disambiguation = disambiguation query = ( session.query(Word, Pronunciation) .join(Pronunciation.word) diff --git a/montreal_forced_aligner/helper.py b/montreal_forced_aligner/helper.py index 6eae5109..0cc56fd9 100644 --- a/montreal_forced_aligner/helper.py +++ b/montreal_forced_aligner/helper.py @@ -25,6 +25,8 @@ from rich.theme import Theme if TYPE_CHECKING: + from kalpy.fstext.lexicon import LexiconCompiler + from montreal_forced_aligner.abc import MetaDict from montreal_forced_aligner.data import CtmInterval @@ -767,6 +769,168 @@ def align_phones( return score, phone_error_rate, errors +def fix_unk_words( + ref: List[str], + test: List[CtmInterval], + lexicon_compiler: LexiconCompiler, +) -> Tuple[float, float, Dict[Tuple[str, str], int]]: + """ + Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for + different phone labels to be scored as if they're the same phone + + Parameters + ---------- + ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`] + List of CTM intervals as reference + test: list[:class:`~montreal_forced_aligner.data.CtmInterval`] + List of CTM intervals to compare to reference + lexicon_compiler: LexiconCompiler + Lexicon compiler to use for evaluating the identity of OOV items + + Returns + ------- + float + Extra duration of new words + float + Word error rate + float + Aligned duration of found words + """ + + from kalpy.gmm.data import WordCtmInterval + + def score_func(ref, test): + ref_label = ref + if isinstance(ref_label, WordCtmInterval): + ref_label = ref_label.label + test_label = test + if isinstance(test_label, WordCtmInterval): + test_label = test_label.label + if ref_label == test_label: + return 0 + if ( + test_label == lexicon_compiler.silence_word + or ref_label == lexicon_compiler.silence_word + ): + return -10 + if lexicon_compiler.to_int(ref_label) == lexicon_compiler.to_int(test_label): + return 0 + return -2 + + alignments = pairwise2.align.globalcs( + ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True + ) + output_ctm = [] + for a in alignments: + for i, sa in enumerate(a.seqA): + sb = a.seqB[i] + if sa == "-": + output_ctm.append(sb) + elif sb == "-": + continue + else: + if sa != sb.label and sb.label == lexicon_compiler.oov_word: + sb.label = sa + output_ctm.append(sb) + return output_ctm + + +def align_words( + ref: List[str], + test: List[CtmInterval], + silence_word: str, + ignored_words: typing.Set[str] = None, + debug: bool = False, +) -> Tuple[float, float, Dict[Tuple[str, str], int]]: + """ + Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for + different phone labels to be scored as if they're the same phone + + Parameters + ---------- + ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`] + List of CTM intervals as reference + test: list[:class:`~montreal_forced_aligner.data.CtmInterval`] + List of CTM intervals to compare to reference + silence_word: str + Silence word (these are ignored in the final calculation) + ignored_words: set[str], optional + Words that should be ignored in score calculations (silence phone is automatically added) + debug: bool, optional + Flag for logging extra information about alignments + + Returns + ------- + float + Extra duration of new words + float + Word error rate + float + Aligned duration of found words + """ + + from montreal_forced_aligner.data import CtmInterval + + if ignored_words is None: + ignored_words = set() + if not isinstance(ignored_words, set): + ignored_words = set(ignored_words) + + def score_func(ref, test): + ref_label = ref + if isinstance(ref_label, CtmInterval): + ref_label = ref_label.label + test_label = test + if isinstance(test_label, CtmInterval): + test_label = test_label.label + if ref_label == test_label: + return 0 + if test_label == silence_word or ref_label == silence_word: + return -10 + return -2 + + alignments = pairwise2.align.globalcs( + ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True + ) + num_insertions = 0 + num_deletions = 0 + num_substitutions = 0 + + ignored_words.add(silence_word) + extra_duration = 0 + aligned_duration = 0 + for a in alignments: + for i, sa in enumerate(a.seqA): + sb = a.seqB[i] + if sa == "-": + if sb.label not in ignored_words: + num_insertions += 1 + extra_duration += sb.end - sb.begin + else: + continue + elif sb == "-": + if sa not in ignored_words: + num_deletions += 1 + else: + continue + else: + if sa in ignored_words: + continue + if sa != sb.label: + num_substitutions += 1 + else: + aligned_duration += sb.end - sb.begin + word_error_rate = (num_insertions + num_deletions + (2 * num_substitutions)) / len(ref) + if debug: + import logging + + logger = logging.getLogger("mfa") + logger.debug( + f"{pairwise2.format_alignment(*alignments[0])}\nExtra word duration: {extra_duration}\nWER: {word_error_rate}" + ) + return extra_duration, word_error_rate, aligned_duration + + def format_probability(probability_value: float) -> float: """Format a probability to have two decimal places and be between 0.01 and 0.99""" return min(max(round(probability_value, 2), 0.01), 0.99) diff --git a/montreal_forced_aligner/ivector/multiprocessing.py b/montreal_forced_aligner/ivector/multiprocessing.py index c42c0072..54b8cf1f 100644 --- a/montreal_forced_aligner/ivector/multiprocessing.py +++ b/montreal_forced_aligner/ivector/multiprocessing.py @@ -2,7 +2,6 @@ from __future__ import annotations import os -import typing from pathlib import Path from _kalpy.gmm import DiagGmm @@ -89,7 +88,7 @@ def __init__(self, args: GmmGselectArguments): self.dubm_model = args.dubm_model self.ivector_options = args.ivector_options - def _run(self) -> typing.Generator[None]: + def _run(self) -> None: """Run the function""" with self.session() as session, thread_logger( "kalpy.ivector", self.log_path, job_name=self.job_name @@ -170,7 +169,7 @@ def __init__(self, args: GaussToPostArguments): self.dubm_model = args.dubm_model self.ivector_options = args.ivector_options - def _run(self) -> typing.Generator[None]: + def _run(self) -> None: """Run the function""" modified_posterior_scale = ( self.ivector_options["posterior_scale"] * self.ivector_options["subsample"] diff --git a/montreal_forced_aligner/language_modeling/multiprocessing.py b/montreal_forced_aligner/language_modeling/multiprocessing.py index c5ba0eb2..93f08a34 100644 --- a/montreal_forced_aligner/language_modeling/multiprocessing.py +++ b/montreal_forced_aligner/language_modeling/multiprocessing.py @@ -116,7 +116,7 @@ def __init__(self, args: TrainLmArguments): self.order = args.order self.oov_word = args.oov_word - def _run(self) -> typing.Generator[bool]: + def _run(self) -> None: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: word_query = session.query(Word.word).filter( @@ -193,7 +193,7 @@ def __init__(self, args: TrainLmArguments): self.symbols_path = args.symbols_path self.order = args.order - def _run(self) -> typing.Generator[bool]: + def _run(self) -> None: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: if config.USE_POSTGRES: @@ -273,7 +273,7 @@ def __init__(self, args: TrainSpeakerLmArguments): self.target_num_ngrams = args.target_num_ngrams self.hclg_options = args.hclg_options - def _run(self) -> typing.Generator[bool]: + def _run(self) -> None: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: job: Job = ( diff --git a/montreal_forced_aligner/online/alignment.py b/montreal_forced_aligner/online/alignment.py index 611b7dab..6db5c2ca 100644 --- a/montreal_forced_aligner/online/alignment.py +++ b/montreal_forced_aligner/online/alignment.py @@ -79,7 +79,6 @@ def align_utterance_online( acoustic_model.alignment_model_path, acoustic_model.tree_path, lexicon_compiler, - lexicon_compiler.word_table, ) if utterance.mfccs is None: utterance.generate_mfccs(acoustic_model.mfcc_computer) diff --git a/montreal_forced_aligner/tokenization/tokenizer.py b/montreal_forced_aligner/tokenization/tokenizer.py index 5e673cec..0327efca 100644 --- a/montreal_forced_aligner/tokenization/tokenizer.py +++ b/montreal_forced_aligner/tokenization/tokenizer.py @@ -196,7 +196,7 @@ def __init__(self, args: TokenizerArguments): super().__init__(args) self.rewriter = args.rewriter - def _run(self) -> typing.Generator: + def _run(self) -> None: """Run the function""" with self.session() as session: utterances = session.query(Utterance.id, Utterance.normalized_text).filter( diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index 131c2869..9bf09ac3 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -354,7 +354,7 @@ def __init__(self, args: CreateHclgArguments): self.tree_path = args.tree_path self.hclg_options = args.hclg_options - def _run(self) -> typing.Generator[typing.Tuple[bool, str]]: + def _run(self) -> None: """Run the function""" with thread_logger("kalpy.decode_graph", self.log_path, job_name=self.job_name): hclg_path = self.working_directory.joinpath(f"HCLG.{self.job_name}.fst") @@ -943,7 +943,7 @@ def __init__(self, args: PerSpeakerDecodeArguments): self.method = args.method self.word_symbols_paths = {} - def _run(self) -> typing.Generator[typing.Tuple[int, str]]: + def _run(self) -> None: """Run the function""" with self.session() as session, thread_logger( "kalpy.decode", self.log_path, job_name=self.job_name diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index 1c2d3a5c..aa6d43c1 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -170,7 +170,6 @@ def segment_utterance_transcript( acoustic_model.alignment_model_path, acoustic_model.tree_path, lexicon_compiler, - lexicon_compiler.word_table, ) if utterance.cmvn_string: cmvn = read_kaldi_object(DoubleMatrix, utterance.cmvn_string) diff --git a/tests/conftest.py b/tests/conftest.py index 9e4a62f4..b576214f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1035,3 +1035,178 @@ def bad_topology_path(config_directory): @pytest.fixture(scope="session") def test_align_config(): return {"beam": 100, "retry_beam": 400} + + +@pytest.fixture(scope="session") +def reference_transcripts(): + return { + "mfa_cutoff": " montreal montreal forced aligner aligner", + "mfa_whatscalled": "montreal forced what's called aligner", + "mfa_uhuh": "montreal uh uh uh uh uh uh forced aligner", + "mfa_uhum": "montreal forced uh um uh hm hm um forced aligner", + "mfa_michael": "montreal forced aligner", + "mfa_kmg": "montreal forced aligner", + "mfa_falsetto": "montreal forced aligner", + "mfa_whisper": "montreal forced aligner", + "mfa_exaggerated": "montreal forced aligner", + "mfa_breathy": "montreal forced aligner", + "mfa_creaky": "montreal forced aligner", + "mfa_long": "montreal forced aligner", + "mfa_hes": "montreal aligner", + "mfa_longstop": "this is a long stop", + "mfa_putty": "m f a is like putty", + "mfa_puddy": "m f a is like puddy", + "mfa_puttynorm": "m f a is like putty", + "mfa_pooty": "m f a is like pooty", + "mfa_bottle": "m f a is like bottle", + "mfa_patty": "m f a is like patty", + "mfa_buddy": "m f a is like buddy", + "mfa_apex": "m f a is like apex", + "mfa_poofy": "m f a is like poofy", + "mfa_reallylong": "m f a is like this is so many words right here", + "mfa_internalsil": " ", + "mfa_her": "montreal forced aligner i hardly know her", + "mfa_er": "montreal forced aligner i hardly know 'er", + "mfa_erpause": "montreal forced aligner i hardly know 'er", + "mfa_cutoffprogressive": " uh montreal forced hm aligner aligner", + "mfa_affectation": "montreal forced aligner", + "mfa_crossword": "but um montreal but um but montreal forced aligner", + "mfa_registershift": "montreal forced forced aligner", + "falsetto": "this is all very high pitched", + "falsetto2": "i really don't know how people talk like this", + "whisper": "this is all very whispered", + "whisper2": "there's gonna be no voiced speech whatsoever here", + "mfa_uh": "montreal uh forced aligner", + "mfa_um": "montreal forced um aligner", + "mfa_youknow": "you know montreal forced aligner", + "mfa_unk": "montreal forced aligner", + "mfa_words": "montreal forced aligner word another word here's some more word word word word word", + "mfa_surround": "this one montreal is going to be forced very bad aligner but what are you gonna do", + "mfa_breaths": "montreal forced aligner", + "mfa_laughter": "[laughter] montreal [laughter] forced [laughter] aligner [laughter]", + "mfa_the": "the montreal forced aligner", + "mfa_thenorm": "the montreal forced aligner", + "mfa_thestop": "this is the montreal forced aligner", + "mfa_theapprox": "this is the montreal forced aligner", + "mfa_thez": "this is the montreal forced aligner", + "mfa_thea": "this is a montreal forced aligner", + "mfa_theinitialstop": "the montreal forced aligner", + "mfa_theother": "this is the other montreal forced aligner", + "mfa_thoughts": "i have a thousand thoughts about that thing", + } + + +@pytest.fixture(scope="session") +def filler_insertion_utterances(): + return [ + "mfa_michael", + "mfa_uh", + "mfa_um", + "mfa_youknow", + "mfa_unk", + "mfa_words", + "mfa_surround", + "mfa_breaths", + "mfa_laughter", + "mfa_cutoffprogressive", + "mfa_uhuh", + "mfa_uhum", + "mfa_whatscalled", + "mfa_cutoff", + "mfa_exaggerated", + ] + + +@pytest.fixture(scope="session") +def putty_utterances(): + return [ + "mfa_putty", + "mfa_puddy", + "mfa_puttynorm", + "mfa_pooty", + "mfa_bottle", + "mfa_patty", + "mfa_buddy", + "mfa_apex", + "mfa_poofy", + "mfa_reallylong", + ] + + +@pytest.fixture(scope="session") +def register_utterances(): + return [ + "mfa_michael", + "mfa_kmg", + "mfa_falsetto", + "mfa_whisper", + "mfa_exaggerated", + "mfa_breathy", + "mfa_creaky", + "mfa_registershift", + "falsetto", + "falsetto2", + "whisper", + "whisper2", + ] + + +@pytest.fixture(scope="session") +def pronunciation_variation_utterances(): + return [ + "mfa_crossword", + "mfa_her", + "mfa_er", + "mfa_erpause", + "mfa_the", + "mfa_thenorm", + "mfa_thestop", + "mfa_theapprox", + "mfa_thez", + "mfa_theinitialstop", + "mfa_theother", + "mfa_thoughts", + ] + + +@pytest.fixture(scope="session") +def cutoff_utterances(): + return [ + "mfa_cutoff", + "mfa_cutoffprogressive", + "mfa_internalsil", + "mfa_longstop", + "mfa_long", + "mfa_hes", + ] + + +@pytest.fixture(scope="session") +def filler_insertion_corpus(filler_insertion_utterances, corpus_root_dir, wav_dir, lab_dir): + path = corpus_root_dir.joinpath("test_filler_insertion") + path.mkdir(exist_ok=True, parents=True) + speaker_name = "michael" + s_dir = path.joinpath(speaker_name) + s_dir.mkdir(exist_ok=True, parents=True) + transcript = "montreal forced aligner" + for name in filler_insertion_utterances: + shutil.copyfile(wav_dir.joinpath(name + ".flac"), s_dir.joinpath(name + ".flac")) + with mfa_open(s_dir.joinpath(name + ".lab"), "w") as f: + f.write(transcript) + return path + + +@pytest.fixture(scope="session") +def pronunciation_variation_corpus( + pronunciation_variation_utterances, corpus_root_dir, wav_dir, lab_dir, reference_transcripts +): + path = corpus_root_dir.joinpath("test_pronunciation_variation") + path.mkdir(exist_ok=True, parents=True) + speaker_name = "michael" + s_dir = path.joinpath(speaker_name) + s_dir.mkdir(exist_ok=True, parents=True) + for name in pronunciation_variation_utterances: + shutil.copyfile(wav_dir.joinpath(name + ".flac"), s_dir.joinpath(name + ".flac")) + with mfa_open(s_dir.joinpath(name + ".lab"), "w") as f: + f.write(reference_transcripts[name]) + return path diff --git a/tests/data/dictionaries/english_us_mfa_reduced.dict b/tests/data/dictionaries/english_us_mfa_reduced.dict index 677c244d..b0caf335 100644 --- a/tests/data/dictionaries/english_us_mfa_reduced.dict +++ b/tests/data/dictionaries/english_us_mfa_reduced.dict @@ -17,6 +17,7 @@ a 0.16 0.15 1.93 0.92 ej a 0.99 0.09 1.63 0.95 ə acoustic 0.99 0.17 1.0 1.0 ə kʰ ʉː s tʲ ɪ k against 0.16 0.06 0.71 1.16 ɡ ɛ n s t +pitched 0.99 0.14 1.0 1.0 pʰ ɪ tʃ t against 0.13 0.07 0.74 1.16 ə ɡ ɪ n s against 0.99 0.1 0.55 1.12 ə ɡ ɛ n s t all 0.99 0.1 2.49 0.86 ɑː ɫ @@ -30,14 +31,36 @@ and 0.17 0.28 4.09 0.54 æ n d and 0.99 0.02 3.18 0.73 æ n and 0.82 0.01 0.3 1.1 ɪ n and 0.78 0.03 0.6 1.06 n̩ +high 0.99 0.07 0.56 1.04 h aj animal 0.99 0.29 1.14 0.95 æ ɲ ɪ m ə ɫ are 0.99 0.07 1.55 0.9 ɑ ɹ are 0.6 0.05 0.19 1.12 ɚ be 0.99 0.03 0.58 1.04 bʲ i +about 0.78 0.1 0.28 1.1 ə b aw +about 0.89 0.19 0.32 1.1 b aw +about 0.99 0.08 0.27 1.09 ə b aw t +about 0.71 0.08 0.29 1.08 b aw t +other 0.99 0.1 2.53 0.9 ɐ ð ɚ +other 0.31 0.02 3.11 0.86 ɐ d̪ ɚ +thoughts 0.17 0.14 1.0 1.0 θ ɒː t s +thoughts 0.17 0.14 1.0 1.0 t̪ ɒː t s +thoughts 0.17 0.14 1.0 1.0 t̪ ɑː t s +thoughts 0.99 0.33 0.88 1.04 θ ɑː t s +thing 0.1 0.11 0.99 1.0 θ ɪ ŋ +thing 0.99 0.28 1.21 0.98 θ ɪ n +thing 0.09 0.21 1.09 0.97 t̪ ɪ ŋ +thing 0.74 0.12 0.22 1.1 t̪ ɪ n bunch 0.99 0.1 0.96 1.01 b ɐ n tʃ but 0.99 0.04 1.14 0.98 b ɐ t +but 0.47 0.29 0.44 1.05 b ɐ ɾ but 0.47 0.02 4.36 0.75 b ɐ by 0.99 0.08 0.85 1.03 b aj +thousand 0.13 0.32 0.92 1.04 θ aw z n̩ +thousand 0.22 0.05 0.92 1.03 t̪ aw z n̩ +thousand 0.39 0.03 0.79 1.06 θ aw z ə n +thousand 0.99 0.1 0.81 1.03 t̪ aw z ə n +thousand 0.09 0.09 0.97 1.02 t̪ aw z n̩ d +thousand 0.17 0.26 0.88 1.06 θ aw z ə n d can't 0.69 0.04 1.19 0.98 cʰ æ n t can't 0.99 0.05 1.35 0.97 cʰ æ n cares 0.99 0.09 0.9 1.06 cʰ ɛ ɹ z @@ -107,7 +130,22 @@ hey 0.69 0.11 1.09 0.98 ej hey 0.99 0.26 4.0 0.55 ç ej hopefully 0.99 0.09 1.38 0.78 h ow p f ə ʎ i i 0.99 0.03 3.2 0.81 aj +hardly 0.99 0.02 0.85 1.03 h ɑ ɹ d ʎ i +her 0.99 0.13 3.09 0.86 h ɝ +her 0.89 0.08 0.1 1.07 ɚ +how 0.99 0.02 4.86 0.71 h aw +talk 0.99 0.15 1.03 1.0 tʰ ɑ k +talk 0.77 0.07 1.84 0.9 tʰ ɒ k +'er 0.99 0.08 0.1 1.07 ɚ +people 0.09 0.42 1.45 0.95 pʰ i p ə ɫ +people 0.99 0.07 1.09 0.99 pʰ i p ɫ̩ +another 0.99 0.19 1.5 0.94 ə n ɐ ð ɚ +another 0.65 0.06 2.17 0.85 ə n ɐ d̪ ɚ +another 0.04 0.09 0.99 1.01 n ɐ d̪ ɚ +another 0.09 0.05 1.84 0.69 æ n ɐ ð ɚ +another 0.51 0.01 1.12 0.98 n ɐ ð ɚ i'm 0.99 0.06 5.23 0.69 aj m +know 0.99 0.17 0.09 1.06 n ow in 0.4 0.06 0.62 1.05 n̩ in 0.99 0.02 0.85 1.02 ɪ n instead 0.99 0.23 1.49 0.89 ɪ n s t ɛ d @@ -127,6 +165,7 @@ justice 0.99 0.59 1.53 0.83 dʒ ɐ s tʲ ɪ s kinda 0.99 0.11 0.97 1.02 cʰ aj n ə kinda 0.99 0.11 0.98 1.02 cʰ aj n d ə league 0.99 0.17 0.81 1.06 ʎ iː ɡ +whispered 0.99 0.43 1.0 1.0 w ɪ s p ɚ d less 0.99 0.15 0.83 1.04 l ɛ s levels 0.99 0.22 0.81 1.08 l ɛ v ə ɫ z like 0.99 0.11 0.76 1.02 l aj k @@ -136,11 +175,29 @@ looked 0.99 0.01 0.59 1.06 l ʊ k t lot 0.99 0.02 0.2 1.1 l ɑ lot 0.88 0.01 0.14 1.12 l ɑ t lower 0.99 0.12 0.52 1.11 l ow ɚ +putty 0.99 0.14 1.0 1.0 pʰ ɐ tʲ i +bottle 0.12 0.14 1.0 1.0 b ɑ t ə ɫ +bottle 0.99 0.03 0.79 1.06 b ɑ t ɫ̩ +patty 0.12 0.14 1.0 1.0 pʰ æ tʲ i +patty 0.99 0.03 0.88 1.03 pʰ æ ɾʲ i +apex 0.99 0.14 1.0 1.0 ej pʰ ɛ k s +buddy 0.99 0.09 1.49 0.67 b ɐ dʲ i +buddy 0.99 0.09 1.49 0.67 b ɐ ɾʲ i +putty 0.99 0.14 1.0 1.0 pʰ ɐ ɾʲ i +puddy 0.99 0.14 1.0 1.0 pʰ ɐ dʲ i +pooty 0.99 0.14 1.0 1.0 pʰ ʉː tʲ i +poofy 0.99 0.14 1.0 1.0 pʰ ʉː fʲ i +m 0.99 0.35 0.82 1.02 ɛ m +bad 0.67 0.07 0.86 1.01 b æ d +bad 0.99 0.29 0.44 1.05 b æ ɾ me 0.99 0.23 0.26 1.06 mʲ i more 0.99 0.12 0.73 1.03 m ɒː ɹ n 0.36 0.07 0.85 1.08 ɪ n +voiced 0.99 0.09 0.96 1.03 v ɔj s t n 0.99 0.19 1.23 0.93 ɛ n +f 0.99 0.13 0.73 1.05 ɛ f no 0.99 0.11 4.15 0.75 n ow +stop 0.99 0.03 1.72 0.91 s t ɑ p not 0.9 0.03 0.74 1.02 n ɑ t not 0.31 0.01 1.04 1.0 n ɑ nothing 0.51 0.41 1.69 0.89 n ɐ θ ɪ ŋ @@ -178,6 +235,11 @@ recording 0.99 0.22 0.89 1.04 ɹ ɪ kʰ ɒ ɹ dʲ ɪ ŋ recording 0.8 0.47 1.32 0.85 ɹ ɪ kʰ ɒ ɹ ɾʲ ɪ n reinforce 0.99 0.17 1.0 1.0 ɹ i ɪ n f ɒ ɹ s run 0.99 0.17 1.16 0.98 ɹ ɐ n +don't 0.18 0.11 7.0 0.67 d ow n t +don't 0.99 0.03 2.58 0.93 d ow n +don't 0.23 0.02 1.31 0.98 d ow +don't 0.02 0.03 1.64 0.85 d ə n +don't 0.02 0.35 1.44 0.91 d n̩ say 0.99 0.24 0.84 1.01 s ej saying 0.08 0.43 0.79 1.06 s ej ɪ ŋ saying 0.99 0.03 0.24 1.06 s ej ɪ n @@ -197,6 +259,7 @@ some 0.29 0.11 1.08 0.99 s m̩ sound 0.6 0.33 0.65 1.09 s aw n d sound 0.99 0.21 0.52 1.1 s aw n sounds 0.99 0.15 1.02 1.0 s aw n d z +whatsoever 0.99 0.14 1.0 1.0 w ɐ t s ow ɛ v ɚ sounds 0.47 0.04 0.79 1.07 s aw n z special 0.82 0.23 0.82 1.08 s p ɛ ʃ ə ɫ special 0.99 0.19 1.59 0.87 s p ɛ ʃ ɫ̩ @@ -237,6 +300,10 @@ to 0.99 0.09 0.65 1.05 tʰ ʊ to 0.27 0.09 0.07 1.12 ə to 0.5 0.04 0.13 1.14 t ə top 0.99 0.1 1.46 0.93 tʰ ɑ p +many 0.62 0.02 1.41 0.96 mʲ ɪ ɲ i +many 0.99 0.04 3.53 0.79 m ɛ ɲ i +many 0.03 0.32 1.42 0.77 m ɛ ɾ̃ i +many 0.03 0.07 0.95 1.03 mʲ ɪ ɾ̃ i twenty 0.22 0.09 1.45 0.91 tʷ ɛ n tʲ i twenty 0.99 0.02 1.92 0.89 tʷ ɛ ɲ i uh 0.03 0.81 1.34 0.96 ə @@ -256,6 +323,8 @@ we're 0.48 0.12 3.51 0.62 w ɪ ɹ welcome 0.99 0.28 3.07 0.6 w ɛ ɫ k ə m where 0.45 0.06 0.92 1.01 w ɚ where 0.99 0.02 5.47 0.57 w ɛ ɹ +right 0.99 0.17 1.03 1.0 ɹ aj t +right 0.52 0.16 1.14 0.99 ɹ aj who 0.99 0.15 3.59 0.7 ç ʉː with 0.99 0.05 0.37 1.13 w ɪ θ words 0.99 0.31 0.46 1.08 w ɝ d z @@ -266,3 +335,10 @@ you 0.99 0.08 1.36 0.97 j ʉː you 0.12 0.48 0.99 1.0 j ə [bracketed] 0.99 0.17 1.0 1.0 spn [laughter] 0.99 0.17 1.0 1.0 spn +aligner 0.99 0.14 1.0 1.0 ə l aj n ɚ +montreal 0.99 0.14 1.0 1.0 m ɑ n tʲ ɹ i ɒ ɫ +forced 0.99 0.04 0.85 1.05 f ɒ ɹ s t +hm 0.67 0.09 0.97 1.02 h ə m +hm 0.99 0.57 1.98 0.5 ə m +hm 0.99 0.32 0.99 1.01 m̩ +mmm 0.99 0.32 0.99 1.01 m̩ diff --git a/tests/data/wav/falsetto.flac b/tests/data/wav/falsetto.flac new file mode 100644 index 00000000..50d42cf1 Binary files /dev/null and b/tests/data/wav/falsetto.flac differ diff --git a/tests/data/wav/falsetto2.flac b/tests/data/wav/falsetto2.flac new file mode 100644 index 00000000..23cc510c Binary files /dev/null and b/tests/data/wav/falsetto2.flac differ diff --git a/tests/data/wav/mfa_a.flac b/tests/data/wav/mfa_a.flac new file mode 100644 index 00000000..dd19b3e6 Binary files /dev/null and b/tests/data/wav/mfa_a.flac differ diff --git a/tests/data/wav/mfa_affectation.flac b/tests/data/wav/mfa_affectation.flac new file mode 100644 index 00000000..db77491e Binary files /dev/null and b/tests/data/wav/mfa_affectation.flac differ diff --git a/tests/data/wav/mfa_apex.flac b/tests/data/wav/mfa_apex.flac new file mode 100644 index 00000000..9b7dc7bd Binary files /dev/null and b/tests/data/wav/mfa_apex.flac differ diff --git a/tests/data/wav/mfa_bottle.flac b/tests/data/wav/mfa_bottle.flac new file mode 100644 index 00000000..66cdafd0 Binary files /dev/null and b/tests/data/wav/mfa_bottle.flac differ diff --git a/tests/data/wav/mfa_breaths.flac b/tests/data/wav/mfa_breaths.flac new file mode 100644 index 00000000..ed4b8115 Binary files /dev/null and b/tests/data/wav/mfa_breaths.flac differ diff --git a/tests/data/wav/mfa_breathy.flac b/tests/data/wav/mfa_breathy.flac new file mode 100644 index 00000000..e9f3f89c Binary files /dev/null and b/tests/data/wav/mfa_breathy.flac differ diff --git a/tests/data/wav/mfa_buddy.flac b/tests/data/wav/mfa_buddy.flac new file mode 100644 index 00000000..cd7991ce Binary files /dev/null and b/tests/data/wav/mfa_buddy.flac differ diff --git a/tests/data/wav/mfa_creaky.flac b/tests/data/wav/mfa_creaky.flac new file mode 100644 index 00000000..041f30ac Binary files /dev/null and b/tests/data/wav/mfa_creaky.flac differ diff --git a/tests/data/wav/mfa_crossword.flac b/tests/data/wav/mfa_crossword.flac new file mode 100644 index 00000000..6dbf5dbe Binary files /dev/null and b/tests/data/wav/mfa_crossword.flac differ diff --git a/tests/data/wav/mfa_cutoff.flac b/tests/data/wav/mfa_cutoff.flac new file mode 100644 index 00000000..f3c0396d Binary files /dev/null and b/tests/data/wav/mfa_cutoff.flac differ diff --git a/tests/data/wav/mfa_cutoffprogressive.flac b/tests/data/wav/mfa_cutoffprogressive.flac new file mode 100644 index 00000000..48349571 Binary files /dev/null and b/tests/data/wav/mfa_cutoffprogressive.flac differ diff --git a/tests/data/wav/mfa_er.flac b/tests/data/wav/mfa_er.flac new file mode 100644 index 00000000..8623f8a0 Binary files /dev/null and b/tests/data/wav/mfa_er.flac differ diff --git a/tests/data/wav/mfa_erpause.flac b/tests/data/wav/mfa_erpause.flac new file mode 100644 index 00000000..f9159231 Binary files /dev/null and b/tests/data/wav/mfa_erpause.flac differ diff --git a/tests/data/wav/mfa_exaggerated.flac b/tests/data/wav/mfa_exaggerated.flac new file mode 100644 index 00000000..16ecc617 Binary files /dev/null and b/tests/data/wav/mfa_exaggerated.flac differ diff --git a/tests/data/wav/mfa_falsetto.flac b/tests/data/wav/mfa_falsetto.flac new file mode 100644 index 00000000..179f1ebf Binary files /dev/null and b/tests/data/wav/mfa_falsetto.flac differ diff --git a/tests/data/wav/mfa_her.flac b/tests/data/wav/mfa_her.flac new file mode 100644 index 00000000..fef2f06c Binary files /dev/null and b/tests/data/wav/mfa_her.flac differ diff --git a/tests/data/wav/mfa_hes.flac b/tests/data/wav/mfa_hes.flac new file mode 100644 index 00000000..8247a3b0 Binary files /dev/null and b/tests/data/wav/mfa_hes.flac differ diff --git a/tests/data/wav/mfa_internalsil.flac b/tests/data/wav/mfa_internalsil.flac new file mode 100644 index 00000000..b6b966ed Binary files /dev/null and b/tests/data/wav/mfa_internalsil.flac differ diff --git a/tests/data/wav/mfa_kmg.flac b/tests/data/wav/mfa_kmg.flac new file mode 100644 index 00000000..127edff3 Binary files /dev/null and b/tests/data/wav/mfa_kmg.flac differ diff --git a/tests/data/wav/mfa_laughter.flac b/tests/data/wav/mfa_laughter.flac new file mode 100644 index 00000000..842e3422 Binary files /dev/null and b/tests/data/wav/mfa_laughter.flac differ diff --git a/tests/data/wav/mfa_long.flac b/tests/data/wav/mfa_long.flac new file mode 100644 index 00000000..bacd769c Binary files /dev/null and b/tests/data/wav/mfa_long.flac differ diff --git a/tests/data/wav/mfa_longstop.flac b/tests/data/wav/mfa_longstop.flac new file mode 100644 index 00000000..0f461d9b Binary files /dev/null and b/tests/data/wav/mfa_longstop.flac differ diff --git a/tests/data/wav/mfa_michael.flac b/tests/data/wav/mfa_michael.flac new file mode 100644 index 00000000..132dc877 Binary files /dev/null and b/tests/data/wav/mfa_michael.flac differ diff --git a/tests/data/wav/mfa_patty.flac b/tests/data/wav/mfa_patty.flac new file mode 100644 index 00000000..74483301 Binary files /dev/null and b/tests/data/wav/mfa_patty.flac differ diff --git a/tests/data/wav/mfa_poofy.flac b/tests/data/wav/mfa_poofy.flac new file mode 100644 index 00000000..a9840512 Binary files /dev/null and b/tests/data/wav/mfa_poofy.flac differ diff --git a/tests/data/wav/mfa_pooty.flac b/tests/data/wav/mfa_pooty.flac new file mode 100644 index 00000000..ed204086 Binary files /dev/null and b/tests/data/wav/mfa_pooty.flac differ diff --git a/tests/data/wav/mfa_puddy.flac b/tests/data/wav/mfa_puddy.flac new file mode 100644 index 00000000..72b3e295 Binary files /dev/null and b/tests/data/wav/mfa_puddy.flac differ diff --git a/tests/data/wav/mfa_putty.flac b/tests/data/wav/mfa_putty.flac new file mode 100644 index 00000000..03b29bd3 Binary files /dev/null and b/tests/data/wav/mfa_putty.flac differ diff --git a/tests/data/wav/mfa_puttynorm.flac b/tests/data/wav/mfa_puttynorm.flac new file mode 100644 index 00000000..45914c38 Binary files /dev/null and b/tests/data/wav/mfa_puttynorm.flac differ diff --git a/tests/data/wav/mfa_reallylong.flac b/tests/data/wav/mfa_reallylong.flac new file mode 100644 index 00000000..5b677185 Binary files /dev/null and b/tests/data/wav/mfa_reallylong.flac differ diff --git a/tests/data/wav/mfa_registershift.flac b/tests/data/wav/mfa_registershift.flac new file mode 100644 index 00000000..47514aab Binary files /dev/null and b/tests/data/wav/mfa_registershift.flac differ diff --git a/tests/data/wav/mfa_surround.flac b/tests/data/wav/mfa_surround.flac new file mode 100644 index 00000000..445f223e Binary files /dev/null and b/tests/data/wav/mfa_surround.flac differ diff --git a/tests/data/wav/mfa_the.flac b/tests/data/wav/mfa_the.flac new file mode 100644 index 00000000..d77ced80 Binary files /dev/null and b/tests/data/wav/mfa_the.flac differ diff --git a/tests/data/wav/mfa_theapprox.flac b/tests/data/wav/mfa_theapprox.flac new file mode 100644 index 00000000..7a13388e Binary files /dev/null and b/tests/data/wav/mfa_theapprox.flac differ diff --git a/tests/data/wav/mfa_theinitialstop.flac b/tests/data/wav/mfa_theinitialstop.flac new file mode 100644 index 00000000..0ed886ac Binary files /dev/null and b/tests/data/wav/mfa_theinitialstop.flac differ diff --git a/tests/data/wav/mfa_thenorm.flac b/tests/data/wav/mfa_thenorm.flac new file mode 100644 index 00000000..5e1bfd02 Binary files /dev/null and b/tests/data/wav/mfa_thenorm.flac differ diff --git a/tests/data/wav/mfa_theother.flac b/tests/data/wav/mfa_theother.flac new file mode 100644 index 00000000..323205f7 Binary files /dev/null and b/tests/data/wav/mfa_theother.flac differ diff --git a/tests/data/wav/mfa_thestop.flac b/tests/data/wav/mfa_thestop.flac new file mode 100644 index 00000000..809920f8 Binary files /dev/null and b/tests/data/wav/mfa_thestop.flac differ diff --git a/tests/data/wav/mfa_thez.flac b/tests/data/wav/mfa_thez.flac new file mode 100644 index 00000000..2444537a Binary files /dev/null and b/tests/data/wav/mfa_thez.flac differ diff --git a/tests/data/wav/mfa_thoughts.flac b/tests/data/wav/mfa_thoughts.flac new file mode 100644 index 00000000..f36423dd Binary files /dev/null and b/tests/data/wav/mfa_thoughts.flac differ diff --git a/tests/data/wav/mfa_uh.flac b/tests/data/wav/mfa_uh.flac new file mode 100644 index 00000000..a0813273 Binary files /dev/null and b/tests/data/wav/mfa_uh.flac differ diff --git a/tests/data/wav/mfa_uhuh.flac b/tests/data/wav/mfa_uhuh.flac new file mode 100644 index 00000000..b104406f Binary files /dev/null and b/tests/data/wav/mfa_uhuh.flac differ diff --git a/tests/data/wav/mfa_uhum.flac b/tests/data/wav/mfa_uhum.flac new file mode 100644 index 00000000..6a5550be Binary files /dev/null and b/tests/data/wav/mfa_uhum.flac differ diff --git a/tests/data/wav/mfa_um.flac b/tests/data/wav/mfa_um.flac new file mode 100644 index 00000000..8652704e Binary files /dev/null and b/tests/data/wav/mfa_um.flac differ diff --git a/tests/data/wav/mfa_unk.flac b/tests/data/wav/mfa_unk.flac new file mode 100644 index 00000000..150ad1fd Binary files /dev/null and b/tests/data/wav/mfa_unk.flac differ diff --git a/tests/data/wav/mfa_whatscalled.flac b/tests/data/wav/mfa_whatscalled.flac new file mode 100644 index 00000000..54edc091 Binary files /dev/null and b/tests/data/wav/mfa_whatscalled.flac differ diff --git a/tests/data/wav/mfa_whisper.flac b/tests/data/wav/mfa_whisper.flac new file mode 100644 index 00000000..7ec49ee3 Binary files /dev/null and b/tests/data/wav/mfa_whisper.flac differ diff --git a/tests/data/wav/mfa_words.flac b/tests/data/wav/mfa_words.flac new file mode 100644 index 00000000..57d0dd55 Binary files /dev/null and b/tests/data/wav/mfa_words.flac differ diff --git a/tests/data/wav/mfa_youknow.flac b/tests/data/wav/mfa_youknow.flac new file mode 100644 index 00000000..7a933214 Binary files /dev/null and b/tests/data/wav/mfa_youknow.flac differ diff --git a/tests/data/wav/whisper.flac b/tests/data/wav/whisper.flac new file mode 100644 index 00000000..3275dd54 Binary files /dev/null and b/tests/data/wav/whisper.flac differ diff --git a/tests/data/wav/whisper2.flac b/tests/data/wav/whisper2.flac new file mode 100644 index 00000000..21184106 Binary files /dev/null and b/tests/data/wav/whisper2.flac differ diff --git a/tests/test_alignment_pretrained.py b/tests/test_alignment_pretrained.py index 93be9adb..9eef3cb6 100644 --- a/tests/test_alignment_pretrained.py +++ b/tests/test_alignment_pretrained.py @@ -3,7 +3,16 @@ from montreal_forced_aligner.alignment import PretrainedAligner from montreal_forced_aligner.data import WordType, WorkflowType -from montreal_forced_aligner.db import PhoneInterval, Utterance, Word, WordInterval +from montreal_forced_aligner.db import ( + File, + Phone, + PhoneInterval, + Utterance, + Word, + WordInterval, + bulk_update, +) +from montreal_forced_aligner.helper import align_words def test_align_sick( @@ -20,7 +29,7 @@ def test_align_sick( acoustic_model_path=english_acoustic_model, oov_count_threshold=1, dither=0, - **test_align_config + **test_align_config, ) a.align() assert a.dither == 0 @@ -55,7 +64,7 @@ def test_align_sick_mfa( dictionary_path=english_us_mfa_dictionary, acoustic_model_path=english_mfa_acoustic_model, oov_count_threshold=1, - **test_align_config + **test_align_config, ) a.align() export_directory = os.path.join(temp_dir, "test_align_mfa_export") @@ -89,7 +98,7 @@ def test_align_one( debug=True, verbose=True, clean=True, - **test_align_config + **test_align_config, ) a.initialize_database() a.create_new_current_workflow(WorkflowType.online_alignment) @@ -126,3 +135,148 @@ def test_align_one( assert len(utterance.phone_intervals) > 0 a.cleanup() a.clean_working_directory() + + +def test_no_silence( + english_us_mfa_reduced_dict, + english_mfa_acoustic_model, + pronunciation_variation_corpus, + temp_dir, + test_align_config, + db_setup, +): + a = PretrainedAligner( + corpus_directory=pronunciation_variation_corpus, + dictionary_path=english_us_mfa_reduced_dict, + acoustic_model_path=english_mfa_acoustic_model, + debug=True, + verbose=True, + clean=True, + silence_probability=0.0, + **test_align_config, + ) + a.initialize_database() + a.create_new_current_workflow(WorkflowType.online_alignment) + a.setup() + with a.session() as session: + utterance = ( + session.query(Utterance) + .join(Utterance.file) + .filter(File.name == "mfa_erpause") + .first() + ) + assert utterance.alignment_log_likelihood is None + assert utterance.features is not None + assert len(utterance.phone_intervals) == 0 + print(a.silence_probability) + print(a.lexicon_compilers[1].silence_probability) + a.lexicon_compilers[1]._fst.set_input_symbols(a.lexicon_compilers[1].phone_table) + a.lexicon_compilers[1]._fst.set_output_symbols(a.lexicon_compilers[1].word_table) + print(a.lexicon_compilers[1]._fst) + a.align_one_utterance(utterance, session) + + with a.session() as session: + utterance = ( + session.query(Utterance) + .join(Utterance.file) + .filter(File.name == "mfa_erpause") + .first() + ) + silence_count = ( + session.query(PhoneInterval) + .join(PhoneInterval.phone) + .filter(Phone.phone == a.optional_silence_phone) + .count() + ) + assert silence_count == 0 + assert utterance.alignment_log_likelihood is not None + assert len(utterance.phone_intervals) > 0 + assert ( + len( + [x for x in utterance.phone_intervals if x.phone.phone != a.optional_silence_phone] + ) + > 0 + ) + + a.cleanup() + a.clean_working_directory() + + +def test_transcript_verification( + filler_insertion_corpus, + english_us_mfa_reduced_dict, + english_mfa_acoustic_model, + temp_dir, + db_setup, + reference_transcripts, +): + a = PretrainedAligner( + corpus_directory=filler_insertion_corpus, + dictionary_path=english_us_mfa_reduced_dict, + acoustic_model_path=english_mfa_acoustic_model, + boost_silence=3.0, + acoustic_scale=0.0833, + self_loop_scale=1.0, + transition_scale=1.0, + use_cutoff_model=True, + uses_speaker_adaptation=False, + ) + a.initialize_database() + a.create_new_current_workflow(WorkflowType.transcript_verification) + a.setup() + with a.session() as session: + update_mappings = [] + counts = {"uh": 10, "um": 10, "hm": 1} + for w in session.query(Word).filter(Word.word.in_(["uh", "um", "hm"])): + update_mappings.append( + {"id": w.id, "word_type": WordType.interjection, "count": counts[w.word]} + ) + bulk_update(session, Word, update_mappings) + session.commit() + a.verify_transcripts() + export_directory = os.path.join(temp_dir, "test_transcript_verification_export") + shutil.rmtree(export_directory, ignore_errors=True) + a.export_files(export_directory) + successes = [] + with a.session() as session: + utterances = session.query(Utterance).all() + for utterance in utterances: + if utterance.file_name in {"mfa_breaths"}: + continue + print("FILE:", utterance.file_name) + print("REFERENCE:", reference_transcripts[utterance.file_name]) + print("ORIGINAL: ", utterance.normalized_text) + word_intervals = ( + session.query(WordInterval) + .join(WordInterval.word) + .filter( + WordInterval.utterance_id == utterance.id, + Word.word_type != WordType.silence, + WordInterval.end - WordInterval.begin > 0.03, + ) + .order_by(WordInterval.begin) + ) + generated = " ".join(x.word.word for x in word_intervals) + extra_duration, wer, aligned_duration = align_words( + utterance.normalized_text.split(), + [x.as_ctm() for x in word_intervals], + "", + debug=True, + ) + + print("FILE:", utterance.file_name) + print("LOG LIKELIHOOD:", utterance.alignment_log_likelihood) + print("DURATION DEVIATION:", utterance.duration_deviation) + if reference_transcripts[utterance.file_name] != generated: + print("VERIFIED: ", generated) + print(wer, extra_duration) + else: + successes.append(utterance.file_name) + if reference_transcripts[utterance.file_name] == utterance.normalized_text: + assert utterance.duration_deviation == 0 + assert utterance.word_error_rate == 0 + else: + assert utterance.duration_deviation > 0 + assert utterance.word_error_rate > 0 + + print(f"Successful: {successes} of {len(utterances)}") diff --git a/tests/test_commandline_align.py b/tests/test_commandline_align.py index 00abdb00..b06bf7a8 100644 --- a/tests/test_commandline_align.py +++ b/tests/test_commandline_align.py @@ -2,6 +2,7 @@ import os import click.testing +import pytest from praatio import textgrid as tgio from montreal_forced_aligner.command_line.mfa import mfa_cli @@ -760,6 +761,7 @@ def test_swedish_mfa( assert len(tg.tierNames) == 2 +@pytest.mark.skip def test_acoustic_g2p_model( basic_corpus_dir, acoustic_model_dir,