diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index 849a9f08..062da3ec 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -178,14 +178,14 @@ public static IMachineBuilder AddHangfireJobServer( switch (engineType) { case EngineType.SmtTransfer: - builder.Services.AddSingleton(); - builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); + builder.AddThot(); queues.Add("smt_transfer"); break; case EngineType.Nmt: queues.Add("nmt"); break; case EngineType.Statistical: + builder.AddThot(); queues.Add("statistical"); break; default: diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs index 816302ca..abbebf6f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs @@ -22,6 +22,71 @@ ILanguageTagService languageTagService { private readonly ILanguageTagService _languageTagService = languageTagService; + protected override int WriteInferences(Utf8JsonWriter inferenceWriter, ParallelCorpus corpus) + { + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus + .SourceCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] targetCorpora = corpus + .TargetCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + + int inferenceCount = 0; + + ITextCorpus targetCorpus = targetCorpora.Length > 0 ? targetCorpora[0].TextCorpus : new DictionaryTextCorpus(); + ITextCorpus? sourcePretranslateCorpus = sourceCorpora + .Select(sc => + { + ITextCorpus textCorpus = sc.TextCorpus; + if (sc.Corpus.InferenceTextIds is not null) + { + textCorpus = textCorpus.FilterTexts( + sc.Corpus.InferenceTextIds.Except(sc.Corpus.TrainOnTextIds ?? new()) + ); + } + return textCorpus.Where(row => + row.Ref is not ScriptureRef sr + || sc.Corpus.InferenceChapters is null + || ( + IsInChapters(sr, sc.Corpus.InferenceChapters) + && !IsInChapters(sr, sc.Corpus.TrainOnChapters ?? new()) + ) + ); + }) + .ToArray() + .FirstOrDefault(); + + if (sourcePretranslateCorpus != null) + { + foreach (Row row in AlignInferenceCorpus(sourcePretranslateCorpus, targetCorpus)) + { + if (row.SourceSegment.Length > 0 && (row.TargetSegment.Length == 0 || !targetCorpus.Any())) + WriteRow(inferenceWriter, corpus.Id, row.TextId, row.Refs, row.SourceSegment); + inferenceCount++; + } + } + return inferenceCount; + } + + private static void WriteRow( + Utf8JsonWriter writer, + string corpusId, + string textId, + IReadOnlyList refs, + string translation + ) + { + writer.WriteStartObject(); + writer.WriteString("corpusId", corpusId); + writer.WriteString("textId", textId); + writer.WriteStartArray("refs"); + foreach (object rowRef in refs) + writer.WriteStringValue(rowRef.ToString()); + writer.WriteEndArray(); + writer.WriteString("translation", translation); + writer.WriteEndObject(); + } + protected override bool ResolveLanguageCodeForBaseModel(string languageCode, out string resolvedCode) { return _languageTagService.ConvertToFlores200Code(languageCode, out resolvedCode); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index 4d149857..2988556e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -3,12 +3,13 @@ public class PreprocessBuildJob : HangfireBuildJob> where TEngine : ITrainingEngine { - private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; + private static readonly JsonWriterOptions InferenceWriterOptions = new() { Indented = true }; internal BuildJobRunnerType TrainJobRunnerType { get; init; } = BuildJobRunnerType.ClearML; protected readonly ISharedFileService SharedFileService; protected readonly ICorpusService CorpusService; + protected string InferenceFilename { get; init; } = "pretranslate.src.json"; private int _seed = 1234; private Random _random; @@ -113,169 +114,185 @@ CancellationToken cancellationToken await using StreamWriter targetTrainWriter = new(await SharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); - await using Stream pretranslateStream = await SharedFileService.OpenWriteAsync( - $"builds/{buildId}/pretranslate.src.json", + await using Stream inferenceStream = await SharedFileService.OpenWriteAsync( + $"builds/{buildId}/{InferenceFilename}", cancellationToken ); - await using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions); + await using Utf8JsonWriter inferenceWriter = new(inferenceStream, InferenceWriterOptions); int trainCount = 0; - int pretranslateCount = 0; - pretranslateWriter.WriteStartArray(); + int inferenceCount = 0; + inferenceWriter.WriteStartArray(); foreach (ParallelCorpus corpus in corpora) { - (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus - .SourceCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) - .ToArray(); - ITextCorpus[] sourceTrainingCorpora = sourceCorpora - .Select(sc => - { - ITextCorpus textCorpus = sc.TextCorpus; - if (sc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(sc.Corpus.TrainOnTextIds); - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || sc.Corpus.TrainOnChapters is null - || IsInChapters(sr, sc.Corpus.TrainOnChapters) - ); - }) - .ToArray(); - ITextCorpus? sourcePretranslateCorpus = sourceCorpora - .Select(sc => - { - ITextCorpus textCorpus = sc.TextCorpus; - if (sc.Corpus.InferenceTextIds is not null) - { - textCorpus = textCorpus.FilterTexts( - sc.Corpus.InferenceTextIds.Except(sc.Corpus.TrainOnTextIds ?? new()) - ); - } - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || sc.Corpus.InferenceChapters is null - || ( - IsInChapters(sr, sc.Corpus.InferenceChapters) - && !IsInChapters(sr, sc.Corpus.TrainOnChapters ?? new()) - ) - ); - }) - .ToArray() - .FirstOrDefault(); + trainCount += await WriteTrainingAsync(sourceTrainWriter, targetTrainWriter, corpus, buildOptionsObject); + inferenceCount += WriteInferences(inferenceWriter, corpus); + } - (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] targetCorpora = corpus - .TargetCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) - .ToArray(); - ITextCorpus[] targetTrainingCorpora = targetCorpora - .Select(tc => - { - ITextCorpus textCorpus = tc.TextCorpus; - if (tc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(tc.Corpus.TrainOnTextIds); - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || tc.Corpus.TrainOnChapters is null - || IsInChapters(sr, tc.Corpus.TrainOnChapters) - ); - }) - .ToArray(); + inferenceWriter.WriteEndArray(); - if (sourceCorpora.Length == 0) - continue; + return (trainCount, inferenceCount); + } + + protected virtual async Task WriteTrainingAsync( + StreamWriter sourceTrainWriter, + StreamWriter targetTrainWriter, + ParallelCorpus corpus, + JsonObject? buildOptionsObject + ) + { + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus + .SourceCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] targetCorpora = corpus + .TargetCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + + int trainCount = 0; + ITextCorpus[] sourceTrainingCorpora = sourceCorpora + .Select(sc => FilterCorpus(sc, sc.Corpus.TrainOnTextIds, sc.Corpus.TrainOnChapters)) + .ToArray(); + + ITextCorpus[] targetTrainingCorpora = targetCorpora + .Select(tc => FilterCorpus(tc, tc.Corpus.TrainOnTextIds, tc.Corpus.TrainOnChapters)) + .ToArray(); - int skipCount = 0; - foreach (Row?[] rows in AlignTrainCorpus(sourceTrainingCorpora, targetTrainingCorpora)) + if (sourceCorpora.Length == 0) + return trainCount; + + int skipCount = 0; + foreach (Row?[] rows in AlignTrainCorpus(sourceTrainingCorpora, targetTrainingCorpora)) + { + if (skipCount > 0) { - if (skipCount > 0) - { - skipCount--; - continue; - } + skipCount--; + continue; + } - Row[] trainRows = rows.Where(r => r is not null).Cast().ToArray(); - if (trainRows.Length > 0) + Row[] trainRows = rows.Where(r => r is not null).Cast().ToArray(); + if (trainRows.Length > 0) + { + Row row = trainRows[0]; + if (rows.Length > 1) { - Row row = trainRows[0]; - if (rows.Length > 1) + Row[] nonEmptyRows = trainRows.Where(r => r.SourceSegment.Length > 0).ToArray(); + Row[] targetNonEmptyRows = nonEmptyRows.Where(r => r.TargetSegment.Length > 0).ToArray(); + if (targetNonEmptyRows.Length > 0) + nonEmptyRows = targetNonEmptyRows; + if (nonEmptyRows.Length > 0) { - Row[] nonEmptyRows = trainRows.Where(r => r.SourceSegment.Length > 0).ToArray(); - Row[] targetNonEmptyRows = nonEmptyRows.Where(r => r.TargetSegment.Length > 0).ToArray(); - if (targetNonEmptyRows.Length > 0) - nonEmptyRows = targetNonEmptyRows; - if (nonEmptyRows.Length > 0) + nonEmptyRows = nonEmptyRows + .GroupBy(r => r.SourceSegment) + .Select(group => group.First()) + .ToArray(); { nonEmptyRows = nonEmptyRows .GroupBy(r => r.SourceSegment) .Select(group => group.First()) .ToArray(); - { - nonEmptyRows = nonEmptyRows - .GroupBy(r => r.SourceSegment) - .Select(group => group.First()) - .ToArray(); - row = nonEmptyRows[_random.Next(nonEmptyRows.Length)]; - } + row = nonEmptyRows[_random.Next(nonEmptyRows.Length)]; } } - - await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); - skipCount = row.RowCount - 1; - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) - trainCount++; } + + await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); + skipCount = row.RowCount - 1; + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + trainCount++; } + } - if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) + if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) + { + ITextCorpus? sourceTermCorpus = CorpusService + .CreateTermCorpora(corpus.SourceCorpora.SelectMany(sc => sc.Files).ToList()) + .FirstOrDefault(); + ITextCorpus? targetTermCorpus = CorpusService + .CreateTermCorpora(corpus.TargetCorpora.SelectMany(tc => tc.Files).ToList()) + .FirstOrDefault(); + if (sourceTermCorpus is not null && targetTermCorpus is not null) { - ITextCorpus? sourceTermCorpus = CorpusService - .CreateTermCorpora(corpus.SourceCorpora.SelectMany(sc => sc.Files).ToList()) - .FirstOrDefault(); - ITextCorpus? targetTermCorpus = CorpusService - .CreateTermCorpora(corpus.TargetCorpora.SelectMany(tc => tc.Files).ToList()) - .FirstOrDefault(); - if (sourceTermCorpus is not null && targetTermCorpus is not null) + IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); + foreach (ParallelTextRow row in parallelKeyTermsCorpus) { - IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); - foreach (ParallelTextRow row in parallelKeyTermsCorpus) - { - await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); - trainCount++; - } + await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); + trainCount++; } } - void WriteRow(Utf8JsonWriter writer, string textId, IReadOnlyList refs, string translation) + } + return trainCount; + } + + protected virtual int WriteInferences(Utf8JsonWriter inferenceWriter, ParallelCorpus corpus) + { + return 0; + } + + protected static IEnumerable AlignInferenceCorpus(ITextCorpus srcCorpus, ITextCorpus trgCorpus) + { + int rowCount = 0; + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + List refs = []; + string textId = ""; + + srcCorpus = srcCorpus.Transform(CleanSegment); + trgCorpus = trgCorpus.Transform(CleanSegment); + + foreach (ParallelTextRow row in srcCorpus.AlignRows(trgCorpus, allSourceRows: true)) + { + if (!row.IsTargetRangeStart && row.IsTargetInRange) { - writer.WriteStartObject(); - writer.WriteString("corpusId", corpus.Id); - writer.WriteString("textId", textId); - writer.WriteStartArray("refs"); - foreach (object rowRef in refs) - writer.WriteStringValue(rowRef.ToString()); - writer.WriteEndArray(); - writer.WriteString("translation", translation); - writer.WriteEndObject(); - pretranslateCount++; + refs.AddRange(row.TargetRefs); + if (row.SourceText.Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(row.SourceText); + } + rowCount++; } - - ITextCorpus targetCorpus = - targetCorpora.Length > 0 ? targetCorpora[0].TextCorpus : new DictionaryTextCorpus(); - if (sourcePretranslateCorpus != null) + else { - foreach (Row row in AlignPretranslateCorpus(sourcePretranslateCorpus, targetCorpus)) + if (rowCount > 0) { - if (row.SourceSegment.Length > 0 && (row.TargetSegment.Length == 0 || !targetCorpus.Any())) - WriteRow(pretranslateWriter, row.TextId, row.Refs, row.SourceSegment); + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + textId = ""; + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + rowCount = 0; } + + textId = row.TextId; + refs.AddRange(row.TargetRefs); + srcSegBuffer.Append(row.SourceText); + trgSegBuffer.Append(row.TargetText); + rowCount++; } } - pretranslateWriter.WriteEndArray(); + if (rowCount > 0) + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + } - return (trainCount, pretranslateCount); + protected static ITextCorpus FilterCorpus( + (MonolingualCorpus Corpus, ITextCorpus TextCorpus) corpus, + HashSet? filterOnTextIds, + Dictionary>? filterOnChapters + ) + { + ITextCorpus textCorpus = corpus.TextCorpus; + if (filterOnTextIds is not null) + textCorpus = textCorpus.FilterTexts(filterOnTextIds); + return textCorpus.Where(row => + row.Ref is not ScriptureRef sr || filterOnChapters is null || IsInChapters(sr, filterOnChapters) + ); } - private static bool IsInChapters(ScriptureRef sr, Dictionary> selection) + protected static bool IsInChapters(ScriptureRef sr, Dictionary> selection) { return selection.TryGetValue(sr.Book, out HashSet? chapters) && chapters != null @@ -423,55 +440,6 @@ IReadOnlyList trgCorpora } } - private static IEnumerable AlignPretranslateCorpus(ITextCorpus srcCorpus, ITextCorpus trgCorpus) - { - int rowCount = 0; - StringBuilder srcSegBuffer = new(); - StringBuilder trgSegBuffer = new(); - List refs = []; - string textId = ""; - - srcCorpus = srcCorpus.Transform(CleanSegment); - trgCorpus = trgCorpus.Transform(CleanSegment); - - foreach (ParallelTextRow row in srcCorpus.AlignRows(trgCorpus, allSourceRows: true)) - { - if (!row.IsTargetRangeStart && row.IsTargetInRange) - { - refs.AddRange(row.TargetRefs); - if (row.SourceText.Length > 0) - { - if (srcSegBuffer.Length > 0) - srcSegBuffer.Append(' '); - srcSegBuffer.Append(row.SourceText); - } - rowCount++; - } - else - { - if (rowCount > 0) - { - if (trgSegBuffer.Length == 0) - yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); - textId = ""; - srcSegBuffer.Clear(); - trgSegBuffer.Clear(); - refs.Clear(); - rowCount = 0; - } - - textId = row.TextId; - refs.AddRange(row.TargetRefs); - srcSegBuffer.Append(row.SourceText); - trgSegBuffer.Append(row.TargetText); - rowCount++; - } - } - - if (rowCount > 0) - yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); - } - protected record Row( string TextId, IReadOnlyList Refs, @@ -486,7 +454,7 @@ protected virtual bool ResolveLanguageCodeForBaseModel(string languageCode, out return true; } - private static TextRow CleanSegment(TextRow row) + protected static TextRow CleanSegment(TextRow row) { if (row.Text == "...") row.Segment = []; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs index 3722336f..b063b8f7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs @@ -52,6 +52,8 @@ await _client.BuildRestartingAsync( ); break; case ServalWordAlignmentPlatformOutboxConstants.InsertInferences: + var jsonSerializerOptions = new JsonSerializerOptions(JsonSerializerOptions); + jsonSerializerOptions.Converters.Add(new WordAlignmentJsonConverter()); IAsyncEnumerable wordAlignments = JsonSerializer .DeserializeAsyncEnumerable( contentStream!, @@ -108,3 +110,31 @@ await _client.IncrementTrainEngineCorpusSizeAsync( } } } + +public class WordAlignmentJsonConverter : JsonConverter +{ + public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + switch (reader.TokenType) + { + case JsonTokenType.True: + return true; + case JsonTokenType.False: + return false; + case JsonTokenType.Number when reader.TryGetInt64(out long l): + return l; + case JsonTokenType.Number: + return reader.GetDouble(); + case JsonTokenType.String: + var str = reader.GetString(); + if (SIL.Machine.Corpora.AlignedWordPair.TryParse(str, out var alignedWordPair)) + return alignedWordPair; + return str!; + default: + throw new JsonException(); + } + } + + public override void Write(Utf8JsonWriter writer, object objectToWrite, JsonSerializerOptions options) => + JsonSerializer.Serialize(writer, objectToWrite, objectToWrite.GetType(), options); +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs index 91e6da70..3a7960e9 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs @@ -26,6 +26,44 @@ IOptionsMonitor engineOptions private readonly IOptionsMonitor _engineOptions = engineOptions; private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; + protected override async Task DoWorkAsync( + string engineId, + string buildId, + (int, double) data, + string? buildOptions, + CancellationToken cancellationToken + ) + { + (int corpusSize, double confidence) = data; + + await using ( + Stream wordAlignmentStream = await SharedFileService.OpenReadAsync( + $"builds/{buildId}/word_alignment_outputs.json", + cancellationToken + ) + ) + { + await PlatformService.InsertInferencesAsync(engineId, wordAlignmentStream, cancellationToken); + } + + int additionalCorpusSize = await SaveModelAsync(engineId, buildId); + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildCompletedAsync( + buildId, + corpusSize + additionalCorpusSize, + Math.Round(confidence, 2, MidpointRounding.AwayFromZero), + ct + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, ct); + }, + cancellationToken: CancellationToken.None + ); + + Logger.LogInformation("Build completed ({0}).", buildId); + } + protected override async Task SaveModelAsync(string engineId, string buildId) { IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs index 94fbd202..20a9b6d7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs @@ -1,20 +1,67 @@ namespace Serval.Machine.Shared.Services; -public class WordAlignmentPreprocessBuildJob( - IEnumerable platformServices, - IRepository engines, - IDataAccessContext dataAccessContext, - ILogger logger, - IBuildJobService buildJobService, - ISharedFileService sharedFileService, - ICorpusService corpusService -) - : PreprocessBuildJob( - platformServices.First(ps => ps.EngineGroup == EngineGroup.WordAlignment), - engines, - dataAccessContext, - logger, - buildJobService, - sharedFileService, - corpusService - ) { } +public class WordAlignmentPreprocessBuildJob : PreprocessBuildJob +{ + public WordAlignmentPreprocessBuildJob( + IEnumerable platformServices, + IRepository engines, + IDataAccessContext dataAccessContext, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService + ) + : base( + platformServices.First(ps => ps.EngineGroup == EngineGroup.WordAlignment), + engines, + dataAccessContext, + logger, + buildJobService, + sharedFileService, + corpusService + ) + { + InferenceFilename = "word_alignment_inputs.json"; + } + + protected override int WriteInferences(Utf8JsonWriter inferenceWriter, ParallelCorpus corpus) + { + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus + .SourceCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] targetCorpora = corpus + .TargetCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .ToArray(); + + ITextCorpus[] sourceInferenceCorpora = sourceCorpora + .Select(sc => FilterCorpus(sc, sc.Corpus.TrainOnTextIds, sc.Corpus.TrainOnChapters)) + .ToArray(); + + ITextCorpus[] targetInferenceCorpora = targetCorpora + .Select(tc => FilterCorpus(tc, tc.Corpus.TrainOnTextIds, tc.Corpus.TrainOnChapters)) + .ToArray(); + + int inferenceCount = 0; + foreach (Row row in AlignInferenceCorpus(sourceInferenceCorpora[0], targetInferenceCorpora[0])) + { + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + WriteRow(inferenceWriter, corpus.Id, row); + inferenceCount++; + } + return inferenceCount; + } + + private static void WriteRow(Utf8JsonWriter writer, string corpusId, Row row) + { + writer.WriteStartObject(); + writer.WriteString("corpusId", corpusId); + writer.WriteString("textId", row.TextId); + writer.WriteStartArray("refs"); + foreach (object rowRef in row.Refs) + writer.WriteStringValue(rowRef.ToString()); + writer.WriteEndArray(); + writer.WriteString("source", row.SourceSegment); + writer.WriteString("target", row.TargetSegment); + writer.WriteEndObject(); + } +}