diff --git a/src/Serval/src/Serval.Translation/Services/EngineService.cs b/src/Serval/src/Serval.Translation/Services/EngineService.cs index 6027c9ee..c5dd60dc 100644 --- a/src/Serval/src/Serval.Translation/Services/EngineService.cs +++ b/src/Serval/src/Serval.Translation/Services/EngineService.cs @@ -1,4 +1,4 @@ -using MassTransit.Mediator; +using MassTransit.Mediator; using Serval.Translation.V1; namespace Serval.Translation.Services; @@ -227,8 +227,19 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellationTok StartBuildRequest request; if (engine.ParallelCorpora.Any()) { - var trainOn = build.TrainOn?.ToDictionary(c => c.ParallelCorpusRef!); - var pretranslate = build.Pretranslate?.ToDictionary(c => c.ParallelCorpusRef!); + Dictionary? trainOn = build.TrainOn?.ToDictionary(c => c.ParallelCorpusRef!); + Dictionary? pretranslate = build.Pretranslate?.ToDictionary(c => + c.ParallelCorpusRef! + ); + IReadOnlyList parallelCorpora = engine + .ParallelCorpora.Where(pc => + trainOn == null + || trainOn.ContainsKey(pc.Id) + || pretranslate == null + || pretranslate.ContainsKey(pc.Id) + ) + .ToList(); + request = new StartBuildRequest { EngineType = engine.Type, @@ -236,22 +247,26 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellationTok BuildId = build.Id, Corpora = { - engine.ParallelCorpora.Select(c => - Map( - c, - trainOn?.GetValueOrDefault(c.Id), - pretranslate?.GetValueOrDefault(c.Id), - trainOn is null, - pretranslate is null - ) + parallelCorpora.Select(c => + Map(c, trainOn?.GetValueOrDefault(c.Id), pretranslate?.GetValueOrDefault(c.Id)) ) } }; } else { - var pretranslate = build.Pretranslate?.ToDictionary(c => c.CorpusRef!); - var trainOn = build.TrainOn?.ToDictionary(c => c.CorpusRef!); + Dictionary? trainOn = build.TrainOn?.ToDictionary(c => c.CorpusRef!); + Dictionary? pretranslate = build.Pretranslate?.ToDictionary(c => + c.CorpusRef! + ); + IReadOnlyList corpora = engine + .Corpora.Where(c => + trainOn == null + || trainOn.ContainsKey(c.Id) + || pretranslate == null + || pretranslate.ContainsKey(c.Id) + ) + .ToList(); request = new StartBuildRequest { @@ -260,14 +275,8 @@ pretranslate is null BuildId = build.Id, Corpora = { - engine.Corpora.Select(c => - Map( - c, - trainOn?.GetValueOrDefault(c.Id), - pretranslate?.GetValueOrDefault(c.Id), - trainOn is null, - pretranslate is null - ) + corpora.Select(c => + Map(c, trainOn?.GetValueOrDefault(c.Id), pretranslate?.GetValueOrDefault(c.Id)) ) } }; @@ -604,13 +613,7 @@ private Models.WordGraphArc Map(V1.WordGraphArc source) }; } - private V1.ParallelCorpus Map( - Corpus source, - TrainingCorpus? trainingCorpus, - PretranslateCorpus? pretranslateCorpus, - bool noTrainingCorpusDefined, - bool noPretranslateCorpusDefined - ) + private V1.ParallelCorpus Map(Corpus source, TrainingCorpus? trainingCorpus, PretranslateCorpus? pretranslateCorpus) { IEnumerable sourceFiles = source.SourceFiles.Select(Map); IEnumerable targetFiles = source.TargetFiles.Select(Map); @@ -619,10 +622,7 @@ bool noPretranslateCorpusDefined V1.MonolingualCorpus targetCorpus = new() { Language = source.TargetLanguage, Files = { source.TargetFiles.Select(Map) } }; - if ( - noTrainingCorpusDefined - || (trainingCorpus is not null && trainingCorpus.TextIds is null && trainingCorpus.ScriptureRange is null) - ) + if (trainingCorpus is null || (trainingCorpus.TextIds is null && trainingCorpus.ScriptureRange is null)) { sourceCorpus.TrainOnAll = true; targetCorpus.TrainOnAll = true; @@ -663,12 +663,8 @@ bool noPretranslateCorpusDefined } } if ( - noPretranslateCorpusDefined - || ( - pretranslateCorpus is not null - && pretranslateCorpus.TextIds is null - && pretranslateCorpus.ScriptureRange is null - ) + pretranslateCorpus is null + || (pretranslateCorpus.TextIds is null && pretranslateCorpus.ScriptureRange is null) ) { sourceCorpus.PretranslateAll = true; @@ -717,9 +713,7 @@ pretranslateCorpus is not null private V1.ParallelCorpus Map( Models.ParallelCorpus source, TrainingCorpus? trainingCorpus, - PretranslateCorpus? pretranslateCorpus, - bool noTrainingCorpusDefined, - bool noPretranslateCorpusDefined + PretranslateCorpus? pretranslateCorpus ) { string? referenceFileLocation = @@ -737,9 +731,7 @@ bool noPretranslateCorpusDefined sc, trainingCorpus?.SourceFilters?.Where(sf => sf.CorpusRef == sc.Id).FirstOrDefault(), pretranslateCorpus?.SourceFilters?.Where(sf => sf.CorpusRef == sc.Id).FirstOrDefault(), - referenceFileLocation, - noTrainingCorpusDefined, - noPretranslateCorpusDefined + referenceFileLocation ) ) }, @@ -750,9 +742,7 @@ bool noPretranslateCorpusDefined tc, trainingCorpus?.TargetFilters?.Where(sf => sf.CorpusRef == tc.Id).FirstOrDefault(), null, - referenceFileLocation, - noTrainingCorpusDefined, - noPretranslateCorpusDefined + referenceFileLocation ) ) } @@ -763,9 +753,7 @@ private V1.MonolingualCorpus Map( Models.MonolingualCorpus source, ParallelCorpusFilter? trainingFilter, ParallelCorpusFilter? pretranslateFilter, - string? referenceFileLocation, - bool noTrainingCorpusDefined, - bool noPretranslateCorpusDefined + string? referenceFileLocation ) { Dictionary? trainOnChapters = null; @@ -813,10 +801,7 @@ pretranslateFilter is not null Files = { source.Files.Select(Map) } }; - if ( - noTrainingCorpusDefined - || (trainingFilter is not null && trainingFilter.TextIds is null && trainingFilter.ScriptureRange is null) - ) + if (trainingFilter is null || (trainingFilter.TextIds is null && trainingFilter.ScriptureRange is null)) { corpus.TrainOnAll = true; } @@ -829,12 +814,8 @@ pretranslateFilter is not null } if ( - noPretranslateCorpusDefined - || ( - pretranslateFilter is not null - && pretranslateFilter.TextIds is null - && pretranslateFilter.ScriptureRange is null - ) + pretranslateFilter is null + || (pretranslateFilter.TextIds is null && pretranslateFilter.ScriptureRange is null) ) { corpus.PretranslateAll = true; diff --git a/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs b/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs index a71e8908..be53d27d 100644 --- a/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs +++ b/src/Serval/test/Serval.Translation.Tests/Services/EngineServiceTests.cs @@ -392,6 +392,80 @@ await env.Service.StartBuildAsync( ); } + [Test] + public async Task StartBuildAsync_OneOfMultipleCorpora() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateMultipleCorporaEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1" }], + Pretranslate = [new PretranslateCorpus { CorpusRef = "corpus1" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.ParallelCorpus + { + Id = "corpus1", + SourceCorpora = + { + new List + { + new() + { + Language = "es", + Files = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = true, + TrainOnAll = true + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Language = "en", + Files = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + PretranslateAll = true, + TrainOnAll = true + } + } + } + } + } + } + ); + } + [Test] public async Task StartBuildAsync_TextFilesScriptureRangeSpecified() { @@ -709,6 +783,106 @@ await env.Service.StartBuildAsync( ); } + [Test] + public async Task StartBuildAsync_ParallelCorpus_OneOfMultipleCorpora() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateMultipleParallelCorpusEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = + [ + new TrainingCorpus + { + ParallelCorpusRef = "parallel-corpus1", + SourceFilters = new List() + { + new() + { + CorpusRef = "parallel-corpus1-source1", + TextIds = new List { "MAT" } + } + }, + TargetFilters = new List() + { + new() + { + CorpusRef = "parallel-corpus1-target1", + TextIds = new List { "MAT" } + } + } + } + ], + Pretranslate = [new PretranslateCorpus { ParallelCorpusRef = "parallel-corpus1" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.ParallelCorpus + { + Id = "parallel-corpus1", + SourceCorpora = + { + new List + { + new() + { + Id = "parallel-corpus1-source1", + Language = "es", + TrainOnTextIds = { "MAT" }, + Files = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "MAT" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + }, + TargetCorpora = + { + new List + { + new() + { + Id = "parallel-corpus1-target1", + Language = "en", + TrainOnTextIds = { "MAT" }, + Files = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "MAT" + } + }, + PretranslateAll = true, + TrainOnAll = false + } + } + } + } + } + } + ); + } + [Test] public async Task StartBuildAsync_TextIds_ParallelCorpus() { @@ -1706,6 +1880,75 @@ public async Task CreateEngineWithTextFilesAsync() return engine; } + public async Task CreateMultipleCorporaEngineWithTextFilesAsync() + { + var engine = new Engine + { + Id = "engine1", + Owner = "owner1", + SourceLanguage = "es", + TargetLanguage = "en", + Type = "Smt", + Corpora = new Models.Corpus[] + { + new() + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + SourceFiles = + [ + new() + { + Id = "file1", + Filename = "file1.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "text1" + } + ], + TargetFiles = + [ + new() + { + Id = "file2", + Filename = "file2.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "text1" + } + ], + }, + new() + { + Id = "corpus2", + SourceLanguage = "es", + TargetLanguage = "en", + SourceFiles = + [ + new() + { + Id = "file3", + Filename = "file3.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "text1" + } + ], + TargetFiles = + [ + new() + { + Id = "file4", + Filename = "file4.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "text1" + } + ], + } + } + }; + await Engines.InsertAsync(engine); + return engine; + } + public async Task CreateEngineWithParatextProjectAsync() { var engine = new Engine @@ -1840,6 +2083,107 @@ public async Task CreateParallelCorpusEngineWithTextFilesAsync() return engine; } + public async Task CreateMultipleParallelCorpusEngineWithTextFilesAsync() + { + var engine = new Engine + { + Id = "engine1", + Owner = "owner1", + SourceLanguage = "es", + TargetLanguage = "en", + Type = "Smt", + ParallelCorpora = new Models.ParallelCorpus[] + { + new() + { + Id = "parallel-corpus1", + SourceCorpora = new List() + { + new() + { + Id = "parallel-corpus1-source1", + Name = "", + Language = "es", + Files = + [ + new() + { + Id = "file1", + Filename = "file1.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "MAT" + } + ] + } + }, + TargetCorpora = new List() + { + new() + { + Id = "parallel-corpus1-target1", + Name = "", + Language = "en", + Files = + [ + new() + { + Id = "file2", + Filename = "file2.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "MAT" + } + ] + } + } + }, + new() + { + Id = "parallel-corpus2", + SourceCorpora = new List() + { + new() + { + Id = "parallel-corpus2-source1", + Name = "", + Language = "es", + Files = + [ + new() + { + Id = "file3", + Filename = "file3.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "MRK" + } + ] + } + }, + TargetCorpora = new List() + { + new() + { + Id = "parallel-corpus2-target1", + Name = "", + Language = "en", + Files = + [ + new() + { + Id = "file4", + Filename = "file4.txt", + Format = Shared.Contracts.FileFormat.Text, + TextId = "MRK" + } + ] + } + } + } + } + }; + await Engines.InsertAsync(engine); + return engine; + } + public async Task CreateParallelCorpusEngineWithParatextProjectAsync() { var engine = new Engine