From 60304f8353b70d22e8e8f1639349c37550c3d905 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Wed, 27 Nov 2024 08:58:58 -0600 Subject: [PATCH] Update machine to 3.5.1 and small bug (#546) Preprocess should be async Make is so that write async can be called multiple times on S3Writer. Never have the S3 buffer grow above max size Update machine to 3.5.1 --- .../TranslationEngineServiceV1.cs | 5 +- .../Serval.Machine.Shared.csproj | 6 +- .../Services/PreprocessBuildJob.cs | 14 +-- .../Services/S3WriteStream.cs | 88 +++++++++++-------- .../src/Serval.Shared/Serval.Shared.csproj | 2 +- .../SIL.ServiceToolkit.csproj | 2 +- .../IParallelCorpusPreprocessingService.cs | 8 +- .../ParallelCorpusPreprocessingService.cs | 12 +-- .../ParallelCorpusProcessingServiceTests.cs | 6 +- 9 files changed, 84 insertions(+), 59 deletions(-) diff --git a/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs b/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs index fb7abc66..720a0126 100644 --- a/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs +++ b/src/Echo/src/EchoTranslationEngine/TranslationEngineServiceV1.cs @@ -82,9 +82,9 @@ await client.BuildStartedAsync( try { List pretranslationsRequests = []; - _parallelCorpusPreprocessingService.Preprocess( + await _parallelCorpusPreprocessingService.Preprocess( request.Corpora.Select(Map).ToList(), - row => { }, + row => Task.CompletedTask, (row, corpus) => { pretranslationsRequests.Add( @@ -97,6 +97,7 @@ await client.BuildStartedAsync( Translation = row.SourceSegment } ); + return Task.CompletedTask; }, false ); diff --git a/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj b/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj index f9eea0c5..4206b29e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj +++ b/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj @@ -36,9 +36,9 @@ - - - + + + diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index 46baa68d..831a6ad0 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -93,11 +93,11 @@ CancellationToken cancellationToken JsonObject? buildOptionsObject = null; if (buildOptions is not null) buildOptionsObject = JsonSerializer.Deserialize(buildOptions); + await using StreamWriter sourceTrainWriter = new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); await using StreamWriter targetTrainWriter = new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); - await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( $"builds/{buildId}/pretranslate.src.json", cancellationToken @@ -107,19 +107,19 @@ CancellationToken cancellationToken int trainCount = 0; int pretranslateCount = 0; pretranslateWriter.WriteStartArray(); - _parallelCorpusPreprocessingService.Preprocess( + await _parallelCorpusPreprocessingService.Preprocess( corpora, - row => + async row => { if (row.SourceSegment.Length > 0 || row.TargetSegment.Length > 0) { - sourceTrainWriter.Write($"{row.SourceSegment}\n"); - targetTrainWriter.Write($"{row.TargetSegment}\n"); + await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); } if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; }, - (row, corpus) => + async (row, corpus) => { if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) { @@ -134,6 +134,8 @@ CancellationToken cancellationToken pretranslateWriter.WriteEndObject(); pretranslateCount++; } + if (pretranslateWriter.BytesPending > 1024 * 1024) + await pretranslateWriter.FlushAsync(); }, (bool?)buildOptionsObject?["use_key_terms"] ?? true ); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/S3WriteStream.cs b/src/Machine/src/Serval.Machine.Shared/Services/S3WriteStream.cs index 4b623d6d..e1ba3494 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/S3WriteStream.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/S3WriteStream.cs @@ -15,6 +15,9 @@ ILoggerFactory loggerFactory private readonly List _uploadResponses = new List(); private readonly ILogger _logger = loggerFactory.CreateLogger(); + private readonly Stream _stream = new MemoryStream(); + private int _bytesWritten = 0; + public const int MaxPartSize = 5 * 1024 * 1024; public override bool CanRead => false; @@ -23,7 +26,7 @@ ILoggerFactory loggerFactory public override bool CanWrite => true; - public override long Length => 0; + public override long Length => _stream.Length; public override long Position { @@ -48,47 +51,60 @@ public override async ValueTask WriteAsync( CancellationToken cancellationToken = default ) { - try - { - using Stream stream = buffer.AsStream(); + // S3 buckets can only be written to in chunks of MaxPartSize + // therefore, break it into chunks, resetting the stream each time - int bytesWritten = 0; + while (buffer.Length + _stream.Position > MaxPartSize) + { + int toWrite = MaxPartSize - (int)_stream.Position; + await _stream.WriteAsync(buffer[..toWrite], cancellationToken); + await UploadPartAsync(cancellationToken); + buffer = buffer[toWrite..]; + } + // save the remaining buffer for future calls + await _stream.WriteAsync(buffer, cancellationToken); + } - while (stream.Length > bytesWritten) - { - int partNumber = _uploadResponses.Count + 1; - UploadPartRequest request = - new() - { - BucketName = _bucketName, - Key = _key, - UploadId = _uploadId, - PartNumber = partNumber, - InputStream = stream, - PartSize = MaxPartSize - }; - request.StreamTransferProgress += new EventHandler( - (_, e) => - { - _logger.LogDebug( - "Transferred {e.TransferredBytes}/{e.TotalBytes}", - e.TransferredBytes, - e.TotalBytes - ); - } - ); - UploadPartResponse response = await _client.UploadPartAsync(request, cancellationToken); - if (response.HttpStatusCode != HttpStatusCode.OK) + private async Task UploadPartAsync(CancellationToken cancellationToken = default) + { + if (_stream.Length == 0) + return; + try + { + _stream.Position = 0; + int partNumber = _uploadResponses.Count + 1; + UploadPartRequest request = + new() { - throw new HttpRequestException( - $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + BucketName = _bucketName, + Key = _key, + UploadId = _uploadId, + PartNumber = partNumber, + InputStream = _stream, + PartSize = MaxPartSize + }; + request.StreamTransferProgress += new EventHandler( + (_, e) => + { + _logger.LogDebug( + "Transferred {e.TransferredBytes}/{e.TotalBytes}", + e.TransferredBytes, + e.TotalBytes ); } + ); + UploadPartResponse response = await _client.UploadPartAsync(request, cancellationToken); + if (response.HttpStatusCode != HttpStatusCode.OK) + { + throw new HttpRequestException( + $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + } - _uploadResponses.Add(response); + _uploadResponses.Add(response); - bytesWritten += MaxPartSize; - } + _bytesWritten += MaxPartSize; + _stream.SetLength(0); } catch (Exception e) { @@ -104,6 +120,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc protected override void Dispose(bool disposing) { + UploadPartAsync().WaitAndUnwrapException(); try { if (disposing) @@ -164,6 +181,7 @@ protected override void Dispose(bool disposing) public override async ValueTask DisposeAsync() { + await UploadPartAsync(); try { if (_uploadResponses.Count == 0) diff --git a/src/Serval/src/Serval.Shared/Serval.Shared.csproj b/src/Serval/src/Serval.Shared/Serval.Shared.csproj index 75ccbd9b..0e504535 100644 --- a/src/Serval/src/Serval.Shared/Serval.Shared.csproj +++ b/src/Serval/src/Serval.Shared/Serval.Shared.csproj @@ -19,7 +19,7 @@ - + diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/SIL.ServiceToolkit.csproj b/src/ServiceToolkit/src/SIL.ServiceToolkit/SIL.ServiceToolkit.csproj index f9476b69..a64c5d85 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/SIL.ServiceToolkit.csproj +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/SIL.ServiceToolkit.csproj @@ -17,7 +17,7 @@ - + diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs index 1556de6d..1be70d5e 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs @@ -1,11 +1,13 @@ +using Nito.AsyncEx; + namespace SIL.ServiceToolkit.Utils; public interface IParallelCorpusPreprocessingService { - void Preprocess( + Task Preprocess( IReadOnlyList corpora, - Action train, - Action pretranslate, + Func train, + Func pretranslate, bool useKeyTerms = false ); } diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs index e75a2d59..25d6b55c 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs @@ -25,10 +25,10 @@ internal int Seed } } - public void Preprocess( + public async Task Preprocess( IReadOnlyList corpora, - Action train, - Action pretranslate, + Func train, + Func pretranslate, bool useKeyTerms = false ) { @@ -77,7 +77,7 @@ public void Preprocess( foreach (Row row in CollapseRanges(trainingRows)) { - train(row); + await train(row); } if (useKeyTerms) @@ -93,7 +93,7 @@ public void Preprocess( IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); foreach (ParallelTextRow row in parallelKeyTermsCorpus) { - train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1)); + await train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1)); } } } @@ -106,7 +106,7 @@ public void Preprocess( foreach (Row row in CollapseRanges(pretranslateCorpus.ToArray())) { - pretranslate(row, corpus); + await pretranslate(row, corpus); } } } diff --git a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs index 543332e2..033467f4 100644 --- a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs +++ b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs @@ -13,7 +13,7 @@ public class ParallelCorpusPreprocessingServiceTests ); [Test] - public void TestParallelCorpusPreprocessor() + public async Task TestParallelCorpusPreprocessor() { ParallelCorpusPreprocessingService processor = new(new CorpusService()); List corpora = @@ -73,17 +73,19 @@ public void TestParallelCorpusPreprocessor() ]; int trainCount = 0; int pretranslateCount = 0; - processor.Preprocess( + await processor.Preprocess( corpora, row => { if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; + return Task.CompletedTask; }, (row, _) => { if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) pretranslateCount++; + return Task.CompletedTask; }, false );