From c760b43a2f92396dc7a9a8805cf0fab3594dc6e3 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Tue, 12 Nov 2024 15:01:56 -0500 Subject: [PATCH] Broken, that is, less broken --- docker-compose.yml | 4 + .../Serval.Machine.EngineServer/Program.cs | 3 + .../IEndpointRouteBuilderExtensions.cs | 7 ++ .../IMachineBuilderExtensions.cs | 82 +++++++++++++-- .../Configuration/StatisticalEngineOptions.cs | 12 +++ .../Models/MonolingualCorpus.cs | 4 +- .../Services/NmtHangfireBuildJobFactory.cs | 2 +- .../Services/PostprocessBuildJob.cs | 18 ++-- .../Services/PreprocessBuildJob.cs | 99 +++++++------------ .../ServalTranslationEngineServiceV1.cs | 12 ++- .../ServalWordAlignmentEngineServiceV1.cs | 12 ++- .../SmtTransferPostprocessBuildJob.cs | 2 +- .../StatisticalHangfireBuildJobFactory.cs | 39 ++++++++ .../StatisticalPostprocessBuildJob.cs | 50 ++++++++++ .../Services/StatisticalTrainBuildJob.cs | 21 ++++ .../StatsiticalClearMLBuildJobFactory.cs | 49 +++++++++ .../WordAlignmentPreprocessBuildJob.cs | 20 ++++ .../Services/NmtEngineServiceTests.cs | 6 +- .../Services/PreprocessBuildJobTests.cs | 22 ++--- .../Services/SmtTransferEngineServiceTests.cs | 2 +- .../test/Serval.E2ETests/ServalApiTests.cs | 5 +- 21 files changed, 362 insertions(+), 109 deletions(-) create mode 100644 src/Machine/src/Serval.Machine.Shared/Configuration/StatisticalEngineOptions.cs create mode 100644 src/Machine/src/Serval.Machine.Shared/Services/StatisticalHangfireBuildJobFactory.cs create mode 100644 src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs create mode 100644 src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs create mode 100644 src/Machine/src/Serval.Machine.Shared/Services/StatsiticalClearMLBuildJobFactory.cs create mode 100644 src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs diff --git a/docker-compose.yml b/docker-compose.yml index 93db1add..e808daa1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -100,6 +100,8 @@ services: - BuildJob__ClearML__0__DockerImage=${MACHINE_PY_IMAGE:-ghcr.io/sillsdev/machine.py:latest} - BuildJob__ClearML__1__Queue=lambert_24gb.cpu_only - BuildJob__ClearML__1__DockerImage=${MACHINE_PY_CPU_IMAGE:-ghcr.io/sillsdev/machine.py:latest.cpu_only} + - BuildJob__ClearML__2__Queue=lambert_24gb.cpu_only + - BuildJob__ClearML__2__DockerImage=${MACHINE_PY_CPU_IMAGE:-ghcr.io/sillsdev/machine.py:latest.cpu_only} - SharedFile__Uri=s3://silnlp/docker-compose/ - "SharedFile__S3AccessKeyId=${AWS_ACCESS_KEY_ID:?access key needed}" - "SharedFile__S3SecretAccessKey=${AWS_SECRET_ACCESS_KEY:?secret key needed}" @@ -146,6 +148,8 @@ services: - BuildJob__ClearML__0__DockerImage=${MACHINE_PY_IMAGE:-ghcr.io/sillsdev/machine.py:latest} - BuildJob__ClearML__1__Queue=lambert_24gb.cpu_only - BuildJob__ClearML__1__DockerImage=${MACHINE_PY_CPU_IMAGE:-ghcr.io/sillsdev/machine.py:latest.cpu_only} + - BuildJob__ClearML__2__Queue=lambert_24gb.cpu_only + - BuildJob__ClearML__2__DockerImage=${MACHINE_PY_CPU_IMAGE:-ghcr.io/sillsdev/machine.py:latest.cpu_only} - SharedFile__Uri=s3://silnlp/docker-compose/ - "SharedFile__S3AccessKeyId=${AWS_ACCESS_KEY_ID:?access key needed}" - "SharedFile__S3SecretAccessKey=${AWS_SECRET_ACCESS_KEY:?secret key needed}" diff --git a/src/Machine/src/Serval.Machine.EngineServer/Program.cs b/src/Machine/src/Serval.Machine.EngineServer/Program.cs index b03f6575..bb3f5242 100644 --- a/src/Machine/src/Serval.Machine.EngineServer/Program.cs +++ b/src/Machine/src/Serval.Machine.EngineServer/Program.cs @@ -11,6 +11,8 @@ .AddMongoDataAccess() .AddMongoHangfireJobClient() .AddServalTranslationEngineService() + .AddServalWordAlignmentEngineService() + .AddServalPlatformService() .AddModelCleanupService() .AddMessageOutboxDeliveryService() .AddClearMLService(); @@ -36,6 +38,7 @@ var app = builder.Build(); app.MapServalTranslationEngineService(); +app.MapServalWordAlignmentEngineService(); app.MapHangfireDashboard(); app.Run(); diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IEndpointRouteBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IEndpointRouteBuilderExtensions.cs index 107de6c2..1392a791 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IEndpointRouteBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IEndpointRouteBuilderExtensions.cs @@ -9,4 +9,11 @@ public static IEndpointRouteBuilder MapServalTranslationEngineService(this IEndp return builder; } + + public static IEndpointRouteBuilder MapServalWordAlignmentEngineService(this IEndpointRouteBuilder builder) + { + builder.MapGrpcService(); + + return builder; + } } diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index 3568ed66..ca969485 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -165,9 +165,12 @@ public static IMachineBuilder AddHangfireJobServer( IEnumerable? engineTypes = null ) { - engineTypes ??= + engineTypes ??= ( builder.Configuration.GetSection("TranslationEngines").Get() - ?? [EngineType.SmtTransfer, EngineType.Nmt]; + ?? [EngineType.SmtTransfer, EngineType.Nmt] + ).Concat( + builder.Configuration.GetSection("WordAlignmentEngines").Get() ?? [EngineType.Statistical] + ); var queues = new List(); foreach (EngineType engineType in engineTypes.Distinct()) { @@ -181,6 +184,11 @@ public static IMachineBuilder AddHangfireJobServer( case EngineType.Nmt: queues.Add("nmt"); break; + case EngineType.Statistical: + queues.Add("statistical"); + break; + default: + throw new ArgumentOutOfRangeException(engineType.ToString()); } } @@ -321,10 +329,21 @@ public static IMachineBuilder AddServalPlatformService( new MethodName { Service = "serval.translation.v1.TranslationPlatformApi", - Method = "UpdateBuildStatus" + Method = "UpdateTranslationBuildStatus" } } }, + new MethodConfig + { + Names = + { + new MethodName + { + Service = "serval.word_alignment.v1.WordAlignmentPlatformApi", + Method = "UpdateWordAlignmentBuildStatus" + } + } + } } }; }); @@ -344,7 +363,6 @@ public static IMachineBuilder AddServalTranslationEngineService( options.Interceptors.Add(); options.Interceptors.Add(); }); - builder.AddServalPlatformService(connectionString); engineTypes ??= builder.Configuration.GetSection("TranslationEngines").Get() @@ -354,20 +372,67 @@ public static IMachineBuilder AddServalTranslationEngineService( switch (engineType) { case EngineType.SmtTransfer: - builder.Services.AddSingleton(); - builder.Services.AddHostedService(); - builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); + builder.AddThot(); builder.Services.AddScoped(); break; case EngineType.Nmt: builder.Services.AddScoped(); break; + default: + throw new ArgumentOutOfRangeException(engineType.ToString()); + } + } + + return builder; + } + + public static IMachineBuilder AddServalWordAlignmentEngineService( + this IMachineBuilder builder, + string? connectionString = null, + IEnumerable? engineTypes = null + ) + { + builder.Services.AddGrpc(options => + { + options.Interceptors.Add(); + options.Interceptors.Add(); + options.Interceptors.Add(); + }); + + engineTypes ??= + builder.Configuration.GetSection("WordAlignmentEngines").Get() ?? [EngineType.Statistical]; + + foreach (EngineType engineType in engineTypes.Distinct()) + { + switch (engineType) + { + case EngineType.Statistical: + builder.AddThot(); + builder.Services.AddScoped(); + break; + default: + throw new ArgumentOutOfRangeException(engineType.ToString()); } } return builder; } + public static IMachineBuilder AddThot(this IMachineBuilder builder) + { + try + { + builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); + } + catch (ArgumentException) + { + // if this has already been run, don't run it again + } + return builder; + } + public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, string? smtTransferEngineDir = null) { builder.Services.AddScoped, TranslationBuildJobService>(); @@ -376,6 +441,8 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, s builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddSingleton(); builder.Services.AddSingleton(x => x.GetRequiredService()); builder.Services.AddHostedService(p => p.GetRequiredService()); @@ -383,6 +450,7 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, s builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); + builder.Services.AddScoped(); if (smtTransferEngineDir is null) { diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/StatisticalEngineOptions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/StatisticalEngineOptions.cs new file mode 100644 index 00000000..68254e4c --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/StatisticalEngineOptions.cs @@ -0,0 +1,12 @@ +namespace Serval.Machine.Shared.Configuration; + +public class StatisticalEngineOptions +{ + public const string Key = "StatisticalEngine"; + + public string EnginesDir { get; set; } = "word_alignment_engines"; + public TimeSpan EngineCommitFrequency { get; set; } = TimeSpan.FromMinutes(5); + public TimeSpan InactiveEngineTimeout { get; set; } = TimeSpan.FromMinutes(10); + public TimeSpan SaveModelTimeout { get; set; } = TimeSpan.FromMinutes(5); + public TimeSpan EngineCommitTimeout { get; set; } = TimeSpan.FromMinutes(2); +} diff --git a/src/Machine/src/Serval.Machine.Shared/Models/MonolingualCorpus.cs b/src/Machine/src/Serval.Machine.Shared/Models/MonolingualCorpus.cs index 2b4a1612..1c737583 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/MonolingualCorpus.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/MonolingualCorpus.cs @@ -7,6 +7,6 @@ public record MonolingualCorpus public required IReadOnlyList Files { get; set; } public HashSet? TrainOnTextIds { get; set; } public Dictionary>? TrainOnChapters { get; set; } - public HashSet? PretranslateTextIds { get; set; } - public Dictionary>? PretranslateChapters { get; set; } + public HashSet? InferenceTextIds { get; set; } + public Dictionary>? InferenceChapters { get; set; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs index 0c54d65b..bcfc5014 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs @@ -19,7 +19,7 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object? buildOptions ), BuildStage.Postprocess - => CreateJob( + => CreateJob, (int, double)>( engineId, buildId, "nmt", diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs index 794487c5..6ed68c4c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs @@ -1,21 +1,15 @@ namespace Serval.Machine.Shared.Services; -public class PostprocessBuildJob( +public class PostprocessBuildJob( IPlatformService platformService, - IRepository engines, + IRepository engines, IDataAccessContext dataAccessContext, - IBuildJobService buildJobService, - ILogger logger, + IBuildJobService buildJobService, + ILogger> logger, ISharedFileService sharedFileService, IOptionsMonitor options -) - : HangfireBuildJob( - platformService, - engines, - dataAccessContext, - buildJobService, - logger - ) +) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger) + where TEngine : ITrainingEngine { protected ISharedFileService SharedFileService { get; } = sharedFileService; private readonly BuildJobOptions _buildJobOptions = options.CurrentValue; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index 8453cf51..98b729e5 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -7,8 +7,8 @@ public class PreprocessBuildJob : HangfireBuildJob WriteDataFilesAsync( + protected virtual async Task<(int TrainCount, int InferenceCount)> WriteDataFilesAsync( string buildId, IReadOnlyList corpora, string? buildOptions, @@ -109,11 +109,11 @@ CancellationToken cancellationToken if (buildOptions is not null) buildOptionsObject = JsonSerializer.Deserialize(buildOptions); await using StreamWriter sourceTrainWriter = - new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); + new(await SharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); await using StreamWriter targetTrainWriter = - new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); + new(await SharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); - await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( + await using Stream pretranslateStream = await SharedFileService.OpenWriteAsync( $"builds/{buildId}/pretranslate.src.json", cancellationToken ); @@ -125,58 +125,21 @@ CancellationToken cancellationToken foreach (ParallelCorpus corpus in corpora) { (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus - .SourceCorpora.SelectMany(c => _corpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .SourceCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) .ToArray(); ITextCorpus[] sourceTrainingCorpora = sourceCorpora - .Select(sc => - { - ITextCorpus textCorpus = sc.TextCorpus; - if (sc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(sc.Corpus.TrainOnTextIds); - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || sc.Corpus.TrainOnChapters is null - || IsInChapters(sr, sc.Corpus.TrainOnChapters) - ); - }) + .Select(sc => FilterCorpus(sc.TextCorpus, sc.Corpus.TrainOnTextIds, sc.Corpus.TrainOnChapters)) .ToArray(); ITextCorpus? sourcePretranslateCorpus = sourceCorpora - .Select(sc => - { - ITextCorpus textCorpus = sc.TextCorpus; - if (sc.Corpus.PretranslateTextIds is not null) - { - textCorpus = textCorpus.FilterTexts( - sc.Corpus.PretranslateTextIds.Except(sc.Corpus.TrainOnTextIds ?? new()) - ); - } - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || sc.Corpus.PretranslateChapters is null - || ( - IsInChapters(sr, sc.Corpus.PretranslateChapters) - && !IsInChapters(sr, sc.Corpus.TrainOnChapters ?? new()) - ) - ); - }) + .Select(sc => FilterCorpus(sc.TextCorpus, sc.Corpus.InferenceTextIds, sc.Corpus.InferenceChapters)) .ToArray() .FirstOrDefault(); (MonolingualCorpus Corpus, ITextCorpus TextCorpus)[] targetCorpora = corpus - .TargetCorpora.SelectMany(c => _corpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) + .TargetCorpora.SelectMany(c => CorpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc))) .ToArray(); ITextCorpus[] targetTrainingCorpora = targetCorpora - .Select(tc => - { - ITextCorpus textCorpus = tc.TextCorpus; - if (tc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(tc.Corpus.TrainOnTextIds); - return textCorpus.Where(row => - row.Ref is not ScriptureRef sr - || tc.Corpus.TrainOnChapters is null - || IsInChapters(sr, tc.Corpus.TrainOnChapters) - ); - }) + .Select(tc => FilterCorpus(tc.TextCorpus, tc.Corpus.TrainOnTextIds, tc.Corpus.TrainOnChapters)) .ToArray(); if (sourceCorpora.Length == 0) @@ -227,10 +190,10 @@ row.Ref is not ScriptureRef sr if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) { - ITextCorpus? sourceTermCorpus = _corpusService + ITextCorpus? sourceTermCorpus = CorpusService .CreateTermCorpora(corpus.SourceCorpora.SelectMany(sc => sc.Files).ToList()) .FirstOrDefault(); - ITextCorpus? targetTermCorpus = _corpusService + ITextCorpus? targetTermCorpus = CorpusService .CreateTermCorpora(corpus.TargetCorpora.SelectMany(tc => tc.Files).ToList()) .FirstOrDefault(); if (sourceTermCorpus is not null && targetTermCorpus is not null) @@ -275,11 +238,25 @@ void WriteRow(Utf8JsonWriter writer, string textId, IReadOnlyList refs, return (trainCount, pretranslateCount); } - private static bool IsInChapters(ScriptureRef sr, Dictionary> selection) + protected static ITextCorpus FilterCorpus( + ITextCorpus corpus, + HashSet? textIds, + IDictionary>? chapters + ) { - return selection.TryGetValue(sr.Book, out HashSet? chapters) - && chapters != null - && (chapters.Count == 0 || chapters.Contains(sr.ChapterNum)); + if (textIds is not null) + corpus = corpus.FilterTexts(textIds); + if (chapters is not null) + { + corpus = corpus.Where(row => + row.Ref is not ScriptureRef sr + || ( + chapters.TryGetValue(sr.Book, out HashSet? chapterSet) + && (chapterSet.Count == 0 || chapterSet.Contains(sr.ChapterNum)) + ) + ); + } + return corpus; } protected override async Task CleanupAsync( @@ -293,7 +270,7 @@ JobCompletionStatus completionStatus { try { - await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + await SharedFileService.DeleteAsync($"builds/{buildId}/"); } catch (Exception e) { @@ -302,7 +279,7 @@ JobCompletionStatus completionStatus } } - private static IEnumerable AlignTrainCorpus( + protected static IEnumerable AlignTrainCorpus( IReadOnlyList srcCorpora, IReadOnlyList trgCorpora ) @@ -341,7 +318,7 @@ IReadOnlyList trgCorpora .Where(rows => rows.Any(r => r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0)); } - private static IEnumerable AlignScripture(ITextCorpus srcCorpus, ITextCorpus trgCorpus) + protected static IEnumerable AlignScripture(ITextCorpus srcCorpus, ITextCorpus trgCorpus) { int rowCount = 0; StringBuilder srcSegBuffer = new(); @@ -472,7 +449,7 @@ private static IEnumerable AlignPretranslateCorpus(ITextCorpus srcCorpus, I yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); } - private record Row( + protected record Row( string TextId, IReadOnlyList Refs, string SourceSegment, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs index e93c86b7..db83db72 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs @@ -173,7 +173,9 @@ private ITranslationEngineService GetEngineService(string engineTypeStr) { if (_engineServices.TryGetValue(GetEngineType(engineTypeStr), out ITranslationEngineService? service)) return service; - throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); + throw new RpcException( + new Status(StatusCode.InvalidArgument, $"The engine type {engineTypeStr} is not supported.") + ); } private static EngineType GetEngineType(string engineTypeStr) @@ -181,7 +183,9 @@ private static EngineType GetEngineType(string engineTypeStr) engineTypeStr = engineTypeStr[0].ToString().ToUpperInvariant() + engineTypeStr[1..]; if (System.Enum.TryParse(engineTypeStr, out EngineType engineType)) return engineType; - throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); + throw new RpcException( + new Status(StatusCode.InvalidArgument, $"The engine type {engineTypeStr} is not supported.") + ); } private static Translation.V1.TranslationResult Map(SIL.Machine.Translation.TranslationResult source) @@ -307,8 +311,8 @@ private static Models.MonolingualCorpus Map(Translation.V1.MonolingualCorpus sou Files = source.Files.Select(Map).ToList(), TrainOnChapters = trainingFilter == FilterChoice.Chapters ? trainOnChapters : null, TrainOnTextIds = trainingFilter == FilterChoice.TextIds ? trainOnTextIds : null, - PretranslateChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null, - PretranslateTextIds = pretranslateFilter == FilterChoice.TextIds ? pretranslateTextIds : null + InferenceChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null, + InferenceTextIds = pretranslateFilter == FilterChoice.TextIds ? pretranslateTextIds : null }; return corpus; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs index 5ef409ab..b8b2686c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs @@ -101,7 +101,9 @@ private IWordAlignmentEngineService GetEngineService(string engineTypeStr) { if (_engineServices.TryGetValue(GetEngineType(engineTypeStr), out IWordAlignmentEngineService? service)) return service; - throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); + throw new RpcException( + new Status(StatusCode.InvalidArgument, $"The engine type {engineTypeStr} is not supported.") + ); } private static EngineType GetEngineType(string engineTypeStr) @@ -109,7 +111,9 @@ private static EngineType GetEngineType(string engineTypeStr) engineTypeStr = engineTypeStr[0].ToString().ToUpperInvariant() + engineTypeStr[1..]; if (System.Enum.TryParse(engineTypeStr, out EngineType engineType)) return engineType; - throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); + throw new RpcException( + new Status(StatusCode.InvalidArgument, $"The engine type {engineTypeStr} is not supported.") + ); } private static WordAlignmentResult Map(TranslationResult source) @@ -168,8 +172,8 @@ private static Models.MonolingualCorpus Map(WordAlignment.V1.MonolingualCorpus s Files = source.Files.Select(Map).ToList(), TrainOnChapters = trainingFilter == FilterChoice.Chapters ? trainOnChapters : null, TrainOnTextIds = trainingFilter == FilterChoice.TextIds ? trainOnTextIds : null, - PretranslateChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null, - PretranslateTextIds = pretranslateFilter == FilterChoice.TextIds ? pretranslateTextIds : null + InferenceChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null, + InferenceTextIds = pretranslateFilter == FilterChoice.TextIds ? pretranslateTextIds : null }; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs index 9e26fcdb..bcc16790 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs @@ -14,7 +14,7 @@ public class SmtTransferPostprocessBuildJob( IOptionsMonitor buildOptions, IOptionsMonitor engineOptions ) - : PostprocessBuildJob( + : PostprocessBuildJob( platformService, engines, dataAccessContext, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalHangfireBuildJobFactory.cs new file mode 100644 index 00000000..7a22fef4 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalHangfireBuildJobFactory.cs @@ -0,0 +1,39 @@ +using static Serval.Machine.Shared.Services.HangfireBuildJobRunner; + +namespace Serval.Machine.Shared.Services; + +public class StatisticalHangfireBuildJobFactory : IHangfireBuildJobFactory +{ + public EngineType EngineType => EngineType.Statistical; + + public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions) + { + return stage switch + { + BuildStage.Preprocess + => CreateJob>( + engineId, + buildId, + "statistical", + data, + buildOptions + ), + BuildStage.Postprocess + => CreateJob( + engineId, + buildId, + "statistical", + data, + buildOptions + ), + BuildStage.Train + => CreateJob( + engineId, + buildId, + "statistical", + buildOptions + ), + _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), + }; + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs new file mode 100644 index 00000000..16c6ebe2 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs @@ -0,0 +1,50 @@ +namespace Serval.Machine.Shared.Services; + +public class StatisticalPostprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDataAccessContext dataAccessContext, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService, + IDistributedReaderWriterLockFactory lockFactory, + ISmtModelFactory smtModelFactory, + IOptionsMonitor buildOptions, + IOptionsMonitor engineOptions +) + : PostprocessBuildJob( + platformService, + engines, + dataAccessContext, + buildJobService, + logger, + sharedFileService, + buildOptions + ) +{ + private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; + private readonly IOptionsMonitor _engineOptions = engineOptions; + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; + + protected override async Task SaveModelAsync(string engineId, string buildId) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId); + return await @lock.WriterLockAsync( + async ct => + { + await using ( + Stream engineStream = await SharedFileService.OpenReadAsync($"builds/{buildId}/model.tar.gz", ct) + ) + { + await _smtModelFactory.UpdateEngineFromAsync( + Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), + engineStream, + ct + ); + } + return 0; + }, + _engineOptions.CurrentValue.SaveModelTimeout + ); + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs new file mode 100644 index 00000000..86936694 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs @@ -0,0 +1,21 @@ +namespace Serval.Machine.Shared.Services; + +public class StatisticalTrainBuildJob( + IPlatformService platformService, + IRepository engines, + IDataAccessContext dataAccessContext, + IBuildJobService buildJobService, + ILogger logger +) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger) +{ + protected override Task DoWorkAsync( + string engineId, + string buildId, + object? data, + string? buildOptions, + CancellationToken cancellationToken + ) + { + throw new NotImplementedException(); + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatsiticalClearMLBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatsiticalClearMLBuildJobFactory.cs new file mode 100644 index 00000000..1e104b2c --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatsiticalClearMLBuildJobFactory.cs @@ -0,0 +1,49 @@ +namespace Serval.Machine.Shared.Services; + +public class StatisticalClearMLBuildJobFactory( + ISharedFileService sharedFileService, + IRepository engines +) : IClearMLBuildJobFactory +{ + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly IRepository _engines = engines; + + public EngineType EngineType => EngineType.Statistical; + + public async Task CreateJobScriptAsync( + string engineId, + string buildId, + string modelType, + BuildStage stage, + object? data = null, + string? buildOptions = null, + CancellationToken cancellationToken = default + ) + { + if (stage == BuildStage.Train) + { + WordAlignmentEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException("The engine does not exist."); + + Uri sharedFileUri = _sharedFileService.GetBaseUri(); + string baseUri = sharedFileUri.GetComponents(UriComponents.SchemeAndServer, UriFormat.Unescaped); + string folder = sharedFileUri.GetComponents(UriComponents.Path, UriFormat.Unescaped); + return "from machine.jobs.build_word_alignment_model import run\n" + + "args = {\n" + + $" 'model_type': '{modelType}',\n" + + $" 'engine_id': '{engineId}',\n" + + $" 'build_id': '{buildId}',\n" + + $" 'shared_file_uri': '{baseUri}',\n" + + $" 'shared_file_folder': '{folder}',\n" + + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") + + $" 'clearml': True\n" + + "}\n" + + "run(args)\n"; + } + else + { + throw new ArgumentException("Unknown build stage.", nameof(stage)); + } + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs new file mode 100644 index 00000000..4d2d4cc4 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs @@ -0,0 +1,20 @@ +namespace Serval.Machine.Shared.Services; + +public class WordAlignmentPreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDataAccessContext dataAccessContext, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService +) + : PreprocessBuildJob( + platformService, + engines, + dataAccessContext, + logger, + buildJobService, + sharedFileService, + corpusService + ) { } diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs index 88989842..6b5b5427 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs @@ -305,16 +305,16 @@ public override object ActivateJob(Type jobType) new LanguageTagService() ); } - if (jobType == typeof(PostprocessBuildJob)) + if (jobType == typeof(PostprocessBuildJob)) { var buildJobOptions = Substitute.For>(); buildJobOptions.CurrentValue.Returns(new BuildJobOptions()); - return new PostprocessBuildJob( + return new PostprocessBuildJob( _env.PlatformService, _env.Engines, new MemoryDataAccessContext(), _env.BuildJobService, - Substitute.For>(), + Substitute.For>>(), _env.SharedFileService, buildJobOptions ); diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs index fed897c2..52332074 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs @@ -372,7 +372,7 @@ public async Task ParallelCorpusLogic() new() { } } }, - PretranslateChapters = new() + InferenceChapters = new() { { "1CH", @@ -537,7 +537,7 @@ public TestEnvironment() Language = "es", Files = [TextFile("source1")], TrainOnTextIds = [], - PretranslateTextIds = [] + InferenceTextIds = [] } }, TargetCorpora = new List() @@ -564,8 +564,8 @@ public TestEnvironment() Files = [TextFile("source1"), TextFile("source2")], TrainOnTextIds = null, TrainOnChapters = null, - PretranslateTextIds = null, - PretranslateChapters = null, + InferenceTextIds = null, + InferenceChapters = null, } }, TargetCorpora = new List() @@ -592,7 +592,7 @@ public TestEnvironment() Language = "es", Files = [ParatextFile("pt-source1")], TrainOnTextIds = [], - PretranslateTextIds = [] + InferenceTextIds = [] } }, TargetCorpora = new List() @@ -618,7 +618,7 @@ public TestEnvironment() Language = "es", Files = [ParatextFile("pt-source1")], TrainOnTextIds = null, - PretranslateTextIds = null + InferenceTextIds = null }, new() { @@ -626,7 +626,7 @@ public TestEnvironment() Language = "es", Files = [ParatextFile("pt-source2")], TrainOnTextIds = null, - PretranslateTextIds = null + InferenceTextIds = null } }, TargetCorpora = new List() @@ -835,7 +835,7 @@ public static ParallelCorpus TextFileCorpus( Language = "es", Files = [TextFile("source1")], TrainOnTextIds = trainOnTextIds, - PretranslateTextIds = pretranslateTextIds + InferenceTextIds = pretranslateTextIds } }, TargetCorpora = new List() @@ -864,7 +864,7 @@ public static ParallelCorpus TextFileCorpus(string sourceLanguage, string target Language = sourceLanguage, Files = [TextFile("source1")], TrainOnTextIds = [], - PretranslateTextIds = [] + InferenceTextIds = [] } }, TargetCorpora = new List() @@ -896,7 +896,7 @@ public ParallelCorpus ParatextCorpus( Language = "es", Files = [ParatextFile("pt-source1")], TrainOnChapters = trainOnChapters, - PretranslateChapters = pretranslateChapters + InferenceChapters = pretranslateChapters } }, TargetCorpora = new List() @@ -925,7 +925,7 @@ public ParallelCorpus ParatextCorpus(HashSet? trainOnTextIds, HashSet() diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs index 20859115..ac53467e 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs @@ -51,7 +51,7 @@ await env.Service.StartBuildAsync( Language = "es", Files = [], TrainOnTextIds = null, - PretranslateTextIds = null + InferenceTextIds = null } }, TargetCorpora = new List() diff --git a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs index 06da1d1d..0d7eaabe 100644 --- a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs +++ b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs @@ -477,9 +477,10 @@ public async Task ParatextProjectNmtJobAsync() [Test] public async Task GetWordAlignment() { - string engineId = await _helperClient.CreateNewEngineAsync("statistical", "es", "en", "STAT1"); + string engineId = await _helperClient.CreateNewEngineAsync("Statistical", "es", "en", "STAT1"); string[] books = ["1JN.txt", "2JN.txt", "3JN.txt"]; - await _helperClient.AddTextCorpusToEngineAsync(engineId, books, "es", "en", false); + ParallelCorpusConfig train_corpus = await _helperClient.MakeParallelTextCorpus(books, "es", "en", false); + await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false); await _helperClient.BuildEngineAsync(engineId); WordAlignmentResult tResult = await _helperClient.WordAlignmentEnginesClient.GetWordAlignmentAsync( engineId,