Skip to content

Commit

Permalink
Preprocess should be async
Browse files Browse the repository at this point in the history
Make is so that write async can be called multiple times on S3Writer.
Never have the S3 buffer grow above max size
Update machine to 3.5.1
  • Loading branch information
johnml1135 committed Nov 27, 2024
1 parent 1df752c commit 244d66f
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ await client.BuildStartedAsync(
try
{
List<InsertPretranslationsRequest> pretranslationsRequests = [];
_parallelCorpusPreprocessingService.Preprocess(
await _parallelCorpusPreprocessingService.Preprocess(
request.Corpora.Select(Map).ToList(),
row => { },
row => Task.CompletedTask,
(row, corpus) =>
{
pretranslationsRequests.Add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
<PackageReference Include="Hangfire.Mongo" Version="1.10.8" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.NewtonsoftJson" Version="8.0.8" />
<PackageReference Include="Microsoft.Extensions.Http.Polly" Version="8.0.8" />
<PackageReference Include="SIL.Machine" Version="3.5.0" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
<PackageReference Include="SIL.Machine.Morphology.HermitCrab" Version="3.5.0" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine.Morphology.HermitCrab\SIL.Machine.Morphology.HermitCrab.csproj')" />
<PackageReference Include="SIL.Machine.Translation.Thot" Version="3.5.0" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine.Translation.Thot\SIL.Machine.Translation.Thot.csproj')" />
<PackageReference Include="SIL.Machine" Version="3.5.1" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
<PackageReference Include="SIL.Machine.Morphology.HermitCrab" Version="3.5.1" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine.Morphology.HermitCrab\SIL.Machine.Morphology.HermitCrab.csproj')" />
<PackageReference Include="SIL.Machine.Translation.Thot" Version="3.5.1" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine.Translation.Thot\SIL.Machine.Translation.Thot.csproj')" />
<PackageReference Include="SIL.WritingSystems" Version="14.1.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ CancellationToken cancellationToken
JsonObject? buildOptionsObject = null;
if (buildOptions is not null)
buildOptionsObject = JsonSerializer.Deserialize<JsonObject>(buildOptions);

await using StreamWriter sourceTrainWriter =
new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken));
await using StreamWriter targetTrainWriter =
new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken));

await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync(
$"builds/{buildId}/pretranslate.src.json",
cancellationToken
Expand All @@ -107,14 +107,14 @@ CancellationToken cancellationToken
int trainCount = 0;
int pretranslateCount = 0;
pretranslateWriter.WriteStartArray();
_parallelCorpusPreprocessingService.Preprocess(
await _parallelCorpusPreprocessingService.Preprocess(
corpora,
row =>
async row =>
{
if (row.SourceSegment.Length > 0 || row.TargetSegment.Length > 0)
{
sourceTrainWriter.Write($"{row.SourceSegment}\n");
targetTrainWriter.Write($"{row.TargetSegment}\n");
await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n");
await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n");
}
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
Expand All @@ -134,6 +134,8 @@ CancellationToken cancellationToken
pretranslateWriter.WriteEndObject();
pretranslateCount++;
}
if (pretranslateWriter.BytesPending > 1024 * 1024)
pretranslateWriter.FlushAsync();
},
(bool?)buildOptionsObject?["use_key_terms"] ?? true
);
Expand Down
88 changes: 53 additions & 35 deletions src/Machine/src/Serval.Machine.Shared/Services/S3WriteStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ ILoggerFactory loggerFactory
private readonly List<UploadPartResponse> _uploadResponses = new List<UploadPartResponse>();
private readonly ILogger<S3WriteStream> _logger = loggerFactory.CreateLogger<S3WriteStream>();

private readonly Stream _stream = new MemoryStream();
private int _bytesWritten = 0;

public const int MaxPartSize = 5 * 1024 * 1024;

public override bool CanRead => false;
Expand All @@ -23,7 +26,7 @@ ILoggerFactory loggerFactory

public override bool CanWrite => true;

public override long Length => 0;
public override long Length => _stream.Length;

public override long Position
{
Expand All @@ -48,47 +51,60 @@ public override async ValueTask WriteAsync(
CancellationToken cancellationToken = default
)
{
try
{
using Stream stream = buffer.AsStream();
// S3 buckets can only be written to in chunks of MaxPartSize
// therefore, break it into chunks, resetting the stream each time

int bytesWritten = 0;
while (buffer.Length + _stream.Position > MaxPartSize)
{
int toWrite = MaxPartSize - (int)_stream.Position;
await _stream.WriteAsync(buffer[..toWrite], cancellationToken);
await UploadPartAsync(cancellationToken);
buffer = buffer[toWrite..];
}
// save the remaining buffer for future calls
await _stream.WriteAsync(buffer, cancellationToken);
}

while (stream.Length > bytesWritten)
{
int partNumber = _uploadResponses.Count + 1;
UploadPartRequest request =
new()
{
BucketName = _bucketName,
Key = _key,
UploadId = _uploadId,
PartNumber = partNumber,
InputStream = stream,
PartSize = MaxPartSize
};
request.StreamTransferProgress += new EventHandler<StreamTransferProgressArgs>(
(_, e) =>
{
_logger.LogDebug(
"Transferred {e.TransferredBytes}/{e.TotalBytes}",
e.TransferredBytes,
e.TotalBytes
);
}
);
UploadPartResponse response = await _client.UploadPartAsync(request, cancellationToken);
if (response.HttpStatusCode != HttpStatusCode.OK)
private async Task UploadPartAsync(CancellationToken cancellationToken = default)
{
if (_stream.Length == 0)
return;
try
{
_stream.Position = 0;
int partNumber = _uploadResponses.Count + 1;
UploadPartRequest request =
new()
{
throw new HttpRequestException(
$"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}"
BucketName = _bucketName,
Key = _key,
UploadId = _uploadId,
PartNumber = partNumber,
InputStream = _stream,
PartSize = MaxPartSize
};
request.StreamTransferProgress += new EventHandler<StreamTransferProgressArgs>(
(_, e) =>
{
_logger.LogDebug(
"Transferred {e.TransferredBytes}/{e.TotalBytes}",
e.TransferredBytes,
e.TotalBytes
);
}
);
UploadPartResponse response = await _client.UploadPartAsync(request, cancellationToken);
if (response.HttpStatusCode != HttpStatusCode.OK)
{
throw new HttpRequestException(
$"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}"
);
}

_uploadResponses.Add(response);
_uploadResponses.Add(response);

bytesWritten += MaxPartSize;
}
_bytesWritten += MaxPartSize;
_stream.SetLength(0);
}
catch (Exception e)
{
Expand All @@ -104,6 +120,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc

protected override void Dispose(bool disposing)
{
UploadPartAsync().WaitAndUnwrapException();
try
{
if (disposing)
Expand Down Expand Up @@ -164,6 +181,7 @@ protected override void Dispose(bool disposing)

public override async ValueTask DisposeAsync()
{
await UploadPartAsync();
try
{
if (_uploadResponses.Count == 0)
Expand Down
2 changes: 1 addition & 1 deletion src/Serval/src/Serval.Shared/Serval.Shared.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<PackageReference Include="Grpc.Core.Api" Version="2.65.0" />
<PackageReference Include="Grpc.HealthCheck" Version="2.65.0" />
<PackageReference Include="Grpc.Net.ClientFactory" Version="2.65.0" />
<PackageReference Include="SIL.Machine" Version="3.5.0" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
<PackageReference Include="SIL.Machine" Version="3.5.1" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
<PackageReference Include="Microsoft.FeatureManagement.AspNetCore" Version="3.5.0" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<PackageReference Include="SIL.WritingSystems" Version="14.1.1" />
<PackageReference Include="System.Text.RegularExpressions" Version="4.3.1" />
<PackageReference Include="SIL.Scripture" Version="12.0.1"/>
<PackageReference Include="SIL.Machine" Version="3.5.0" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
<PackageReference Include="SIL.Machine" Version="3.5.1" Condition="!Exists('..\..\..\..\..\machine\src\SIL.Machine\SIL.Machine.csproj')" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using Nito.AsyncEx;

namespace SIL.ServiceToolkit.Utils;

public interface IParallelCorpusPreprocessingService
{
void Preprocess(
Task Preprocess(
IReadOnlyList<ParallelCorpus> corpora,
Action<Row> train,
Func<Row, Task> train,
Action<Row, ParallelCorpus> pretranslate,
bool useKeyTerms = false
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ internal int Seed
}
}

public void Preprocess(
public async Task Preprocess(
IReadOnlyList<ParallelCorpus> corpora,
Action<Row> train,
Func<Row, Task> train,
Action<Row, ParallelCorpus> pretranslate,
bool useKeyTerms = false
)
Expand Down Expand Up @@ -77,7 +77,7 @@ public void Preprocess(

foreach (Row row in CollapseRanges(trainingRows))
{
train(row);
await train(row);
}

if (useKeyTerms)
Expand All @@ -93,7 +93,7 @@ public void Preprocess(
IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus);
foreach (ParallelTextRow row in parallelKeyTermsCorpus)
{
train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1));
await train(new Row(row.TextId, row.Refs, row.SourceText, row.TargetText, 1));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class ParallelCorpusPreprocessingServiceTests
);

[Test]
public void TestParallelCorpusPreprocessor()
public async Task TestParallelCorpusPreprocessor()
{
ParallelCorpusPreprocessingService processor = new(new CorpusService());
List<ParallelCorpus> corpora =
Expand Down Expand Up @@ -73,12 +73,13 @@ public void TestParallelCorpusPreprocessor()
];
int trainCount = 0;
int pretranslateCount = 0;
processor.Preprocess(
await processor.Preprocess(
corpora,
row =>
{
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
return Task.CompletedTask;
},
(row, _) =>
{
Expand Down

0 comments on commit 244d66f

Please sign in to comment.