diff --git a/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs b/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs index fb7abc66..6f40a621 100644 --- a/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs +++ b/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs @@ -82,9 +82,9 @@ await client.BuildStartedAsync( try { List pretranslationsRequests = []; - _parallelCorpusPreprocessingService.Preprocess( + await _parallelCorpusPreprocessingService.Preprocess( request.Corpora.Select(Map).ToList(), - row => { }, + row => Task.CompletedTask, (row, corpus) => { pretranslationsRequests.Add( diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index eebdfcff..e3c31328 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -94,25 +94,27 @@ CancellationToken cancellationToken if (buildOptions is not null) buildOptionsObject = JsonSerializer.Deserialize(buildOptions); - using MemoryStream sourceStream = new(); - using MemoryStream targetStream = new(); - using MemoryStream pretranslationStream = new(); - - using StreamWriter targetTrainWriter = new(targetStream, Encoding.Default); - using StreamWriter sourceTrainWriter = new(sourceStream, Encoding.Default); - await using Utf8JsonWriter pretranslateWriter = new(pretranslationStream, PretranslateWriterOptions); + await using StreamWriter sourceTrainWriter = + new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); + await using StreamWriter targetTrainWriter = + new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); + await using Utf8JsonWriter pretranslateWriter = + new( + await _sharedFileService.OpenWriteAsync($"builds/{buildId}/pretranslate.src.json", cancellationToken), + PretranslateWriterOptions + ); int trainCount = 0; int pretranslateCount = 0; pretranslateWriter.WriteStartArray(); - _parallelCorpusPreprocessingService.Preprocess( + await _parallelCorpusPreprocessingService.Preprocess( corpora, - row => + async row => { if (row.SourceSegment.Length > 0 || row.TargetSegment.Length > 0) { - sourceTrainWriter.WriteLine(row.SourceSegment); - targetTrainWriter.WriteLine(row.TargetSegment); + await sourceTrainWriter.WriteLineAsync(row.SourceSegment); + await targetTrainWriter.WriteLineAsync(row.TargetSegment); } if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; @@ -138,22 +140,6 @@ CancellationToken cancellationToken pretranslateWriter.WriteEndArray(); - await sourceTrainWriter.FlushAsync(cancellationToken); - await targetTrainWriter.FlushAsync(cancellationToken); - await pretranslateWriter.FlushAsync(cancellationToken); - - async Task WriteStreamAsync(MemoryStream stream, string path) - { - stream.Position = 0; - await using StreamWriter writer = new(await _sharedFileService.OpenWriteAsync(path, cancellationToken)); - await writer.WriteAsync(Encoding.Default.GetString(stream.ToArray())); - await writer.FlushAsync(cancellationToken); - } - - await WriteStreamAsync(sourceStream, $"builds/{buildId}/train.src.txt"); - await WriteStreamAsync(targetStream, $"builds/{buildId}/train.trg.txt"); - await WriteStreamAsync(pretranslationStream, $"builds/{buildId}/pretranslate.src.json"); - return (trainCount, pretranslateCount); } diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs index 1556de6d..d1fcf6a7 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs @@ -1,10 +1,12 @@ +using Nito.AsyncEx; + namespace SIL.ServiceToolkit.Utils; public interface IParallelCorpusPreprocessingService { - void Preprocess( + Task Preprocess( IReadOnlyList corpora, - Action train, + Func train, Action pretranslate, bool useKeyTerms = false ); diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs index e75a2d59..7e1d9aae 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs @@ -25,9 +25,9 @@ internal int Seed } } - public void Preprocess( + public async Task Preprocess( IReadOnlyList corpora, - Action train, + Func train, Action pretranslate, bool useKeyTerms = false ) @@ -77,7 +77,7 @@ public void Preprocess( foreach (Row row in CollapseRanges(trainingRows)) { - train(row); + await train(row); } if (useKeyTerms) @@ -93,7 +93,7 @@ public void Preprocess( IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); foreach (ParallelTextRow row in parallelKeyTermsCorpus) { - train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1)); + await train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1)); } } } diff --git a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs index 543332e2..a58d891b 100644 --- a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs +++ b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs @@ -79,6 +79,7 @@ public void TestParallelCorpusPreprocessor() { if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; + return Task.CompletedTask; }, (row, _) => {