diff --git a/src/Machine/src/SIL.Machine.AspNetCore/Services/CorpusService.cs b/src/Machine/src/SIL.Machine.AspNetCore/Services/CorpusService.cs index b48e9724..635bbff5 100644 --- a/src/Machine/src/SIL.Machine.AspNetCore/Services/CorpusService.cs +++ b/src/Machine/src/SIL.Machine.AspNetCore/Services/CorpusService.cs @@ -26,7 +26,7 @@ public IEnumerable CreateTextCorpora(IReadOnlyList file break; case FileFormat.Paratext: - corpora.Add(new ParatextBackupTextCorpus(file.Location)); + corpora.Add(new ParatextBackupTextCorpus(file.Location, includeAllText: true)); break; } } diff --git a/src/Machine/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/Machine/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs index fd9f6217..8372c28e 100644 --- a/src/Machine/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs +++ b/src/Machine/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -128,7 +128,7 @@ CancellationToken cancellationToken continue; int skipCount = 0; - foreach (Row?[] rows in AlignCorpora(sourceTextCorpora, targetTextCorpus)) + foreach (Row?[] rows in AlignTrainCorpus(sourceTextCorpora, targetTextCorpus)) { if (skipCount > 0) { @@ -153,26 +153,6 @@ CancellationToken cancellationToken if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; } - - Row? pretranslateRow = rows[0]; - if ( - pretranslateRow is not null - && IsInPretranslate(pretranslateRow, corpus) - && pretranslateRow.SourceSegment.Length > 0 - && pretranslateRow.TargetSegment.Length == 0 - ) - { - pretranslateWriter.WriteStartObject(); - pretranslateWriter.WriteString("corpusId", corpus.Id); - pretranslateWriter.WriteString("textId", pretranslateRow.TextId); - pretranslateWriter.WriteStartArray("refs"); - foreach (object rowRef in pretranslateRow.Refs) - pretranslateWriter.WriteStringValue(rowRef.ToString()); - pretranslateWriter.WriteEndArray(); - pretranslateWriter.WriteString("translation", pretranslateRow.SourceSegment); - pretranslateWriter.WriteEndObject(); - pretranslateCount++; - } } if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) @@ -190,6 +170,23 @@ pretranslateRow is not null } } } + + foreach (Row row in AlignPretranslateCorpus(sourceTextCorpora[0], targetTextCorpus)) + { + if (IsInPretranslate(row, corpus) && row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) + { + pretranslateWriter.WriteStartObject(); + pretranslateWriter.WriteString("corpusId", corpus.Id); + pretranslateWriter.WriteString("textId", row.TextId); + pretranslateWriter.WriteStartArray("refs"); + foreach (object rowRef in row.Refs) + pretranslateWriter.WriteStringValue(rowRef.ToString()); + pretranslateWriter.WriteEndArray(); + pretranslateWriter.WriteString("translation", row.SourceSegment); + pretranslateWriter.WriteEndObject(); + pretranslateCount++; + } + } } pretranslateWriter.WriteEndArray(); @@ -244,13 +241,13 @@ private static bool IsIncluded( private static bool IsInChapters(IReadOnlyDictionary> bookChapters, object rowRef) { - if (rowRef is not VerseRef vr) + if (rowRef is not ScriptureRef sr) return false; - return bookChapters.TryGetValue(vr.Book, out HashSet? chapters) - && (chapters.Contains(vr.ChapterNum) || chapters.Count == 0); + return bookChapters.TryGetValue(sr.Book, out HashSet? chapters) + && (chapters.Contains(sr.ChapterNum) || chapters.Count == 0); } - private static IEnumerable AlignCorpora(IReadOnlyList srcCorpora, ITextCorpus trgCorpus) + private static IEnumerable AlignTrainCorpus(IReadOnlyList srcCorpora, ITextCorpus trgCorpus) { if (trgCorpus.IsScripture()) { @@ -332,7 +329,7 @@ private static bool IsInChapters(IReadOnlyDictionary> bookC { yield return new( vrefs.First().Book, - vrefs.Order().Cast().ToArray(), + vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), srcSegBuffer.ToString(), trgSegBuffer.ToString(), rowCount @@ -355,7 +352,7 @@ private static bool IsInChapters(IReadOnlyDictionary> bookC { yield return new( vrefs.First().Book, - vrefs.Order().Cast().ToArray(), + vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), srcSegBuffer.ToString(), trgSegBuffer.ToString(), rowCount @@ -365,6 +362,50 @@ private static bool IsInChapters(IReadOnlyDictionary> bookC } } + private static IEnumerable AlignPretranslateCorpus(ITextCorpus srcCorpus, ITextCorpus trgCorpus) + { + int rowCount = 0; + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + List refs = []; + string textId = ""; + foreach (ParallelTextRow row in srcCorpus.AlignRows(trgCorpus, allSourceRows: true)) + { + if (!row.IsTargetRangeStart && row.IsTargetInRange) + { + refs.AddRange(row.Refs); + if (row.SourceText.Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(row.SourceText); + } + rowCount++; + } + else + { + if (rowCount > 0) + { + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + textId = ""; + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + rowCount = 0; + } + + textId = row.TextId; + refs.AddRange(row.Refs); + srcSegBuffer.Append(row.SourceText); + trgSegBuffer.Append(row.TargetText); + rowCount++; + } + } + + if (rowCount > 0) + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + } + private record Row( string TextId, IReadOnlyList Refs, diff --git a/src/Machine/test/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs b/src/Machine/test/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs index 9416d25e..5fafaeec 100644 --- a/src/Machine/test/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs +++ b/src/Machine/test/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs @@ -43,7 +43,7 @@ public async Task RunAsync_TrainOnAll() public async Task RunAsync_TrainOnTextIds() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnTextIds = new HashSet { "textId1" } }; + Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnTextIds = ["textId1"] }; await env.RunBuildJobAsync(corpus1); @@ -72,7 +72,7 @@ public async Task RunAsync_PretranslateAll() public async Task RunAsync_PretranslateTextIds() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = new HashSet { "textId1" } }; + Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = ["textId1"] }; await env.RunBuildJobAsync(corpus1); @@ -178,7 +178,7 @@ public async Task RunAsync_MixedSource_Paratext() Assert.That(trgCount, Is.EqualTo(1)); Assert.That(termCount, Is.EqualTo(0)); }); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(8)); + Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(12)); } [Test] @@ -247,8 +247,8 @@ public TestEnvironment() TargetLanguage = "en", PretranslateAll = false, TrainOnAll = false, - PretranslateTextIds = new HashSet(), - TrainOnTextIds = new HashSet(), + PretranslateTextIds = [], + TrainOnTextIds = [], SourceFiles = [TextFile("source1")], TargetFiles = [TextFile("target1")] }; @@ -260,8 +260,8 @@ public TestEnvironment() TargetLanguage = "en", PretranslateAll = false, TrainOnAll = false, - PretranslateTextIds = new HashSet(), - TrainOnTextIds = new HashSet(), + PretranslateTextIds = [], + TrainOnTextIds = [], SourceFiles = [TextFile("source1"), TextFile("source2")], TargetFiles = [TextFile("target1")] }; @@ -273,8 +273,8 @@ public TestEnvironment() TargetLanguage = "en", PretranslateAll = false, TrainOnAll = false, - PretranslateTextIds = new HashSet(), - TrainOnTextIds = new HashSet(), + PretranslateTextIds = [], + TrainOnTextIds = [], SourceFiles = [ParatextFile("pt-source1")], TargetFiles = [ParatextFile("pt-target1")] }; @@ -286,8 +286,8 @@ public TestEnvironment() TargetLanguage = "en", PretranslateAll = false, TrainOnAll = false, - PretranslateTextIds = new HashSet(), - TrainOnTextIds = new HashSet(), + PretranslateTextIds = [], + TrainOnTextIds = [], SourceFiles = [ParatextFile("pt-source1"), ParatextFile("pt-source2")], TargetFiles = [ParatextFile("pt-target1")] };