From b270eb5dae8e68b3e6da3c326a0dff397b1c0570 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Mon, 28 Oct 2024 12:07:24 -0400 Subject: [PATCH] A better fix for #516. --- .../Services/EngineService.cs | 71 ++++- .../test/Serval.E2ETests/ServalApiTests.cs | 9 +- .../Services/EngineServiceTests.cs | 284 +++++++++++++++++- 3 files changed, 339 insertions(+), 25 deletions(-) diff --git a/src/Serval/src/Serval.Translation/Services/EngineService.cs b/src/Serval/src/Serval.Translation/Services/EngineService.cs index 22e5b411..1d4a9bdc 100644 --- a/src/Serval/src/Serval.Translation/Services/EngineService.cs +++ b/src/Serval/src/Serval.Translation/Services/EngineService.cs @@ -248,7 +248,13 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellationTok Corpora = { parallelCorpora.Select(c => - Map(c, trainOn?.GetValueOrDefault(c.Id), pretranslate?.GetValueOrDefault(c.Id)) + Map( + c, + trainOn?.GetValueOrDefault(c.Id), + pretranslate?.GetValueOrDefault(c.Id), + trainOn is null, + pretranslate is null + ) ) } }; @@ -276,7 +282,13 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellationTok Corpora = { corpora.Select(c => - Map(c, trainOn?.GetValueOrDefault(c.Id), pretranslate?.GetValueOrDefault(c.Id)) + Map( + c, + trainOn?.GetValueOrDefault(c.Id), + pretranslate?.GetValueOrDefault(c.Id), + trainOn is null, + pretranslate is null + ) ) } }; @@ -613,7 +625,13 @@ private Models.WordGraphArc Map(V1.WordGraphArc source) }; } - private V1.ParallelCorpus Map(Corpus source, TrainingCorpus? trainingCorpus, PretranslateCorpus? pretranslateCorpus) + private V1.ParallelCorpus Map( + Corpus source, + TrainingCorpus? trainingCorpus, + PretranslateCorpus? pretranslateCorpus, + bool noTrainingCorpusDefined, + bool noPretranslateCorpusDefined + ) { IEnumerable sourceFiles = source.SourceFiles.Select(Map); IEnumerable targetFiles = source.TargetFiles.Select(Map); @@ -622,12 +640,15 @@ private V1.ParallelCorpus Map(Corpus source, TrainingCorpus? trainingCorpus, Pre V1.MonolingualCorpus targetCorpus = new() { Language = source.TargetLanguage, Files = { source.TargetFiles.Select(Map) } }; - if (trainingCorpus is null || (trainingCorpus.TextIds is null && trainingCorpus.ScriptureRange is null)) + if ( + noTrainingCorpusDefined + || (trainingCorpus is not null && trainingCorpus.TextIds is null && trainingCorpus.ScriptureRange is null) + ) { sourceCorpus.TrainOnAll = true; targetCorpus.TrainOnAll = true; } - else + else if (trainingCorpus is not null) { if (trainingCorpus.TextIds is not null && trainingCorpus.ScriptureRange is not null) { @@ -663,14 +684,18 @@ private V1.ParallelCorpus Map(Corpus source, TrainingCorpus? trainingCorpus, Pre } } if ( - pretranslateCorpus is null - || (pretranslateCorpus.TextIds is null && pretranslateCorpus.ScriptureRange is null) + noPretranslateCorpusDefined + || ( + pretranslateCorpus is not null + && pretranslateCorpus.TextIds is null + && pretranslateCorpus.ScriptureRange is null + ) ) { sourceCorpus.PretranslateAll = true; targetCorpus.PretranslateAll = true; } - else + else if (pretranslateCorpus is not null) { if (pretranslateCorpus.TextIds is not null && pretranslateCorpus.ScriptureRange is not null) { @@ -713,7 +738,9 @@ pretranslateCorpus is null private V1.ParallelCorpus Map( Models.ParallelCorpus source, TrainingCorpus? trainingCorpus, - PretranslateCorpus? pretranslateCorpus + PretranslateCorpus? pretranslateCorpus, + bool noTrainingCorpusDefined, + bool noPretranslateCorpusDefined ) { string? referenceFileLocation = @@ -731,7 +758,10 @@ private V1.ParallelCorpus Map( sc, trainingCorpus?.SourceFilters?.Where(sf => sf.CorpusRef == sc.Id).FirstOrDefault(), pretranslateCorpus?.SourceFilters?.Where(sf => sf.CorpusRef == sc.Id).FirstOrDefault(), - referenceFileLocation + referenceFileLocation, + noTrainingCorpusDefined || (trainingCorpus is not null && trainingCorpus.SourceFilters is null), + noPretranslateCorpusDefined + || (pretranslateCorpus is not null && pretranslateCorpus.SourceFilters is null) ) ) }, @@ -742,7 +772,9 @@ private V1.ParallelCorpus Map( tc, trainingCorpus?.TargetFilters?.Where(sf => sf.CorpusRef == tc.Id).FirstOrDefault(), null, - referenceFileLocation + referenceFileLocation, + noTrainingCorpusDefined || (trainingCorpus is not null && trainingCorpus.TargetFilters is null), + noPretranslateCorpusDefined || pretranslateCorpus is not null // there is no pretranslate Target filter. ) ) } @@ -753,7 +785,9 @@ private V1.MonolingualCorpus Map( Models.MonolingualCorpus source, ParallelCorpusFilter? trainingFilter, ParallelCorpusFilter? pretranslateFilter, - string? referenceFileLocation + string? referenceFileLocation, + bool trainOnAll, + bool pretranslateOnAll ) { Dictionary? trainOnChapters = null; @@ -801,7 +835,10 @@ pretranslateFilter is not null Files = { source.Files.Select(Map) } }; - if (trainingFilter is null || (trainingFilter.TextIds is null && trainingFilter.ScriptureRange is null)) + if ( + trainOnAll + || (trainingFilter is not null && trainingFilter.TextIds is null && trainingFilter.ScriptureRange is null) + ) { corpus.TrainOnAll = true; } @@ -814,8 +851,12 @@ pretranslateFilter is not null } if ( - pretranslateFilter is null - || (pretranslateFilter.TextIds is null && pretranslateFilter.ScriptureRange is null) + pretranslateOnAll + || ( + pretranslateFilter is not null + && pretranslateFilter.TextIds is null + && pretranslateFilter.ScriptureRange is null + ) ) { corpus.PretranslateAll = true; diff --git a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs index f9108934..cb4afb66 100644 --- a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs +++ b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs @@ -125,11 +125,16 @@ public async Task NmtBatch() _helperClient.TranslationBuildConfig.Pretranslate = [new() { CorpusId = cId2, TextIds = ["2JN.txt"] }]; await _helperClient.BuildEngineAsync(engineId); await Task.Delay(1000); - IList lTrans = await _helperClient.TranslationEnginesClient.GetAllPretranslationsAsync( + IList lTrans1 = await _helperClient.TranslationEnginesClient.GetAllPretranslationsAsync( + engineId, + cId1 + ); + Assert.That(lTrans1, Has.Count.EqualTo(0)); // should be nothing + IList lTrans2 = await _helperClient.TranslationEnginesClient.GetAllPretranslationsAsync( engineId, cId2 ); - Assert.That(lTrans, Has.Count.EqualTo(13)); // just 2 John + Assert.That(lTrans2, Has.Count.EqualTo(13)); // just 2 John } [Test] diff --git a/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs b/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs index be53d27d..891e4475 100644 --- a/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs +++ b/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs @@ -466,6 +466,126 @@ await env.Service.StartBuildAsync( ); } + [Test] + public async Task StartBuildAsync_OneEachOfMultipleCorpora() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateMultipleCorporaEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1" }], + Pretranslate = [new PretranslateCorpus { CorpusRef = "corpus2" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.ParallelCorpus + { + Id = "corpus1", + SourceCorpora = + { + new List + { + new() + { + Language = "es", + Files = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = false, + TrainOnAll = true + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Language = "en", + Files = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = false, + TrainOnAll = true + } + } + } + }, + new V1.ParallelCorpus + { + Id = "corpus2", + SourceCorpora = + { + new List + { + new() + { + Language = "es", + Files = + { + new V1.CorpusFile + { + Location = "file3.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Language = "en", + Files = + { + new V1.CorpusFile + { + Location = "file4.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + } + } + } + } + ); + } + [Test] public async Task StartBuildAsync_TextFilesScriptureRangeSpecified() { @@ -734,7 +854,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } }, @@ -773,7 +893,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } } @@ -883,6 +1003,154 @@ await env.Service.StartBuildAsync( ); } + [Test] + public async Task StartBuildAsync_ParallelCorpus_OneOfEachMultipleCorpora() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateMultipleParallelCorpusEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = + [ + new TrainingCorpus + { + ParallelCorpusRef = "parallel-corpus1", + SourceFilters = new List() + { + new() + { + CorpusRef = "parallel-corpus1-source1", + TextIds = new List { "MAT" } + } + }, + TargetFilters = new List() + { + new() + { + CorpusRef = "parallel-corpus1-target1", + TextIds = new List { "MAT" } + } + } + } + ], + Pretranslate = [new PretranslateCorpus { ParallelCorpusRef = "parallel-corpus2" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.ParallelCorpus + { + Id = "parallel-corpus1", + SourceCorpora = + { + new List + { + new() + { + Id = "parallel-corpus1-source1", + Language = "es", + TrainOnTextIds = { "MAT" }, + Files = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "MAT" + } + }, + PretranslateAll = false, + TrainOnAll = false + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Id = "parallel-corpus1-target1", + Language = "en", + TrainOnTextIds = { "MAT" }, + Files = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "MAT" + } + }, + PretranslateAll = false, + TrainOnAll = false + } + } + } + }, + new V1.ParallelCorpus + { + Id = "parallel-corpus2", + SourceCorpora = + { + new List + { + new() + { + Id = "parallel-corpus2-source1", + Language = "es", + Files = + { + new V1.CorpusFile + { + Location = "file3.txt", + Format = FileFormat.Text, + TextId = "MRK" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Id = "parallel-corpus2-target1", + Language = "en", + Files = + { + new V1.CorpusFile + { + Location = "file4.txt", + Format = FileFormat.Text, + TextId = "MRK" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + } + } + } + } + ); + } + [Test] public async Task StartBuildAsync_TextIds_ParallelCorpus() { @@ -965,7 +1233,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } }, @@ -1004,7 +1272,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } } @@ -1098,7 +1366,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } }, @@ -1147,7 +1415,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } } } @@ -1531,7 +1799,7 @@ await env.Service.StartBuildAsync( SourceFilters = new List() { new() { CorpusRef = "parallel-corpus1-source1", ScriptureRange = "MAT 1;MRK" } - } + }, } ] } @@ -1591,7 +1859,7 @@ await env.Service.StartBuildAsync( } }, PretranslateAll = true, - TrainOnAll = true + TrainOnAll = false } }, TargetCorpora =