From 5df95505950ce8102d760a0ea99586e24184cb57 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Mon, 10 Jun 2024 09:32:39 -0400 Subject: [PATCH 1/4] SMT on ClearML: * Replace CPU, GPU types with just Hangfire vs ClearML as well as engine type * Allow each engine type to have it's own queue and docker image * SMT build defaults on ClearML * NMT local train removed * Use .zip for SMT model moving * Update cleanup script for SMT models * Download and upload model in factory --- .../Configuration/BuildJobOptions.cs | 3 +- .../Configuration/ClearMLBuildQueue.cs | 9 + .../Configuration/ClearMLOptions.cs | 3 - .../IMachineBuilderExtensions.cs | 120 ++--- .../IServiceCollectionExtensions.cs | 2 + src/SIL.Machine.AspNetCore/Models/Build.cs | 13 +- .../Models/TranslationEngine.cs | 1 + .../Services/BuildJobService.cs | 72 +-- .../Services/ClearMLBuildJobRunner.cs | 27 +- .../Services/ClearMLHealthCheck.cs | 24 +- .../Services/ClearMLMonitorService.cs | 75 ++- .../Services/ClearMLService.cs | 12 +- .../Services/HangfireBuildJobRunner.cs | 10 +- .../Services/IBuildJobRunner.cs | 10 +- .../Services/IBuildJobService.cs | 26 +- .../Services/IClearMLBuildJobFactory.cs | 3 +- .../Services/IClearMLQueueService.cs | 6 + .../Services/IClearMLService.cs | 5 +- .../Services/IHangfireBuildJobFactory.cs | 2 +- .../Services/ISmtModelFactory.cs | 2 + .../Services/ITranslationEngineService.cs | 2 +- .../Services/ModelCleanupService.cs | 25 +- .../Services/NmtClearMLBuildJobFactory.cs | 11 +- .../Services/NmtEngineService.cs | 34 +- .../Services/NmtHangfireBuildJobFactory.cs | 7 +- .../Services/NmtPostprocessBuildJob.cs | 44 +- .../Services/NmtPreprocessBuildJob.cs | 440 +----------------- .../Services/NmtTrainBuildJob.cs | 142 ------ .../Services/PostprocessBuildJob.cs | 57 +++ .../Services/PreprocessBuildJob.cs | 431 +++++++++++++++++ .../Services/S3WriteStream.cs | 107 +++-- .../ServalTranslationEngineServiceV1.cs | 7 +- .../Services/SmtTransferBuildJob.cs | 137 ------ .../SmtTransferClearMLBuildJobFactory.cs | 52 +++ .../Services/SmtTransferEngineService.cs | 29 +- .../SmtTransferHangfireBuildJobFactory.cs | 16 +- .../SmtTransferPostprocessBuildJob.cs | 80 ++++ .../Services/SmtTransferPreprocessBuildJob.cs | 20 + .../Services/SmtTransferTrainBuildJob.cs | 91 ++++ .../Services/ThotSmtModelFactory.cs | 30 +- src/SIL.Machine.AspNetCore/Usings.cs | 2 - .../appsettings.Development.json | 4 +- .../appsettings.json | 18 +- .../appsettings.Development.json | 4 +- .../appsettings.json | 18 +- .../Services/ClearMLServiceTests.cs | 2 +- .../Services/ModelCleanupServiceTests.cs | 2 + .../NmtClearMLBuildJobFactoryTests.cs | 19 +- .../Services/NmtEngineServiceTests.cs | 75 ++- .../Services/NmtPreprocessBuildJobTests.cs | 80 +++- .../Services/SmtTransferEngineServiceTests.cs | 234 +++++++--- tests/SIL.Machine.AspNetCore.Tests/Usings.cs | 2 - 52 files changed, 1446 insertions(+), 1201 deletions(-) create mode 100644 src/SIL.Machine.AspNetCore/Configuration/ClearMLBuildQueue.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IClearMLQueueService.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs diff --git a/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs index d761ac4d0..b6dc8c328 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs @@ -4,6 +4,5 @@ public class BuildJobOptions { public const string Key = "BuildJob"; - public Dictionary Runners { get; set; } = - new() { { BuildJobType.Cpu, BuildJobRunner.Hangfire }, { BuildJobType.Gpu, BuildJobRunner.ClearML } }; + public IList ClearML { get; set; } = new List(); } diff --git a/src/SIL.Machine.AspNetCore/Configuration/ClearMLBuildQueue.cs b/src/SIL.Machine.AspNetCore/Configuration/ClearMLBuildQueue.cs new file mode 100644 index 000000000..9ffe279d8 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Configuration/ClearMLBuildQueue.cs @@ -0,0 +1,9 @@ +namespace SIL.Machine.AspNetCore.Configuration; + +public class ClearMLBuildQueue +{ + public TranslationEngineType TranslationEngineType { get; set; } + public string ModelType { get; set; } = ""; + public string Queue { get; set; } = "default"; + public string DockerImage { get; set; } = ""; +} diff --git a/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs index e8fc08994..fdc9e0238 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs @@ -4,13 +4,10 @@ public class ClearMLOptions { public const string Key = "ClearML"; - public string Queue { get; set; } = "default"; public string AccessKey { get; set; } = ""; public string SecretKey { get; set; } = ""; public bool BuildPollingEnabled { get; set; } = false; public TimeSpan BuildPollingTimeout { get; set; } = TimeSpan.FromSeconds(10); - public string ModelType { get; set; } = "huggingface"; public string RootProject { get; set; } = "Machine"; public string Project { get; set; } = "dev"; - public string DockerImage { get; set; } = ""; } diff --git a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs index 4f8e10982..6183c476d 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs @@ -64,6 +64,21 @@ public static IMachineBuilder AddSharedFileOptions(this IMachineBuilder builder, return builder; } + public static IMachineBuilder AddBuildJobOptions( + this IMachineBuilder builder, + Action configureOptions + ) + { + builder.Services.Configure(configureOptions); + return builder; + } + + public static IMachineBuilder AddBuildJobOptions(this IMachineBuilder builder, IConfiguration config) + { + builder.Services.Configure(config); + return builder; + } + public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder) { if (builder.Configuration is null) @@ -131,26 +146,6 @@ public static IMachineBuilder AddClearMLService(this IMachineBuilder builder, st return builder; } - private static IMachineBuilder AddClearMLBuildJobRunner(this IMachineBuilder builder) - { - builder.Services.AddScoped(); - builder.Services.AddScoped(); - builder.Services.AddSingleton(); - builder.Services.AddHostedService(p => p.GetRequiredService()); - - return builder; - } - - private static IMachineBuilder AddHangfireBuildJobRunner(this IMachineBuilder builder) - { - builder.Services.AddScoped(); - - builder.Services.AddScoped(); - builder.Services.AddScoped(); - - return builder; - } - private static MongoStorageOptions GetMongoStorageOptions() { var mongoStorageOptions = new MongoStorageOptions @@ -200,6 +195,7 @@ public static IMachineBuilder AddHangfireJobServer( switch (engineType) { case TranslationEngineType.SmtTransfer: + builder.Services.AddSingleton(); builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); queues.Add("smt_transfer"); break; @@ -252,7 +248,7 @@ await c.Indexes.CreateOrUpdateAsync( ); await c.Indexes.CreateOrUpdateAsync( new CreateIndexModel( - Builders.IndexKeys.Ascending(e => e.CurrentBuild!.JobRunner) + Builders.IndexKeys.Ascending(e => e.CurrentBuild!.BuildJobRunner) ) ); } @@ -360,49 +356,38 @@ public static IMachineBuilder AddServalTranslationEngineService( return builder; } - public static IMachineBuilder AddBuildJobService( - this IMachineBuilder builder, - Action configureOptions - ) + public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, string? smtTransferEngineDir = null) { - builder.Services.Configure(configureOptions); - var options = new BuildJobOptions(); - configureOptions(options); - return builder.AddBuildJobService(options); - } + builder.Services.AddScoped(); - public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, IConfiguration config) - { - builder.Services.Configure(config); - var buildJobOptions = new BuildJobOptions(); - config.GetSection(BuildJobOptions.Key).Bind(buildJobOptions); - return builder.AddBuildJobService(buildJobOptions); - } + builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(x => x.GetRequiredService()); + builder.Services.AddHostedService(p => p.GetRequiredService()); - public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder) - { - if (builder.Configuration is null) - { - builder.AddBuildJobService(o => { }); - } - else - { - builder.AddBuildJobService(builder.Configuration.GetSection(BuildJobOptions.Key)); + builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddScoped(); + if (smtTransferEngineDir is null) + { var smtTransferEngineOptions = new SmtTransferEngineOptions(); - builder.Configuration.GetSection(SmtTransferEngineOptions.Key).Bind(smtTransferEngineOptions); - string? driveLetter = Path.GetPathRoot(smtTransferEngineOptions.EnginesDir)?[..1]; - if (driveLetter is null) - throw new InvalidOperationException("SMT Engine directory is required"); - // add health check for disk storage capacity - builder - .Services.AddHealthChecks() - .AddDiskStorageHealthCheck( - x => x.AddDrive(driveLetter, 1_000), // 1GB - "SMT Engine Storage Capacity", - HealthStatus.Degraded - ); + builder.Configuration?.GetSection(SmtTransferEngineOptions.Key).Bind(smtTransferEngineOptions); + smtTransferEngineDir = smtTransferEngineOptions.EnginesDir; } + string? driveLetter = Path.GetPathRoot(smtTransferEngineDir)?[..1]; + if (driveLetter is null) + throw new InvalidOperationException("SMT Engine directory is required"); + // add health check for disk storage capacity + builder + .Services.AddHealthChecks() + .AddDiskStorageHealthCheck( + x => x.AddDrive(driveLetter, 1_000), // 1GB + "SMT Engine Storage Capacity", + HealthStatus.Degraded + ); return builder; } @@ -412,23 +397,4 @@ public static IMachineBuilder AddModelCleanupService(this IMachineBuilder builde builder.Services.AddHostedService(); return builder; } - - private static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, BuildJobOptions options) - { - builder.Services.AddScoped(); - - foreach (BuildJobRunner runnerType in options.Runners.Values.Distinct()) - { - switch (runnerType) - { - case BuildJobRunner.ClearML: - builder.AddClearMLBuildJobRunner(); - break; - case BuildJobRunner.Hangfire: - builder.AddHangfireBuildJobRunner(); - break; - } - } - return builder; - } } diff --git a/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs b/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs index 742f64f76..21642d82d 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs @@ -27,6 +27,7 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf builder.AddSharedFileOptions(o => { }); builder.AddSmtTransferEngineOptions(o => { }); builder.AddClearMLOptions(o => { }); + builder.AddBuildJobOptions(o => { }); } else { @@ -34,6 +35,7 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf builder.AddSharedFileOptions(configuration.GetSection(SharedFileOptions.Key)); builder.AddSmtTransferEngineOptions(configuration.GetSection(SmtTransferEngineOptions.Key)); builder.AddClearMLOptions(configuration.GetSection(ClearMLOptions.Key)); + builder.AddBuildJobOptions(configuration.GetSection(BuildJobOptions.Key)); } return builder; } diff --git a/src/SIL.Machine.AspNetCore/Models/Build.cs b/src/SIL.Machine.AspNetCore/Models/Build.cs index 40eef7969..89a4bdc83 100644 --- a/src/SIL.Machine.AspNetCore/Models/Build.cs +++ b/src/SIL.Machine.AspNetCore/Models/Build.cs @@ -8,18 +8,25 @@ public enum BuildJobState Canceling } -public enum BuildJobRunner +public enum BuildJobRunnerType { Hangfire, ClearML } +public enum BuildStage +{ + Preprocess, + Train, + Postprocess +} + public record Build { public required string BuildId { get; init; } public required BuildJobState JobState { get; init; } public required string JobId { get; init; } - public required BuildJobRunner JobRunner { get; init; } - public required string Stage { get; init; } + public required BuildJobRunnerType BuildJobRunner { get; init; } + public required BuildStage Stage { get; init; } public string? Options { get; set; } } diff --git a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs index 7cd8a918c..cedb504c6 100644 --- a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs +++ b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs @@ -5,6 +5,7 @@ public record TranslationEngine : IEntity public string Id { get; set; } = ""; public int Revision { get; set; } = 1; public required string EngineId { get; init; } + public required TranslationEngineType Type { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } public required bool IsModelPersisted { get; init; } diff --git a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs index c86c77c76..406474283 100644 --- a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs +++ b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs @@ -1,23 +1,10 @@ namespace SIL.Machine.AspNetCore.Services; -public class BuildJobService : IBuildJobService +public class BuildJobService(IEnumerable runners, IRepository engines) + : IBuildJobService { - private readonly Dictionary _runnersByJobType; - private readonly Dictionary _runners; - private readonly IRepository _engines; - - public BuildJobService( - IEnumerable runners, - IRepository engines, - IOptions options - ) - { - _runners = runners.ToDictionary(r => r.Type); - _runnersByJobType = new Dictionary(); - foreach (KeyValuePair kvp in options.Value.Runners) - _runnersByJobType.Add(kvp.Key, _runners[kvp.Value]); - _engines = engines; - } + private readonly Dictionary _runners = runners.ToDictionary(r => r.Type); + private readonly IRepository _engines = engines; public Task IsEngineBuilding(string engineId, CancellationToken cancellationToken = default) { @@ -25,12 +12,12 @@ public Task IsEngineBuilding(string engineId, CancellationToken cancellati } public Task> GetBuildingEnginesAsync( - BuildJobRunner runner, + BuildJobRunnerType runner, CancellationToken cancellationToken = default ) { return _engines.GetAllAsync( - e => e.CurrentBuild != null && e.CurrentBuild.JobRunner == runner, + e => e.CurrentBuild != null && e.CurrentBuild.BuildJobRunner == runner, cancellationToken ); } @@ -49,58 +36,49 @@ public Task> GetBuildingEnginesAsync( } public async Task CreateEngineAsync( - IEnumerable jobTypes, string engineId, string? name = null, CancellationToken cancellationToken = default ) { - foreach (BuildJobType jobType in jobTypes) + foreach (BuildJobRunnerType runnerType in _runners.Keys) { - IBuildJobRunner runner = _runnersByJobType[jobType]; + IBuildJobRunner runner = _runners[runnerType]; await runner.CreateEngineAsync(engineId, name, cancellationToken); } } - public async Task DeleteEngineAsync( - IEnumerable jobTypes, - string engineId, - CancellationToken cancellationToken = default - ) + public async Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default) { - foreach (BuildJobType jobType in jobTypes) + foreach (BuildJobRunnerType runnerType in _runners.Keys) { - IBuildJobRunner runner = _runnersByJobType[jobType]; + IBuildJobRunner runner = _runners[runnerType]; await runner.DeleteEngineAsync(engineId, cancellationToken); } } public async Task StartBuildJobAsync( - BuildJobType jobType, - TranslationEngineType engineType, + BuildJobRunnerType runnerType, string engineId, string buildId, - string stage, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default ) { - if ( - !await _engines.ExistsAsync( - e => - e.EngineId == engineId - && (e.CurrentBuild == null || e.CurrentBuild.JobState != BuildJobState.Canceling), - cancellationToken - ) - ) - { + TranslationEngine? engine = await _engines.GetAsync( + e => + e.EngineId == engineId + && (e.CurrentBuild == null || e.CurrentBuild.JobState != BuildJobState.Canceling), + cancellationToken + ); + if (engine is null) return false; - } - IBuildJobRunner runner = _runnersByJobType[jobType]; + IBuildJobRunner runner = _runners[runnerType]; string jobId = await runner.CreateJobAsync( - engineType, + engine.Type, engineId, buildId, stage, @@ -119,7 +97,7 @@ await _engines.UpdateAsync( { BuildId = buildId, JobId = jobId, - JobRunner = runner.Type, + BuildJobRunner = runner.Type, Stage = stage, JobState = BuildJobState.Pending, Options = buildOptions @@ -127,7 +105,7 @@ await _engines.UpdateAsync( ), cancellationToken: cancellationToken ); - await runner.EnqueueJobAsync(jobId, cancellationToken); + await runner.EnqueueJobAsync(jobId, engine.Type, cancellationToken); return true; } catch @@ -149,7 +127,7 @@ await _engines.UpdateAsync( if (engine is null || engine.CurrentBuild is null) return (null, BuildJobState.None); - IBuildJobRunner runner = _runners[engine.CurrentBuild.JobRunner]; + IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; if (engine.CurrentBuild.JobState is BuildJobState.Pending) { diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs index 31b656b32..234f05e3b 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs @@ -2,14 +2,18 @@ public class ClearMLBuildJobRunner( IClearMLService clearMLService, - IEnumerable buildJobFactories + IEnumerable buildJobFactories, + IOptionsMonitor options ) : IBuildJobRunner { private readonly IClearMLService _clearMLService = clearMLService; private readonly Dictionary _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); - public BuildJobRunner Type => BuildJobRunner.ClearML; + private readonly Dictionary _options = + options.CurrentValue.ClearML.ToDictionary(o => o.TranslationEngineType); + + public BuildJobRunnerType Type => BuildJobRunnerType.ClearML; public async Task CreateEngineAsync( string engineId, @@ -31,7 +35,7 @@ public async Task CreateJobAsync( TranslationEngineType engineType, string engineId, string buildId, - string stage, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default @@ -48,12 +52,19 @@ public async Task CreateJobAsync( string script = await buildJobFactory.CreateJobScriptAsync( engineId, buildId, + _options[engineType].ModelType, stage, data, buildOptions, cancellationToken ); - return await _clearMLService.CreateTaskAsync(buildId, projectId, script, cancellationToken); + return await _clearMLService.CreateTaskAsync( + buildId, + projectId, + script, + _options[engineType].DockerImage, + cancellationToken + ); } public Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default) @@ -61,9 +72,13 @@ public Task DeleteJobAsync(string jobId, CancellationToken cancellationTok return _clearMLService.DeleteTaskAsync(jobId, cancellationToken); } - public Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default) + public Task EnqueueJobAsync( + string jobId, + TranslationEngineType engineType, + CancellationToken cancellationToken = default + ) { - return _clearMLService.EnqueueTaskAsync(jobId, cancellationToken); + return _clearMLService.EnqueueTaskAsync(jobId, _options[engineType].Queue, cancellationToken); } public Task StopJobAsync(string jobId, CancellationToken cancellationToken = default) diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs index 698b737a6..0cce9b7d1 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs @@ -3,12 +3,15 @@ namespace SIL.Machine.AspNetCore.Services; public class ClearMLHealthCheck( IClearMLAuthenticationService clearMLAuthenticationService, IHttpClientFactory httpClientFactory, - IOptionsMonitor options + IOptionsMonitor buildJobOptions ) : IHealthCheck { private readonly HttpClient _httpClient = httpClientFactory.CreateClient("ClearML-NoRetry"); - private readonly IOptionsMonitor _options = options; private readonly IClearMLAuthenticationService _clearMLAuthenticationService = clearMLAuthenticationService; + private readonly ISet _queuesMonitored = buildJobOptions + .CurrentValue.ClearML.Select(x => x.Queue) + .ToHashSet(); + private int _numConsecutiveFailures = 0; private readonly AsyncLock _lock = new AsyncLock(); @@ -21,10 +24,11 @@ public async Task CheckHealthAsync( { if (!await PingAsync(cancellationToken)) return HealthCheckResult.Unhealthy("ClearML is unresponsive"); - if (!await WorkersAreAssignedToQueue(cancellationToken)) + IReadOnlySet queuesWithoutWorkers = await QueuesWithoutWorkers(cancellationToken); + if (queuesWithoutWorkers.Count > 0) { return HealthCheckResult.Unhealthy( - $"No ClearML agents are available for configured queue \"{_options.CurrentValue.Queue}\"" + $"No ClearML agents are available for configured queues: {string.Join(", ", queuesWithoutWorkers)}" ); } @@ -70,8 +74,9 @@ public async Task PingAsync(CancellationToken cancellationToken = default) return result is not null; } - public async Task WorkersAreAssignedToQueue(CancellationToken cancellationToken = default) + public async Task> QueuesWithoutWorkers(CancellationToken cancellationToken = default) { + HashSet queuesWithoutWorkers = _queuesMonitored.ToHashSet(); JsonObject? result = await CallAsync("workers", "get_all", new JsonObject(), cancellationToken); JsonNode? workers_node = result?["data"]?["workers"]; if (workers_node is null) @@ -83,12 +88,13 @@ public async Task WorkersAreAssignedToQueue(CancellationToken cancellation if (queues_node is null) continue; JsonArray queues = (JsonArray)queues_node; - foreach (var queue in queues) + foreach (var currentQueue in queues) { - if ((string?)queue?["name"] == _options.CurrentValue.Queue) - return true; + string? currentQueueName = (string?)currentQueue?["name"]; + if (currentQueueName is not null) + queuesWithoutWorkers.Remove(currentQueueName); } } - return false; + return queuesWithoutWorkers; } } diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs index f166dabaf..05b931bee 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs @@ -4,16 +4,18 @@ public class ClearMLMonitorService( IServiceProvider services, IClearMLService clearMLService, ISharedFileService sharedFileService, - IOptions options, + IOptionsMonitor clearMLOptions, + IOptionsMonitor buildJobOptions, ILogger logger ) : RecurrentTask( "ClearML monitor service", services, - options.Value.BuildPollingTimeout, + clearMLOptions.CurrentValue.BuildPollingTimeout, logger, - options.Value.BuildPollingEnabled - ) + clearMLOptions.CurrentValue.BuildPollingEnabled + ), + IClearMLQueueService { private static readonly string EvalMetric = CreateMD5("eval"); private static readonly string BleuVariant = CreateMD5("bleu"); @@ -24,10 +26,21 @@ ILogger logger private readonly IClearMLService _clearMLService = clearMLService; private readonly ISharedFileService _sharedFileService = sharedFileService; - private readonly ILogger _logger = logger; + private readonly ILogger _logger = logger; private readonly Dictionary _curBuildStatus = new(); - public int QueueSize { get; private set; } + private readonly IReadOnlyDictionary _queuePerEngineType = + buildJobOptions.CurrentValue.ClearML.ToDictionary(x => x.TranslationEngineType, x => x.Queue); + + private readonly IDictionary _queueSizePerEngineType = new ConcurrentDictionary< + TranslationEngineType, + int + >(buildJobOptions.CurrentValue.ClearML.ToDictionary(x => x.TranslationEngineType, x => 0)); + + public int GetQueueSize(TranslationEngineType engineType) + { + return _queueSizePerEngineType[engineType]; + } protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) { @@ -35,28 +48,43 @@ protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken { var buildJobService = scope.ServiceProvider.GetRequiredService(); IReadOnlyList trainingEngines = await buildJobService.GetBuildingEnginesAsync( - BuildJobRunner.ClearML, + BuildJobRunnerType.ClearML, cancellationToken ); if (trainingEngines.Count == 0) return; - Dictionary tasks = ( - await _clearMLService.GetTasksByIdAsync( - trainingEngines.Select(e => e.CurrentBuild!.JobId), - cancellationToken + Dictionary tasks = new(); + Dictionary queuePositions = new(); + + foreach (TranslationEngineType engineType in _queuePerEngineType.Keys) + { + Dictionary tasksPerEngineType = ( + await _clearMLService.GetTasksByIdAsync( + trainingEngines.Select(e => e.CurrentBuild!.JobId), + cancellationToken + ) ) - ) - .UnionBy(await _clearMLService.GetTasksForCurrentQueueAsync(cancellationToken), t => t.Id) - .ToDictionary(t => t.Id); + .UnionBy( + await _clearMLService.GetTasksForQueueAsync(_queuePerEngineType[engineType], cancellationToken), + t => t.Id + ) + .ToDictionary(t => t.Id); + // add new keys to dictionary + foreach (KeyValuePair kvp in tasksPerEngineType) + tasks.TryAdd(kvp.Key, kvp.Value); - Dictionary queuePositions = tasks - .Values.Where(t => t.Status is ClearMLTaskStatus.Queued or ClearMLTaskStatus.Created) - .OrderBy(t => t.Created) - .Select((t, i) => (Position: i, Task: t)) - .ToDictionary(e => e.Task.Name, e => e.Position); + Dictionary queuePositionsPerEngineType = tasksPerEngineType + .Values.Where(t => t.Status is ClearMLTaskStatus.Queued or ClearMLTaskStatus.Created) + .OrderBy(t => t.Created) + .Select((t, i) => (Position: i, Task: t)) + .ToDictionary(e => e.Task.Name, e => e.Position); + // add new keys to dictionary + foreach (KeyValuePair kvp in queuePositionsPerEngineType) + queuePositions.TryAdd(kvp.Key, kvp.Value); - QueueSize = queuePositions.Count; + _queueSizePerEngineType[engineType] = queuePositionsPerEngineType.Count; + } var platformService = scope.ServiceProvider.GetRequiredService(); var lockFactory = scope.ServiceProvider.GetRequiredService(); @@ -80,7 +108,7 @@ await UpdateTrainJobStatus( ); } - if (engine.CurrentBuild.Stage == NmtBuildStages.Train) + if (engine.CurrentBuild.Stage == BuildStage.Train) { if ( engine.CurrentBuild.JobState is BuildJobState.Pending @@ -220,11 +248,10 @@ CancellationToken cancellationToken await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { return await buildJobService.StartBuildJobAsync( - BuildJobType.Cpu, - TranslationEngineType.Nmt, + BuildJobRunnerType.Hangfire, engineId, buildId, - NmtBuildStages.Postprocess, + BuildStage.Postprocess, (corpusSize, confidence), buildOptions, cancellationToken diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs index f52539284..d755aadcc 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs @@ -74,6 +74,7 @@ public async Task CreateTaskAsync( string buildId, string projectId, string script, + string dockerImage, CancellationToken cancellationToken = default ) { @@ -85,7 +86,7 @@ public async Task CreateTaskAsync( ["script"] = new JsonObject { ["diff"] = script }, ["container"] = new JsonObject { - ["image"] = _options.CurrentValue.DockerImage, + ["image"] = dockerImage, ["arguments"] = "--env ENV_FOR_DYNACONF=" + snakeCaseEnvironment, }, ["type"] = "training" @@ -107,9 +108,9 @@ public async Task DeleteTaskAsync(string id, CancellationToken cancellatio return deleted.Value; } - public async Task EnqueueTaskAsync(string id, CancellationToken cancellationToken = default) + public async Task EnqueueTaskAsync(string id, string queue, CancellationToken cancellationToken = default) { - var body = new JsonObject { ["task"] = id, ["queue_name"] = _options.CurrentValue.Queue }; + var body = new JsonObject { ["task"] = id, ["queue_name"] = queue }; JsonObject? result = await CallAsync("tasks", "enqueue", body, cancellationToken); var queued = (int?)result?["data"]?["queued"]; if (queued is null) @@ -137,11 +138,12 @@ public async Task StopTaskAsync(string id, CancellationToken cancellationT return updated == 1; } - public async Task> GetTasksForCurrentQueueAsync( + public async Task> GetTasksForQueueAsync( + string queue, CancellationToken cancellationToken = default ) { - var body = new JsonObject { ["name"] = _options.CurrentValue.Queue }; + var body = new JsonObject { ["name"] = queue }; JsonObject? result = await CallAsync("queues", "get_all_ex", body, cancellationToken); var tasks = (JsonArray?)result?["data"]?["queues"]?[0]?["entries"]; IEnumerable taskIds = tasks?.Select(t => (string)t?["id"]!) ?? new List(); diff --git a/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs index 2cf56dc0b..9b5c62ff3 100644 --- a/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs +++ b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs @@ -36,7 +36,7 @@ public static Job CreateJob(string engineId, string buildId, string queue, private readonly Dictionary _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); - public BuildJobRunner Type => BuildJobRunner.Hangfire; + public BuildJobRunnerType Type => BuildJobRunnerType.Hangfire; public Task CreateEngineAsync(string engineId, string? name = null, CancellationToken cancellationToken = default) { @@ -52,7 +52,7 @@ public Task CreateJobAsync( TranslationEngineType engineType, string engineId, string buildId, - string stage, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default @@ -68,7 +68,11 @@ public Task DeleteJobAsync(string jobId, CancellationToken cancellationTok return Task.FromResult(_jobClient.Delete(jobId)); } - public Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default) + public Task EnqueueJobAsync( + string jobId, + TranslationEngineType engineType, + CancellationToken cancellationToken = default + ) { return Task.FromResult(_jobClient.Requeue(jobId)); } diff --git a/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs index 9f78205f3..1167efe1e 100644 --- a/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs +++ b/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs @@ -2,7 +2,7 @@ public interface IBuildJobRunner { - BuildJobRunner Type { get; } + BuildJobRunnerType Type { get; } Task CreateEngineAsync(string engineId, string? name = null, CancellationToken cancellationToken = default); Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); @@ -11,7 +11,7 @@ Task CreateJobAsync( TranslationEngineType engineType, string engineId, string buildId, - string stage, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default @@ -19,7 +19,11 @@ Task CreateJobAsync( Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default); - Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default); + Task EnqueueJobAsync( + string jobId, + TranslationEngineType engineType, + CancellationToken cancellationToken = default + ); Task StopJobAsync(string jobId, CancellationToken cancellationToken = default); } diff --git a/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs index c425d5afe..37b28c55b 100644 --- a/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs +++ b/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs @@ -1,39 +1,23 @@ namespace SIL.Machine.AspNetCore.Services; -public enum BuildJobType -{ - Cpu, - Gpu -} - public interface IBuildJobService { Task> GetBuildingEnginesAsync( - BuildJobRunner runner, + BuildJobRunnerType runner, CancellationToken cancellationToken = default ); Task IsEngineBuilding(string engineId, CancellationToken cancellationToken = default); - Task CreateEngineAsync( - IEnumerable jobTypes, - string engineId, - string? name = null, - CancellationToken cancellationToken = default - ); + Task CreateEngineAsync(string engineId, string? name = null, CancellationToken cancellationToken = default); - Task DeleteEngineAsync( - IEnumerable jobTypes, - string engineId, - CancellationToken cancellationToken = default - ); + Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); Task StartBuildJobAsync( - BuildJobType jobType, - TranslationEngineType engineType, + BuildJobRunnerType jobType, string engineId, string buildId, - string stage, + BuildStage stage, object? data = default, string? buildOptions = default, CancellationToken cancellationToken = default diff --git a/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs index ed14c12d7..eeb84149a 100644 --- a/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs @@ -7,7 +7,8 @@ public interface IClearMLBuildJobFactory Task CreateJobScriptAsync( string engineId, string buildId, - string stage, + string modelType, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default diff --git a/src/SIL.Machine.AspNetCore/Services/IClearMLQueueService.cs b/src/SIL.Machine.AspNetCore/Services/IClearMLQueueService.cs new file mode 100644 index 000000000..32bcf270a --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IClearMLQueueService.cs @@ -0,0 +1,6 @@ +namespace SIL.Machine.AspNetCore.Services; + +public interface IClearMLQueueService +{ + public int GetQueueSize(TranslationEngineType engineType); +} diff --git a/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs b/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs index 674c35637..bad8bf576 100644 --- a/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs +++ b/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs @@ -14,13 +14,14 @@ Task CreateTaskAsync( string buildId, string projectId, string script, + string dockerImage, CancellationToken cancellationToken = default ); Task DeleteTaskAsync(string id, CancellationToken cancellationToken = default); - Task EnqueueTaskAsync(string id, CancellationToken cancellationToken = default); + Task EnqueueTaskAsync(string id, string queue, CancellationToken cancellationToken = default); Task DequeueTaskAsync(string id, CancellationToken cancellationToken = default); Task StopTaskAsync(string id, CancellationToken cancellationToken = default); - Task> GetTasksForCurrentQueueAsync(CancellationToken cancellationToken = default); + Task> GetTasksForQueueAsync(string queue, CancellationToken cancellationToken = default); Task GetTaskByNameAsync(string name, CancellationToken cancellationToken = default); Task> GetTasksByIdAsync( IEnumerable ids, diff --git a/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs index eb80b33dd..3c1c37510 100644 --- a/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs @@ -4,5 +4,5 @@ public interface IHangfireBuildJobFactory { TranslationEngineType EngineType { get; } - Job CreateJob(string engineId, string buildId, string stage, object? data, string? buildOptions); + Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions); } diff --git a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs index 46ef238b5..51d7fe1e5 100644 --- a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs @@ -9,6 +9,8 @@ IInteractiveTranslationModel Create( ITruecaser truecaser ); ITrainer CreateTrainer(string engineId, IRangeTokenizer tokenizer, IParallelTextCorpus corpus); + Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); + Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); void InitNew(string engineId); void Cleanup(string engineId); } diff --git a/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs b/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs index c1238b1e2..0fbcdef03 100644 --- a/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs @@ -43,7 +43,7 @@ Task StartBuildAsync( Task GetModelDownloadUrlAsync(string engineId, CancellationToken cancellationToken = default); - Task GetQueueSizeAsync(CancellationToken cancellationToken = default); + int GetQueueSize(); bool IsLanguageNativeToModel(string language, out string internalCode); } diff --git a/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs index 21cf753d7..baf7d75a9 100644 --- a/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs @@ -24,17 +24,26 @@ private async Task CheckModelsAsync(CancellationToken cancellationToken) NmtEngineService.ModelDirectory, cancellationToken: cancellationToken ); - // Get all engine ids from the database + // Get all NMT engine ids from the database IReadOnlyList? allEngines = await _engines.GetAllAsync(cancellationToken: cancellationToken); - IEnumerable validFilenames = allEngines.Select(e => - NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision) - ); + IEnumerable validNmtFilenames = allEngines + .Where(e => e.Type == TranslationEngineType.Nmt) + .Select(e => NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision)); // If there is a currently running build that creates and pushes a new file, but the database has not // updated yet, don't delete the new file. - IEnumerable validFilenamesForNextBuild = allEngines.Select(e => - NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision + 1) - ); - HashSet filenameFilter = validFilenames.Concat(validFilenamesForNextBuild).ToHashSet(); + IEnumerable validNmtFilenamesForNextBuild = allEngines + .Where(e => e.Type == TranslationEngineType.Nmt) + .Select(e => NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision + 1)); + + // Add SMT engines + IEnumerable validSmtModels = allEngines + .Where(e => e.Type == TranslationEngineType.SmtTransfer) + .Select(e => SmtTransferEngineService.GetModelPath(e.EngineId)); + + HashSet filenameFilter = validNmtFilenames + .Concat(validNmtFilenamesForNextBuild) + .Concat(validSmtModels) + .ToHashSet(); foreach (string path in paths) { diff --git a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs index dfc8423ea..5d507bba4 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs @@ -3,27 +3,26 @@ public class NmtClearMLBuildJobFactory( ISharedFileService sharedFileService, ILanguageTagService languageTagService, - IRepository engines, - IOptionsMonitor options + IRepository engines ) : IClearMLBuildJobFactory { private readonly ISharedFileService _sharedFileService = sharedFileService; private readonly ILanguageTagService _languageTagService = languageTagService; private readonly IRepository _engines = engines; - private readonly IOptionsMonitor _options = options; public TranslationEngineType EngineType => TranslationEngineType.Nmt; public async Task CreateJobScriptAsync( string engineId, string buildId, - string stage, + string modelType, + BuildStage stage, object? data = null, string? buildOptions = null, CancellationToken cancellationToken = default ) { - if (stage == NmtBuildStages.Train) + if (stage == BuildStage.Train) { TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); if (engine is null) @@ -36,7 +35,7 @@ public async Task CreateJobScriptAsync( _languageTagService.ConvertToFlores200Code(engine.TargetLanguage, out string trgLang); return "from machine.jobs.build_nmt_engine import run\n" + "args = {\n" - + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" + + $" 'model_type': '{modelType}',\n" + $" 'engine_id': '{engineId}',\n" + $" 'build_id': '{buildId}',\n" + $" 'src_lang': '{srcLang}',\n" diff --git a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs index 7c19e3c37..28af9ee9a 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs @@ -1,12 +1,5 @@ namespace SIL.Machine.AspNetCore.Services; -public static class NmtBuildStages -{ - public const string Preprocess = "preprocess"; - public const string Train = "train"; - public const string Postprocess = "postprocess"; -} - public class NmtEngineService( IPlatformService platformService, IDistributedReaderWriterLockFactory lockFactory, @@ -14,7 +7,7 @@ public class NmtEngineService( IRepository engines, IBuildJobService buildJobService, ILanguageTagService languageTagService, - ClearMLMonitorService clearMLMonitorService, + IClearMLQueueService clearMLQueueService, ISharedFileService sharedFileService ) : ITranslationEngineService { @@ -23,7 +16,7 @@ ISharedFileService sharedFileService private readonly IDataAccessContext _dataAccessContext = dataAccessContext; private readonly IRepository _engines = engines; private readonly IBuildJobService _buildJobService = buildJobService; - private readonly ClearMLMonitorService _clearMLMonitorService = clearMLMonitorService; + private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService; private readonly ILanguageTagService _languageTagService = languageTagService; private readonly ISharedFileService _sharedFileService = sharedFileService; public const string ModelDirectory = "models/"; @@ -54,15 +47,11 @@ public async Task CreateAsync( EngineId = engineId, SourceLanguage = sourceLanguage, TargetLanguage = targetLanguage, + Type = TranslationEngineType.Nmt, IsModelPersisted = isModelPersisted ?? false // models are not persisted if not specified }; await _engines.InsertAsync(translationEngine, ct); - await _buildJobService.CreateEngineAsync( - [BuildJobType.Cpu, BuildJobType.Gpu], - engineId, - engineName, - ct - ); + await _buildJobService.CreateEngineAsync(engineId, engineName, ct); return translationEngine; }, cancellationToken: cancellationToken @@ -78,11 +67,7 @@ public async Task DeleteAsync(string engineId, CancellationToken cancellationTok await CancelBuildJobAsync(engineId, cancellationToken); await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); - await _buildJobService.DeleteEngineAsync( - new[] { BuildJobType.Cpu, BuildJobType.Gpu }, - engineId, - CancellationToken.None - ); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); } await _lockFactory.DeleteAsync(engineId, CancellationToken.None); } @@ -103,11 +88,10 @@ public async Task StartBuildAsync( throw new InvalidOperationException("The engine is already building or in the process of canceling."); await _buildJobService.StartBuildJobAsync( - BuildJobType.Cpu, - TranslationEngineType.Nmt, + BuildJobRunnerType.Hangfire, engineId, buildId, - NmtBuildStages.Preprocess, + BuildStage.Preprocess, corpora, buildOptions, cancellationToken @@ -185,9 +169,9 @@ public Task TrainSegmentPairAsync( throw new NotSupportedException(); } - public Task GetQueueSizeAsync(CancellationToken cancellationToken = default) + public int GetQueueSize() { - return Task.FromResult(_clearMLMonitorService.QueueSize); + return _clearMLQueueService.GetQueueSize(Type); } public bool IsLanguageNativeToModel(string language, out string internalCode) diff --git a/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs index 1f6fe0480..746c336aa 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs @@ -6,11 +6,11 @@ public class NmtHangfireBuildJobFactory : IHangfireBuildJobFactory { public TranslationEngineType EngineType => TranslationEngineType.Nmt; - public Job CreateJob(string engineId, string buildId, string stage, object? data, string? buildOptions) + public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions) { return stage switch { - NmtBuildStages.Preprocess + BuildStage.Preprocess => CreateJob>( engineId, buildId, @@ -18,9 +18,8 @@ public Job CreateJob(string engineId, string buildId, string stage, object? data data, buildOptions ), - NmtBuildStages.Postprocess + BuildStage.Postprocess => CreateJob(engineId, buildId, "nmt", data, buildOptions), - NmtBuildStages.Train => CreateJob(engineId, buildId, "nmt", buildOptions), _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), }; } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs index d71937956..ba17ef414 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs @@ -7,10 +7,8 @@ public class NmtPostprocessBuildJob( IBuildJobService buildJobService, ILogger logger, ISharedFileService sharedFileService -) : HangfireBuildJob<(int, double)>(platformService, engines, lockFactory, buildJobService, logger) +) : PostprocessBuildJob(platformService, engines, lockFactory, buildJobService, logger, sharedFileService) { - private readonly ISharedFileService _sharedFileService = sharedFileService; - protected override async Task DoWorkAsync( string engineId, string buildId, @@ -38,44 +36,4 @@ await PlatformService.BuildCompletedAsync( Logger.LogInformation("Build completed ({0}).", buildId); } - - protected override async Task CleanupAsync( - string engineId, - string buildId, - (int, double) data, - IDistributedReaderWriterLock @lock, - JobCompletionStatus completionStatus - ) - { - if (completionStatus is JobCompletionStatus.Restarting) - return; - - try - { - if (completionStatus is not JobCompletionStatus.Faulted) - await _sharedFileService.DeleteAsync($"builds/{buildId}/"); - } - catch (Exception e) - { - Logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); - } - } - - private async Task InsertPretranslationsAsync(string engineId, string buildId, CancellationToken cancellationToken) - { - await using var targetPretranslateStream = await _sharedFileService.OpenReadAsync( - $"builds/{buildId}/pretranslate.trg.json", - cancellationToken - ); - - IAsyncEnumerable pretranslations = JsonSerializer - .DeserializeAsyncEnumerable( - targetPretranslateStream, - new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, - cancellationToken - ) - .OfType(); - - await PlatformService.InsertPretranslationsAsync(engineId, pretranslations, cancellationToken); - } } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs index a7f86e20d..6cdcbfdee 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -1,15 +1,7 @@ namespace SIL.Machine.AspNetCore.Services; -public class NmtPreprocessBuildJob : HangfireBuildJob> +public class NmtPreprocessBuildJob : PreprocessBuildJob { - private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; - - private readonly ISharedFileService _sharedFileService; - private readonly ICorpusService _corpusService; - private readonly ILanguageTagService _languageTagService; - private int _seed = 1234; - private Random _random; - public NmtPreprocessBuildJob( IPlatformService platformService, IRepository engines, @@ -20,434 +12,18 @@ public NmtPreprocessBuildJob( ICorpusService corpusService, ILanguageTagService languageTagService ) - : base(platformService, engines, lockFactory, buildJobService, logger) + : base(platformService, engines, lockFactory, logger, buildJobService, sharedFileService, corpusService) { - _sharedFileService = sharedFileService; - _corpusService = corpusService; _languageTagService = languageTagService; - _random = new Random(_seed); - } - - internal int Seed - { - get => _seed; - set - { - if (_seed != value) - { - _seed = value; - _random = new Random(_seed); - } - } - } - - protected override async Task DoWorkAsync( - string engineId, - string buildId, - IReadOnlyList data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - (int trainCount, int pretranslateCount) = await WriteDataFilesAsync( - buildId, - data, - buildOptions, - cancellationToken - ); - - // Log summary of build data - JsonObject buildPreprocessSummary = - new() - { - { "Event", "BuildPreprocess" }, - { "EngineId", engineId }, - { "BuildId", buildId }, - { "NumTrainRows", trainCount }, - { "NumPretranslateRows", pretranslateCount } - }; - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new OperationCanceledException($"Engine {engineId} does not exist. Build canceled."); - - bool sourceTagInFlores200 = _languageTagService.ConvertToFlores200Code( - engine.SourceLanguage, - out string srcLang - ); - buildPreprocessSummary.Add("SourceLanguageResolved", srcLang); - bool targetTagInFlores200 = _languageTagService.ConvertToFlores200Code( - engine.TargetLanguage, - out string trgLang - ); - buildPreprocessSummary.Add("TargetLanguageResolved", trgLang); - Logger.LogInformation("{summary}", buildPreprocessSummary.ToJsonString()); - - if (trainCount == 0 && (!sourceTagInFlores200 || !targetTagInFlores200)) - { - throw new InvalidOperationException( - $"Neither language code in build {buildId} are known to the base model, and the data specified for training was empty. Build canceled." - ); - } - - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - bool canceling = !await BuildJobService.StartBuildJobAsync( - BuildJobType.Gpu, - TranslationEngineType.Nmt, - engineId, - buildId, - NmtBuildStages.Train, - buildOptions: buildOptions, - cancellationToken: cancellationToken - ); - if (canceling) - throw new OperationCanceledException(); - } - } - - private async Task<(int TrainCount, int PretranslateCount)> WriteDataFilesAsync( - string buildId, - IReadOnlyList corpora, - string? buildOptions, - CancellationToken cancellationToken - ) - { - JsonObject? buildOptionsObject = null; - if (buildOptions is not null) - { - buildOptionsObject = JsonSerializer.Deserialize(buildOptions); - } - await using StreamWriter sourceTrainWriter = - new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); - await using StreamWriter targetTrainWriter = - new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); - await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( - $"builds/{buildId}/pretranslate.src.json", - cancellationToken - ); - await using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions); - - int trainCount = 0; - int pretranslateCount = 0; - pretranslateWriter.WriteStartArray(); - foreach (Corpus corpus in corpora) - { - ITextCorpus[] sourceTextCorpora = _corpusService.CreateTextCorpora(corpus.SourceFiles).ToArray(); - ITextCorpus targetTextCorpus = - _corpusService.CreateTextCorpora(corpus.TargetFiles).FirstOrDefault() ?? new DictionaryTextCorpus(); - - if (sourceTextCorpora.Length == 0) - continue; - - int skipCount = 0; - foreach (Row?[] rows in AlignTrainCorpus(corpus, sourceTextCorpora, targetTextCorpus)) - { - if (skipCount > 0) - { - skipCount--; - continue; - } - - Row[] trainRows = rows.Where(row => IsInTrain(row, corpus)).Cast().ToArray(); - if (trainRows.Length > 0) - { - Row row = trainRows[0]; - if (rows.Length > 1) - { - Row[] nonEmptyRows = trainRows.Where(r => r.SourceSegment.Length > 0).ToArray(); - if (nonEmptyRows.Length > 0) - row = nonEmptyRows[_random.Next(nonEmptyRows.Length)]; - } - - await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); - skipCount = row.RowCount - 1; - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) - trainCount++; - } - } - - if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) - { - ITextCorpus? sourceTermCorpus = _corpusService.CreateTermCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTermCorpus = _corpusService.CreateTermCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTermCorpus is not null && targetTermCorpus is not null) - { - IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); - foreach (ParallelTextRow row in parallelKeyTermsCorpus) - { - await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); - trainCount++; - } - } - } - - foreach (Row row in AlignPretranslateCorpus(corpus, sourceTextCorpora[0], targetTextCorpus)) - { - if ( - IsInPretranslate(row, corpus) - && row.SourceSegment.Length > 0 - && (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus)) - ) - { - pretranslateWriter.WriteStartObject(); - pretranslateWriter.WriteString("corpusId", corpus.Id); - pretranslateWriter.WriteString("textId", row.TextId); - pretranslateWriter.WriteStartArray("refs"); - foreach (object rowRef in row.Refs) - pretranslateWriter.WriteStringValue(rowRef.ToString()); - pretranslateWriter.WriteEndArray(); - pretranslateWriter.WriteString("translation", row.SourceSegment); - pretranslateWriter.WriteEndObject(); - pretranslateCount++; - } - } - } - pretranslateWriter.WriteEndArray(); - - return (trainCount, pretranslateCount); - } - - protected override async Task CleanupAsync( - string engineId, - string buildId, - IReadOnlyList data, - IDistributedReaderWriterLock @lock, - JobCompletionStatus completionStatus - ) - { - if (completionStatus is JobCompletionStatus.Canceled) - { - try - { - await _sharedFileService.DeleteAsync($"builds/{buildId}/"); - } - catch (Exception e) - { - Logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId); - } - } + PretranslationEnabled = true; + EngineType = TranslationEngineType.Nmt; } - private static bool IsInTrain(Row? row, Corpus corpus) - { - return IsIncluded(row, corpus.TrainOnTextIds, corpus.TrainOnChapters); - } - - private static bool IsInPretranslate(Row? row, Corpus corpus) - { - return IsIncluded(row, corpus.PretranslateTextIds, corpus.PretranslateChapters); - } - - private static bool IsIncluded( - Row? row, - IReadOnlySet? textIds, - IReadOnlyDictionary>? chapters - ) - { - if (row is null) - return false; - if (chapters is not null) - { - return row.Refs.Any(r => IsInChapters(chapters, r)); - } - if (textIds is not null) - { - return textIds.Contains(row.TextId); - } - return true; - } - - private static bool IsInChapters(IReadOnlyDictionary> bookChapters, object rowRef) - { - if (rowRef is not ScriptureRef sr) - return false; - return bookChapters.TryGetValue(sr.Book, out HashSet? chapters) - && (chapters.Contains(sr.ChapterNum) || chapters.Count == 0); - } - - private static IEnumerable AlignTrainCorpus( - Corpus corpus, - IReadOnlyList srcCorpora, - ITextCorpus trgCorpus - ) - { - IEnumerable? textIds = corpus.TrainOnChapters is not null - ? corpus.TrainOnChapters.Keys - : corpus.TrainOnTextIds; - srcCorpora = srcCorpora.Select(sc => sc.FilterTexts(textIds)).ToArray(); - trgCorpus = trgCorpus.FilterTexts(textIds); - - if (trgCorpus.IsScripture()) - { - return srcCorpora - .Select(sc => AlignScripture(sc, trgCorpus)) - .ZipMany(rows => rows.ToArray()) - // filter out every list that only contains completely empty rows - .Where(rows => rows.Any(r => r is null || r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0)); - } - - IEnumerable sourceOnlyRows = srcCorpora - .Select(sc => sc.AlignRows(trgCorpus, allSourceRows: true)) - .ZipMany(rows => - rows.Where(r => r.TargetSegment.Count == 0) - .Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1)) - .ToArray() - ); - - IEnumerable targetRows = srcCorpora - .Select(sc => sc.AlignRows(trgCorpus, allTargetRows: true)) - .ZipMany(rows => - rows.Where(r => r.TargetSegment.Count > 0) - .Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1)) - .ToArray() - ); - - return sourceOnlyRows - .Concat(targetRows) - // filter out every list that only contains completely empty rows - .Where(rows => rows.Any(r => r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0)); - } - - private static IEnumerable AlignScripture(ITextCorpus srcCorpus, ITextCorpus trgCorpus) - { - int rowCount = 0; - StringBuilder srcSegBuffer = new(); - StringBuilder trgSegBuffer = new(); - HashSet vrefs = []; - foreach ( - (VerseRef vref, string srcSegment, string trgSegment) in srcCorpus - .ExtractScripture() - .Select(r => (r.CorpusVerseRef, r.Text)) - .Zip( - trgCorpus.ExtractScripture().Select(r => r.Text), - (s, t) => (VerseRef: s.CorpusVerseRef, SourceSegment: s.Text, TargetSegment: t) - ) - ) - { - if (srcSegment == "" && trgSegment == "") - { - vrefs.UnionWith(vref.AllVerses()); - rowCount++; - } - else if (srcSegment == "") - { - vrefs.UnionWith(vref.AllVerses()); - if (trgSegment.Length > 0) - { - if (trgSegBuffer.Length > 0) - trgSegBuffer.Append(' '); - trgSegBuffer.Append(trgSegment); - } - rowCount++; - } - else if (trgSegment == "") - { - vrefs.UnionWith(vref.AllVerses()); - if (srcSegment.Length > 0) - { - if (srcSegBuffer.Length > 0) - srcSegBuffer.Append(' '); - srcSegBuffer.Append(srcSegment); - } - rowCount++; - } - else - { - if (rowCount > 0) - { - yield return new( - vrefs.First().Book, - vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), - srcSegBuffer.ToString(), - trgSegBuffer.ToString(), - rowCount - ); - for (int i = 0; i < rowCount - 1; i++) - yield return null; - srcSegBuffer.Clear(); - trgSegBuffer.Clear(); - vrefs.Clear(); - rowCount = 0; - } - vrefs.UnionWith(vref.AllVerses()); - srcSegBuffer.Append(srcSegment); - trgSegBuffer.Append(trgSegment); - rowCount++; - } - } - - if (rowCount > 0) - { - yield return new( - vrefs.First().Book, - vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), - srcSegBuffer.ToString(), - trgSegBuffer.ToString(), - rowCount - ); - for (int i = 0; i < rowCount - 1; i++) - yield return null; - } - } + private readonly ILanguageTagService _languageTagService; - private static IEnumerable AlignPretranslateCorpus(Corpus corpus, ITextCorpus srcCorpus, ITextCorpus trgCorpus) + protected override string ResolveLanguageCode(string languageCode) { - IEnumerable? textIds = corpus.PretranslateChapters is not null - ? corpus.PretranslateChapters.Keys - : corpus.PretranslateTextIds; - srcCorpus = srcCorpus.FilterTexts(textIds); - trgCorpus = trgCorpus.FilterTexts(textIds); - - int rowCount = 0; - StringBuilder srcSegBuffer = new(); - StringBuilder trgSegBuffer = new(); - List refs = []; - string textId = ""; - foreach (ParallelTextRow row in srcCorpus.AlignRows(trgCorpus, allSourceRows: true)) - { - if (!row.IsTargetRangeStart && row.IsTargetInRange) - { - refs.AddRange(row.Refs); - if (row.SourceText.Length > 0) - { - if (srcSegBuffer.Length > 0) - srcSegBuffer.Append(' '); - srcSegBuffer.Append(row.SourceText); - } - rowCount++; - } - else - { - if (rowCount > 0) - { - yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); - textId = ""; - srcSegBuffer.Clear(); - trgSegBuffer.Clear(); - refs.Clear(); - rowCount = 0; - } - - textId = row.TextId; - refs.AddRange(row.Refs); - srcSegBuffer.Append(row.SourceText); - trgSegBuffer.Append(row.TargetText); - rowCount++; - } - } - - if (rowCount > 0) - yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + _languageTagService.ConvertToFlores200Code(languageCode, out string resolvedCode); + return resolvedCode; } - - private record Row( - string TextId, - IReadOnlyList Refs, - string SourceSegment, - string TargetSegment, - int RowCount - ); } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs deleted file mode 100644 index 6f9008552..000000000 --- a/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs +++ /dev/null @@ -1,142 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -// TODO: The Hangfire implementation of the NMT train stage is not complete, DO NOT USE -// see https://github.com/sillsdev/machine/issues/103 -public class NmtTrainBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - IBuildJobService buildJobService, - ILogger logger, - ISharedFileService sharedFileService, - IOptionsMonitor options -) : HangfireBuildJob(platformService, engines, lockFactory, buildJobService, logger) -{ - private readonly ISharedFileService _sharedFileService = sharedFileService; - private readonly IOptionsMonitor _options = options; - - protected override async Task DoWorkAsync( - string engineId, - string buildId, - object? data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - TranslationEngine? engine = await Engines.GetAsync( - e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, - cancellationToken - ); - if (engine is null) - throw new OperationCanceledException(); - - try - { - Installer.LogMessage += Log; - await Installer.SetupPython(); - await Installer.TryInstallPip(); - await PipInstallModuleAsync( - "sil-machine[jobs,huggingface,sentencepiece]", - cancellationToken: cancellationToken - ); - await PipInstallModuleAsync( - "torch", - indexUrl: "https://download.pytorch.org/whl/cu117", - cancellationToken: cancellationToken - ); - await PipInstallModuleAsync("accelerate", cancellationToken: cancellationToken); - - PythonEngine.Initialize(); - - using (Py.GIL()) - { - PythonEngine.Exec( - "from machine.jobs.build_nmt_engine import run\n" - + "args = {\n" - + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" - + $" 'engine_id': '{engineId}',\n" - + $" 'build_id': '{buildId}',\n" - + $" 'src_lang': '{ConvertLanguageTag(engine.SourceLanguage)}',\n" - + $" 'trg_lang': '{ConvertLanguageTag(engine.TargetLanguage)}',\n" - + $" 'shared_file_uri': '{_sharedFileService.GetBaseUri()}',\n" - + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") - // buildRevision + 1 because the build revision is incremented after the build job - // is finished successfully but the file should be saved with the new revision number - + ( - engine.IsModelPersisted - ? $" 'save_model': '{engine.Id}_{engine.BuildRevision + 1}',\n" - : "" - ) - + $" 'clearml': False\n" - + "}\n" - + "run(args)\n" - ); - } - } - finally - { - Installer.LogMessage -= Log; - } - } - - private void Log(string message) - { - Logger.LogInformation(message); - } - - private static string ConvertLanguageTag(string languageTag) - { - if ( - !IetfLanguageTag.TryGetSubtags( - languageTag, - out LanguageSubtag languageSubtag, - out ScriptSubtag scriptSubtag, - out _, - out _ - ) - ) - { - return languageTag; - } - - // Convert to NLLB language codes - return $"{languageSubtag.Iso3Code}_{scriptSubtag.Code}"; - } - - public async Task PipInstallModuleAsync( - string module_name, - string version = "", - string indexUrl = "", - bool force = false, - CancellationToken cancellationToken = default - ) - { - try - { - Python.Deployment.Installer.LogMessage += Log; - if (!Installer.IsModuleInstalled(module_name) || force) - { - string text = Path.Combine(Python.Deployment.Installer.EmbeddedPythonHome, "Scripts", "pip"); - string text2 = (force ? " --force-reinstall" : ""); - if (version.Length > 0) - { - version = "==" + version; - } - if (indexUrl.Length > 0) - { - text2 += " --index-url " + indexUrl; - } - - await Python.Deployment.Installer.RunCommand( - text + " install " + module_name + version + " " + text2, - cancellationToken - ); - } - } - finally - { - Python.Deployment.Installer.LogMessage -= Log; - } - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs new file mode 100644 index 000000000..009fda8ad --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs @@ -0,0 +1,57 @@ +namespace SIL.Machine.AspNetCore.Services; + +public abstract class PostprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService +) : HangfireBuildJob<(int, double)>(platformService, engines, lockFactory, buildJobService, logger) +{ + protected readonly ISharedFileService SharedFileService = sharedFileService; + + protected override async Task CleanupAsync( + string engineId, + string buildId, + (int, double) data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Restarting) + return; + + try + { + if (completionStatus is not JobCompletionStatus.Faulted) + await SharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + } + } + + protected async Task InsertPretranslationsAsync( + string engineId, + string buildId, + CancellationToken cancellationToken + ) + { + await using var targetPretranslateStream = await SharedFileService.OpenReadAsync( + $"builds/{buildId}/pretranslate.trg.json", + cancellationToken + ); + + IAsyncEnumerable pretranslations = JsonSerializer + .DeserializeAsyncEnumerable( + targetPretranslateStream, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, + cancellationToken + ) + .OfType(); + + await PlatformService.InsertPretranslationsAsync(engineId, pretranslations, cancellationToken); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs new file mode 100644 index 000000000..9d562d19a --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs @@ -0,0 +1,431 @@ +namespace SIL.Machine.AspNetCore.Services; + +public abstract class PreprocessBuildJob : HangfireBuildJob> +{ + private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; + + internal BuildJobRunnerType TrainJobRunnerType { get; init; } = BuildJobRunnerType.ClearML; + protected TranslationEngineType EngineType { get; init; } + protected bool PretranslationEnabled { get; init; } + + private readonly ISharedFileService _sharedFileService; + private readonly ICorpusService _corpusService; + private int _seed = 1234; + private Random _random; + + public PreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService + ) + : base(platformService, engines, lockFactory, buildJobService, logger) + { + _sharedFileService = sharedFileService; + _corpusService = corpusService; + _random = new Random(_seed); + } + + internal int Seed + { + get => _seed; + set + { + if (_seed != value) + { + _seed = value; + _random = new Random(_seed); + } + } + } + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + IReadOnlyList data, + string? buildOptions, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + (int trainCount, int pretranslateCount) = await WriteDataFilesAsync( + buildId, + data, + buildOptions, + cancellationToken + ); + + // Log summary of build data + JsonObject buildPreprocessSummary = + new() + { + { "Event", "BuildPreprocess" }, + { "EngineId", engineId }, + { "BuildId", buildId }, + { "NumTrainRows", trainCount }, + { "NumPretranslateRows", pretranslateCount } + }; + TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new OperationCanceledException($"Engine {engineId} does not exist. Build canceled."); + + buildPreprocessSummary.Add("SourceLanguageResolved", ResolveLanguageCode(engine.SourceLanguage)); + buildPreprocessSummary.Add("TargetLanguageResolved", ResolveLanguageCode(engine.TargetLanguage)); + Logger.LogInformation("{summary}", buildPreprocessSummary.ToJsonString()); + + cancellationToken.ThrowIfCancellationRequested(); + + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + bool canceling = !await BuildJobService.StartBuildJobAsync( + TrainJobRunnerType, + engineId, + buildId, + BuildStage.Train, + data: new object(), + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); + } + } + + private async Task<(int TrainCount, int PretranslateCount)> WriteDataFilesAsync( + string buildId, + IReadOnlyList corpora, + string? buildOptions, + CancellationToken cancellationToken + ) + { + JsonObject? buildOptionsObject = null; + if (buildOptions is not null) + { + buildOptionsObject = JsonSerializer.Deserialize(buildOptions); + } + await using StreamWriter sourceTrainWriter = + new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken)); + await using StreamWriter targetTrainWriter = + new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); + + using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/pretranslate.src.json", + cancellationToken + ); + using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions); + + int trainCount = 0; + int pretranslateCount = 0; + pretranslateWriter.WriteStartArray(); + foreach (Corpus corpus in corpora) + { + ITextCorpus[] sourceTextCorpora = _corpusService.CreateTextCorpora(corpus.SourceFiles).ToArray(); + ITextCorpus targetTextCorpus = + _corpusService.CreateTextCorpora(corpus.TargetFiles).FirstOrDefault() ?? new DictionaryTextCorpus(); + + if (sourceTextCorpora.Length == 0) + continue; + + int skipCount = 0; + foreach (Row?[] rows in AlignTrainCorpus(sourceTextCorpora, targetTextCorpus)) + { + if (skipCount > 0) + { + skipCount--; + continue; + } + + Row[] trainRows = rows.Where(r => r is not null && IsInTrain(r, corpus)).Cast().ToArray(); + if (trainRows.Length > 0) + { + Row row = trainRows[0]; + if (rows.Length > 1) + { + Row[] nonEmptyRows = trainRows.Where(r => r.SourceSegment.Length > 0).ToArray(); + if (nonEmptyRows.Length > 0) + row = nonEmptyRows[_random.Next(nonEmptyRows.Length)]; + } + + await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); + skipCount = row.RowCount - 1; + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + trainCount++; + } + } + + if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) + { + ITextCorpus? sourceTermCorpus = _corpusService.CreateTermCorpora(corpus.SourceFiles).FirstOrDefault(); + ITextCorpus? targetTermCorpus = _corpusService.CreateTermCorpora(corpus.TargetFiles).FirstOrDefault(); + if (sourceTermCorpus is not null && targetTermCorpus is not null) + { + IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); + foreach (ParallelTextRow row in parallelKeyTermsCorpus) + { + await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); + trainCount++; + } + } + } + + if (PretranslationEnabled) + { + foreach (Row row in AlignPretranslateCorpus(sourceTextCorpora[0], targetTextCorpus)) + { + if ( + IsInPretranslate(row, corpus) + && row.SourceSegment.Length > 0 + && (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus)) + ) + { + pretranslateWriter.WriteStartObject(); + pretranslateWriter.WriteString("corpusId", corpus.Id); + pretranslateWriter.WriteString("textId", row.TextId); + pretranslateWriter.WriteStartArray("refs"); + foreach (object rowRef in row.Refs) + pretranslateWriter.WriteStringValue(rowRef.ToString()); + pretranslateWriter.WriteEndArray(); + pretranslateWriter.WriteString("translation", row.SourceSegment); + pretranslateWriter.WriteEndObject(); + pretranslateCount++; + } + } + } + } + + pretranslateWriter.WriteEndArray(); + + return (trainCount, pretranslateCount); + } + + protected override async Task CleanupAsync( + string engineId, + string buildId, + IReadOnlyList data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Canceled) + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId); + } + } + } + + private static bool IsInTrain(Row row, Corpus corpus) + { + return IsIncluded(row, corpus.TrainOnAll, corpus.TrainOnTextIds, corpus.TrainOnChapters); + } + + private static bool IsInPretranslate(Row row, Corpus corpus) + { + return IsIncluded(row, corpus.PretranslateAll, corpus.PretranslateTextIds, corpus.PretranslateChapters); + } + + private static bool IsIncluded( + Row row, + bool all, + IReadOnlySet textIds, + IReadOnlyDictionary>? chapters + ) + { + if (chapters is not null) + { + if (row.Refs.Any(r => IsInChapters(chapters, r))) + return true; + } + return all || textIds.Contains(row.TextId); + } + + private static bool IsInChapters(IReadOnlyDictionary> bookChapters, object rowRef) + { + if (rowRef is not ScriptureRef sr) + return false; + return bookChapters.TryGetValue(sr.Book, out HashSet? chapters) + && (chapters.Contains(sr.ChapterNum) || chapters.Count == 0); + } + + private static IEnumerable AlignTrainCorpus(IReadOnlyList srcCorpora, ITextCorpus trgCorpus) + { + if (trgCorpus.IsScripture()) + { + return srcCorpora + .Select(sc => AlignScripture(sc, trgCorpus)) + .ZipMany(rows => rows.ToArray()) + // filter out every list that only contains completely empty rows + .Where(rows => rows.Any(r => r is null || r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0)); + } + + IEnumerable sourceOnlyRows = srcCorpora + .Select(sc => sc.AlignRows(trgCorpus, allSourceRows: true)) + .ZipMany(rows => + rows.Where(r => r.TargetSegment.Count == 0) + .Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1)) + .ToArray() + ); + + IEnumerable targetRows = srcCorpora + .Select(sc => sc.AlignRows(trgCorpus, allTargetRows: true)) + .ZipMany(rows => + rows.Where(r => r.TargetSegment.Count > 0) + .Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1)) + .ToArray() + ); + + return sourceOnlyRows + .Concat(targetRows) + // filter out every list that only contains completely empty rows + .Where(rows => rows.Any(r => r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0)); + } + + private static IEnumerable AlignScripture(ITextCorpus srcCorpus, ITextCorpus trgCorpus) + { + int rowCount = 0; + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + HashSet vrefs = []; + foreach ( + (VerseRef vref, string srcSegment, string trgSegment) in srcCorpus + .ExtractScripture() + .Select(r => (r.CorpusVerseRef, r.Text)) + .Zip( + trgCorpus.ExtractScripture().Select(r => r.Text), + (s, t) => (VerseRef: s.CorpusVerseRef, SourceSegment: s.Text, TargetSegment: t) + ) + ) + { + if (srcSegment == "" && trgSegment == "") + { + vrefs.UnionWith(vref.AllVerses()); + rowCount++; + } + else if (srcSegment == "") + { + vrefs.UnionWith(vref.AllVerses()); + if (trgSegment.Length > 0) + { + if (trgSegBuffer.Length > 0) + trgSegBuffer.Append(' '); + trgSegBuffer.Append(trgSegment); + } + rowCount++; + } + else if (trgSegment == "") + { + vrefs.UnionWith(vref.AllVerses()); + if (srcSegment.Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(srcSegment); + } + rowCount++; + } + else + { + if (rowCount > 0) + { + yield return new( + vrefs.First().Book, + vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), + srcSegBuffer.ToString(), + trgSegBuffer.ToString(), + rowCount + ); + for (int i = 0; i < rowCount - 1; i++) + yield return null; + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + vrefs.Clear(); + rowCount = 0; + } + vrefs.UnionWith(vref.AllVerses()); + srcSegBuffer.Append(srcSegment); + trgSegBuffer.Append(trgSegment); + rowCount++; + } + } + + if (rowCount > 0) + { + yield return new( + vrefs.First().Book, + vrefs.Order().Select(v => new ScriptureRef(v)).Cast().ToArray(), + srcSegBuffer.ToString(), + trgSegBuffer.ToString(), + rowCount + ); + for (int i = 0; i < rowCount - 1; i++) + yield return null; + } + } + + private static IEnumerable AlignPretranslateCorpus(ITextCorpus srcCorpus, ITextCorpus trgCorpus) + { + int rowCount = 0; + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + List refs = []; + string textId = ""; + foreach (ParallelTextRow row in srcCorpus.AlignRows(trgCorpus, allSourceRows: true)) + { + if (!row.IsTargetRangeStart && row.IsTargetInRange) + { + refs.AddRange(row.Refs); + if (row.SourceText.Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(row.SourceText); + } + rowCount++; + } + else + { + if (rowCount > 0) + { + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + textId = ""; + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + rowCount = 0; + } + + textId = row.TextId; + refs.AddRange(row.Refs); + srcSegBuffer.Append(row.SourceText); + trgSegBuffer.Append(row.TargetText); + rowCount++; + } + } + + if (rowCount > 0) + yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1); + } + + private record Row( + string TextId, + IReadOnlyList Refs, + string SourceSegment, + string TargetSegment, + int RowCount + ); + + protected virtual string ResolveLanguageCode(string languageCode) + { + return languageCode; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs index b3194bec4..f96ed0ae2 100644 --- a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs +++ b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs @@ -46,56 +46,89 @@ public override void Flush() { } public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; - public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) + { + try + { + using MemoryStream ms = new(buffer, offset, count); + WriteAsync(ms, count).Wait(); + } + catch (Exception e) + { + AbortAsync(e).Wait(); + throw; + } + } + + public override async ValueTask WriteAsync( + ReadOnlyMemory buffer, + CancellationToken cancellationToken = default + ) + { + try + { + using MemoryStream ms = new(buffer.ToArray(), 0, buffer.Length); + await WriteAsync(ms, buffer.Length); + } + catch (Exception e) + { + await AbortAsync(e); + throw; + } + } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { try { using MemoryStream ms = new(buffer, offset, count); + await WriteAsync(ms, count); + } + catch (Exception e) + { + await AbortAsync(e); + throw; + } + } - int bytesWritten = 0; + private async Task WriteAsync(MemoryStream ms, int count) + { + int bytesWritten = 0; - while (count > bytesWritten) - { - int partNumber = _uploadResponses.Count + 1; - UploadPartRequest request = - new() - { - BucketName = _bucketName, - Key = _key, - UploadId = _uploadId, - PartNumber = partNumber, - InputStream = ms, - PartSize = MaxPartSize - }; - request.StreamTransferProgress += new EventHandler( - (_, e) => - { - _logger.LogDebug( - "Transferred {e.TransferredBytes}/{e.TotalBytes}", - e.TransferredBytes, - e.TotalBytes - ); - } - ); - UploadPartResponse response = await _client.UploadPartAsync(request); - if (response.HttpStatusCode != HttpStatusCode.OK) + while (count > bytesWritten) + { + int partNumber = _uploadResponses.Count + 1; + UploadPartRequest request = + new() { - throw new HttpRequestException( - $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + BucketName = _bucketName, + Key = _key, + UploadId = _uploadId, + PartNumber = partNumber, + InputStream = ms, + PartSize = Math.Min(MaxPartSize, count - bytesWritten) + }; + request.StreamTransferProgress += new EventHandler( + (_, e) => + { + _logger.LogDebug( + "Transferred {e.TransferredBytes}/{e.TotalBytes}", + e.TransferredBytes, + e.TotalBytes ); } + ); + UploadPartResponse response = await _client.UploadPartAsync(request); + if (response.HttpStatusCode != HttpStatusCode.OK) + { + throw new HttpRequestException( + $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + } - _uploadResponses.Add(response); + _uploadResponses.Add(response); - bytesWritten += MaxPartSize; - } - } - catch (Exception e) - { - await AbortAsync(e); - throw; + bytesWritten += MaxPartSize; } } diff --git a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs index ea712e637..78304ac81 100644 --- a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs +++ b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs @@ -156,13 +156,10 @@ ServerCallContext context } } - public override async Task GetQueueSize( - GetQueueSizeRequest request, - ServerCallContext context - ) + public override Task GetQueueSize(GetQueueSizeRequest request, ServerCallContext context) { ITranslationEngineService engineService = GetEngineService(request.EngineType); - return new GetQueueSizeResponse { Size = await engineService.GetQueueSizeAsync(context.CancellationToken) }; + return Task.FromResult(new GetQueueSizeResponse { Size = engineService.GetQueueSize() }); } public override Task GetLanguageInfo( diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs deleted file mode 100644 index 54a19cc7f..000000000 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs +++ /dev/null @@ -1,137 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class SmtTransferBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - IBuildJobService buildJobService, - ILogger logger, - IRepository trainSegmentPairs, - ITruecaserFactory truecaserFactory, - ISmtModelFactory smtModelFactory, - ICorpusService corpusService -) : HangfireBuildJob>(platformService, engines, lockFactory, buildJobService, logger) -{ - private readonly IRepository _trainSegmentPairs = trainSegmentPairs; - private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; - private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; - private readonly ICorpusService _corpusService = corpusService; - - protected override Task InitializeAsync( - string engineId, - string buildId, - IReadOnlyList data, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - return _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); - } - - protected override async Task DoWorkAsync( - string engineId, - string buildId, - IReadOnlyList data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - await PlatformService.BuildStartedAsync(buildId, cancellationToken); - Logger.LogInformation("Build started ({0})", buildId); - var stopwatch = new Stopwatch(); - stopwatch.Start(); - - cancellationToken.ThrowIfCancellationRequested(); - - JsonObject? buildOptionsObject = null; - if (buildOptions is not null) - { - buildOptionsObject = JsonSerializer.Deserialize(buildOptions); - } - - var targetCorpora = new List(); - var parallelCorpora = new List(); - foreach (Corpus corpus in data) - { - ITextCorpus? sourceTextCorpus = _corpusService.CreateTextCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTextCorpus = _corpusService.CreateTextCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTextCorpus is null || targetTextCorpus is null) - continue; - - targetCorpora.Add(targetTextCorpus); - parallelCorpora.Add(sourceTextCorpus.AlignRows(targetTextCorpus)); - - if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) - { - ITextCorpus? sourceTermCorpus = _corpusService.CreateTermCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTermCorpus = _corpusService.CreateTermCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTermCorpus is not null && targetTermCorpus is not null) - { - IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); - parallelCorpora.Add(parallelKeyTermsCorpus); - } - } - } - - IParallelTextCorpus parallelCorpus = parallelCorpora.Flatten(); - ITextCorpus targetCorpus = targetCorpora.Flatten(); - - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - - using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineId, tokenizer, parallelCorpus); - using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineId, tokenizer, targetCorpus); - - cancellationToken.ThrowIfCancellationRequested(); - - var progress = new BuildProgress(PlatformService, buildId); - await smtModelTrainer.TrainAsync(progress, cancellationToken); - await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); - - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new OperationCanceledException(); - - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - cancellationToken.ThrowIfCancellationRequested(); - await smtModelTrainer.SaveAsync(CancellationToken.None); - await truecaseTrainer.SaveAsync(CancellationToken.None); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine.Id, - CancellationToken.None - ); - using ( - IInteractiveTranslationModel smtModel = _smtModelFactory.Create( - engineId, - tokenizer, - detokenizer, - truecaser - ) - ) - { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - } - - await PlatformService.BuildCompletedAsync( - buildId, - smtModelTrainer.Stats.TrainCorpusSize + segmentPairs.Count, - smtModelTrainer.Stats.Metrics["bleu"] * 100.0, - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); - } - - stopwatch.Stop(); - Logger.LogInformation("Build completed in {0}s ({1})", stopwatch.Elapsed.TotalSeconds, buildId); - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs new file mode 100644 index 000000000..367bdb2d2 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs @@ -0,0 +1,52 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferClearMLBuildJobFactory( + ISharedFileService sharedFileService, + IRepository engines +) : IClearMLBuildJobFactory +{ + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly IRepository _engines = engines; + + public TranslationEngineType EngineType => TranslationEngineType.SmtTransfer; + + public async Task CreateJobScriptAsync( + string engineId, + string buildId, + string modelType, + BuildStage stage, + object? data = null, + string? buildOptions = null, + CancellationToken cancellationToken = default + ) + { + if (stage == BuildStage.Train) + { + TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException("The engine does not exist."); + + Uri sharedFileUri = _sharedFileService.GetBaseUri(); + string baseUri = sharedFileUri.GetComponents(UriComponents.SchemeAndServer, UriFormat.Unescaped); + string folder = sharedFileUri.GetComponents(UriComponents.Path, UriFormat.Unescaped); + return "from machine.jobs.build_smt_engine import run\n" + + "args = {\n" + + $" 'model_type': '{modelType}',\n" + + $" 'engine_id': '{engineId}',\n" + + $" 'build_id': '{buildId}',\n" + + $" 'shared_file_uri': '{baseUri}',\n" + + $" 'shared_file_folder': '{folder}',\n" + + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") + // buildRevision + 1 because the build revision is incremented after the build job + // is finished successfully but the file should be saved with the new revision number + + (engine.IsModelPersisted ? $" 'save_model': '{engineId}',\n" : $"") + + $" 'clearml': True\n" + + "}\n" + + "run(args)\n"; + } + else + { + throw new ArgumentException("Unknown build stage.", nameof(stage)); + } + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index d0c254a23..54c14a556 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -1,10 +1,5 @@ namespace SIL.Machine.AspNetCore.Services; -public static class SmtTransferBuildStages -{ - public const string Train = "train"; -} - public class SmtTransferEngineService( IDistributedReaderWriterLockFactory lockFactory, IPlatformService platformService, @@ -13,7 +8,7 @@ public class SmtTransferEngineService( IRepository trainSegmentPairs, SmtTransferEngineStateService stateService, IBuildJobService buildJobService, - JobStorage jobStorage + IClearMLQueueService clearMLQueueService ) : ITranslationEngineService { private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; @@ -23,10 +18,17 @@ JobStorage jobStorage private readonly IRepository _trainSegmentPairs = trainSegmentPairs; private readonly SmtTransferEngineStateService _stateService = stateService; private readonly IBuildJobService _buildJobService = buildJobService; - private readonly JobStorage _jobStorage = jobStorage; + private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService; public TranslationEngineType Type => TranslationEngineType.SmtTransfer; + public const string ModelDirectory = "models/"; + + public static string GetModelPath(string engineId) + { + return $"{ModelDirectory}{engineId}.zip"; + } + public async Task CreateAsync( string engineId, string? engineName, @@ -52,10 +54,11 @@ public async Task CreateAsync( EngineId = engineId, SourceLanguage = sourceLanguage, TargetLanguage = targetLanguage, + Type = TranslationEngineType.SmtTransfer, IsModelPersisted = isModelPersisted ?? true // models are persisted if not specified }; await _engines.InsertAsync(translationEngine, ct); - await _buildJobService.CreateEngineAsync([BuildJobType.Cpu], engineId, engineName, ct); + await _buildJobService.CreateEngineAsync(engineId, engineName, ct); return translationEngine; }, cancellationToken: cancellationToken @@ -85,6 +88,7 @@ await _dataAccessContext.WithTransactionAsync( }, cancellationToken: cancellationToken ); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); if (_stateService.TryRemove(engineId, out SmtTransferEngineState? state)) { @@ -199,11 +203,10 @@ public async Task StartBuildAsync( throw new InvalidOperationException("The engine is already building or in the process of canceling."); await _buildJobService.StartBuildJobAsync( - BuildJobType.Cpu, - TranslationEngineType.SmtTransfer, + BuildJobRunnerType.Hangfire, engineId, buildId, - SmtTransferBuildStages.Train, + BuildStage.Preprocess, corpora, buildOptions, cancellationToken @@ -225,9 +228,9 @@ public async Task CancelBuildAsync(string engineId, CancellationToken cancellati } } - public Task GetQueueSizeAsync(CancellationToken cancellationToken = default) + public int GetQueueSize() { - return Task.FromResult(Convert.ToInt32(_jobStorage.GetMonitoringApi().EnqueuedCount("smt_transfer"))); + return _clearMLQueueService.GetQueueSize(Type); } public bool IsLanguageNativeToModel(string language, out string internalCode) diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs index 0da56c602..7ebd1bede 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs @@ -6,18 +6,28 @@ public class SmtTransferHangfireBuildJobFactory : IHangfireBuildJobFactory { public TranslationEngineType EngineType => TranslationEngineType.SmtTransfer; - public Job CreateJob(string engineId, string buildId, string stage, object? data, string? buildOptions) + public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions) { return stage switch { - SmtTransferBuildStages.Train - => CreateJob>( + BuildStage.Preprocess + => CreateJob>( engineId, buildId, "smt_transfer", data, buildOptions ), + BuildStage.Postprocess + => CreateJob( + engineId, + buildId, + "smt_transfer", + data, + buildOptions + ), + BuildStage.Train + => CreateJob(engineId, buildId, "smt_transfer", data, buildOptions), _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), }; } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs new file mode 100644 index 000000000..bc457a34a --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs @@ -0,0 +1,80 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferPostprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService, + IRepository trainSegmentPairs, + ISmtModelFactory smtModelFactory, + ITruecaserFactory truecaserFactory +) : PostprocessBuildJob(platformService, engines, lockFactory, buildJobService, logger, sharedFileService) +{ + private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; + private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; + private readonly IRepository _trainSegmentPairs = trainSegmentPairs; + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + (int, double) data, + string? buildOptions, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + cancellationToken.ThrowIfCancellationRequested(); + + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + await _smtModelFactory.DownloadBuiltEngineAsync(engineId, cancellationToken); + int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, cancellationToken); + await PlatformService.BuildCompletedAsync( + buildId, + trainSize: data.Item1 + segmentPairsSize, + confidence: Math.Round(data.Item2, 2, MidpointRounding.AwayFromZero), + cancellationToken: CancellationToken.None + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); + } + + Logger.LogInformation("Build completed ({0}).", buildId); + } + + private async Task TrainOnNewSegmentPairs(string engineId, CancellationToken cancellationToken) + { + TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new OperationCanceledException(); + + cancellationToken.ThrowIfCancellationRequested(); + IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( + p => p.TranslationEngineRef == engine.Id, + CancellationToken.None + ); + if (segmentPairs.Count == 0) + return segmentPairs.Count; + + var tokenizer = new LatinWordTokenizer(); + var detokenizer = new LatinWordDetokenizer(); + ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); + + using ( + IInteractiveTranslationModel smtModel = _smtModelFactory.Create(engineId, tokenizer, detokenizer, truecaser) + ) + { + foreach (TrainSegmentPair segmentPair in segmentPairs) + { + await smtModel.TrainSegmentAsync( + segmentPair.Source, + segmentPair.Target, + cancellationToken: CancellationToken.None + ); + } + await smtModel.SaveAsync(CancellationToken.None); + } + return segmentPairs.Count; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs new file mode 100644 index 000000000..a6317c753 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs @@ -0,0 +1,20 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferPreprocessBuildJob : PreprocessBuildJob +{ + public SmtTransferPreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService + ) + : base(platformService, engines, lockFactory, logger, buildJobService, sharedFileService, corpusService) + { + EngineType = TranslationEngineType.SmtTransfer; + PretranslationEnabled = false; + TrainJobRunnerType = BuildJobRunnerType.Hangfire; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs new file mode 100644 index 000000000..7796a6260 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs @@ -0,0 +1,91 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferTrainBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService, + ITruecaserFactory truecaserFactory, + ISmtModelFactory smtModelFactory +) : HangfireBuildJob(platformService, engines, lockFactory, buildJobService, logger) +{ + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; + private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + object? data, + string? buildOptions, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + DirectoryInfo tempDir = Directory.CreateTempSubdirectory(); + await DownloadTrainingText(buildId, tempDir.FullName, cancellationToken); + + // assemble corpus + DictionaryTextCorpus sourceCorpus = + new(new TextFileText("train", Path.Combine(tempDir.FullName, "train.src.txt"))); + DictionaryTextCorpus targetCorpus = + new(new TextFileText("train", Path.Combine(tempDir.FullName, "train.trg.txt"))); + ParallelTextCorpus parallelCorpus = new ParallelTextCorpus(sourceCorpus, targetCorpus); + int corpusSize = parallelCorpus.Count(includeEmpty: false); + + // train SMT model + var tokenizer = new LatinWordTokenizer(); + + using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineId, tokenizer, parallelCorpus); + using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineId, tokenizer, targetCorpus); + + cancellationToken.ThrowIfCancellationRequested(); + + var progress = new BuildProgress(PlatformService, buildId); + await smtModelTrainer.TrainAsync(progress, cancellationToken); + await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + await smtModelTrainer.SaveAsync(CancellationToken.None); + await truecaseTrainer.SaveAsync(CancellationToken.None); + + await _smtModelFactory.UploadBuiltEngineAsync(engineId, cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + bool canceling = !await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + engineId, + buildId, + BuildStage.Postprocess, + data: (smtModelTrainer.Stats.TrainCorpusSize, smtModelTrainer.Stats.Metrics["bleu"] * 100.0), + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); + } + } + + private async Task DownloadTrainingText(string buildId, string directory, CancellationToken cancellationToken) + { + using Stream srcText = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/train.src.txt", + cancellationToken + ); + using FileStream srcFileStream = File.Create(Path.Combine(directory, "train.src.txt")); + await srcText.CopyToAsync(srcFileStream, cancellationToken); + + using Stream tgtText = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/train.trg.txt", + cancellationToken + ); + using FileStream tgtFileStream = File.Create(Path.Combine(directory, "train.trg.txt")); + await tgtText.CopyToAsync(tgtFileStream, cancellationToken); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs index 075bc4d51..d3eb4f7f7 100644 --- a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs @@ -2,11 +2,13 @@ public class ThotSmtModelFactory( IOptionsMonitor options, - IOptionsMonitor engineOptions + IOptionsMonitor engineOptions, + ISharedFileService sharedFileService ) : ISmtModelFactory { private readonly IOptionsMonitor _options = options; private readonly IOptionsMonitor _engineOptions = engineOptions; + private readonly ISharedFileService _sharedFileService = sharedFileService; public IInteractiveTranslationModel Create( string engineId, @@ -44,6 +46,32 @@ IParallelTextCorpus corpus }; } + public async Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + { + string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); + if (!Directory.Exists(engineDir)) + Directory.CreateDirectory(engineDir); + string sharedFilePath = $"models/{engineId}.zip"; + using Stream sharedStream = await _sharedFileService.OpenReadAsync(sharedFilePath, cancellationToken); + ZipFile.ExtractToDirectory(sharedStream, engineDir, overwriteFiles: true); + await _sharedFileService.DeleteAsync(sharedFilePath); + } + + public async Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + { + // create zip archive in memory stream + // This cannot be created directly to the shared stream because it all needs to be written at once + using var memoryStream = new MemoryStream(); + string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); + ZipFile.CreateFromDirectory(engineDir, memoryStream); + + // copy to shared file + memoryStream.Seek(0, SeekOrigin.Begin); + string sharedFilePath = $"models/{engineId}.zip"; + using Stream sharedStream = await _sharedFileService.OpenWriteAsync(sharedFilePath, cancellationToken); + await sharedStream.WriteAsync(memoryStream.ToArray().AsMemory(0, (int)memoryStream.Length), cancellationToken); + } + public void InitNew(string engineId) { string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); diff --git a/src/SIL.Machine.AspNetCore/Usings.cs b/src/SIL.Machine.AspNetCore/Usings.cs index f42422c20..2f2bb7afb 100644 --- a/src/SIL.Machine.AspNetCore/Usings.cs +++ b/src/SIL.Machine.AspNetCore/Usings.cs @@ -40,8 +40,6 @@ global using Nito.AsyncEx; global using Nito.AsyncEx.Synchronous; global using Polly; -global using Python.Included; -global using Python.Runtime; global using SIL.DataAccess; global using SIL.Machine.AspNetCore.Configuration; global using SIL.Machine.AspNetCore.Models; diff --git a/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json b/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json index e171b5558..1f2a4ef6b 100644 --- a/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json +++ b/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json @@ -5,10 +5,8 @@ "Serval": "https://localhost:8444" }, "ClearML": { - "Queue": "jobs_backlog", "MaxSteps": 1000, - "Project": "dev", - "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + "Project": "dev" }, "SharedFile": { "Uri": "s3://aqua-ml-data/dev/" diff --git a/src/SIL.Machine.Serval.EngineServer/appsettings.json b/src/SIL.Machine.Serval.EngineServer/appsettings.json index ee9d91a8e..828d2a41e 100644 --- a/src/SIL.Machine.Serval.EngineServer/appsettings.json +++ b/src/SIL.Machine.Serval.EngineServer/appsettings.json @@ -11,10 +11,20 @@ "Nmt" ], "BuildJob": { - "Runners": { - "Cpu": "Hangfire", - "Gpu": "ClearML" - } + "ClearML": [ + { + "TranslationEngineType": "Nmt", + "ModelType": "huggingface", + "Queue": "jobs_backlog", + "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + }, + { + "TranslationEngineType": "SmtTransfer", + "ModelType": "hmm", + "Queue": "cpu_only", + "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + } + ] }, "SmtTransferEngine": { "EnginesDir": "/var/lib/machine/engines" diff --git a/src/SIL.Machine.Serval.JobServer/appsettings.Development.json b/src/SIL.Machine.Serval.JobServer/appsettings.Development.json index e171b5558..1f2a4ef6b 100644 --- a/src/SIL.Machine.Serval.JobServer/appsettings.Development.json +++ b/src/SIL.Machine.Serval.JobServer/appsettings.Development.json @@ -5,10 +5,8 @@ "Serval": "https://localhost:8444" }, "ClearML": { - "Queue": "jobs_backlog", "MaxSteps": 1000, - "Project": "dev", - "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + "Project": "dev" }, "SharedFile": { "Uri": "s3://aqua-ml-data/dev/" diff --git a/src/SIL.Machine.Serval.JobServer/appsettings.json b/src/SIL.Machine.Serval.JobServer/appsettings.json index d89391a2c..2e83382ba 100644 --- a/src/SIL.Machine.Serval.JobServer/appsettings.json +++ b/src/SIL.Machine.Serval.JobServer/appsettings.json @@ -11,10 +11,20 @@ "Nmt" ], "BuildJob": { - "Runners": { - "Cpu": "Hangfire", - "Gpu": "ClearML" - } + "ClearML": [ + { + "TranslationEngineType": "Nmt", + "ModelType": "huggingface", + "Queue": "jobs_backlog", + "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + }, + { + "TranslationEngineType": "SmtTransfer", + "ModelType": "hmm", + "Queue": "jobs_backlog", + "DockerImage": "ghcr.io/sillsdev/machine.py:latest" + } + ] }, "SmtTransferEngine": { "EnginesDir": "/var/lib/machine/engines" diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs index 89448a5bb..e4a9f75f8 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs @@ -43,7 +43,7 @@ public async Task CreateTaskAsync() + "}\n" + "run(args)\n"; - string projectId = await service.CreateTaskAsync("build1", "project1", script); + string projectId = await service.CreateTaskAsync("build1", "project1", "dockerImage", script); Assert.That(projectId, Is.EqualTo("projectId")); mockHttp.VerifyNoOutstandingExpectation(); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs index aa054696c..34a3e7cbd 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs @@ -28,6 +28,7 @@ private async Task SetUpAsync() { Id = "engine1", EngineId = "engineId1", + Type = TranslationEngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -39,6 +40,7 @@ private async Task SetUpAsync() { Id = "engine2", EngineId = "engineId2", + Type = TranslationEngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 2, diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs index edb2f2a46..bd69741cb 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs @@ -10,7 +10,8 @@ public async Task CreateJobScriptAsync_BuildOptions() string script = await env.BuildJobFactory.CreateJobScriptAsync( "engine1", "build1", - NmtBuildStages.Train, + "test_model", + BuildStage.Train, buildOptions: "{ \"max_steps\": \"10\" }" ); Assert.That( @@ -38,7 +39,12 @@ public async Task CreateJobScriptAsync_BuildOptions() public async Task CreateJobScriptAsync_NoBuildOptions() { var env = new TestEnvironment(); - string script = await env.BuildJobFactory.CreateJobScriptAsync("engine1", "build1", NmtBuildStages.Train); + string script = await env.BuildJobFactory.CreateJobScriptAsync( + "engine1", + "build1", + "test_model", + BuildStage.Train + ); Assert.That( script, Is.EqualTo( @@ -75,6 +81,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", + Type = TranslationEngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -83,14 +90,14 @@ public TestEnvironment() { BuildId = "build1", JobId = "job1", - JobRunner = BuildJobRunner.ClearML, - Stage = NmtBuildStages.Train, + BuildJobRunner = BuildJobRunnerType.ClearML, + Stage = BuildStage.Train, JobState = BuildJobState.Pending } } ); Options = Substitute.For>(); - Options.CurrentValue.Returns(new ClearMLOptions { ModelType = "test_model" }); + Options.CurrentValue.Returns(new ClearMLOptions { }); SharedFileService = Substitute.For(); SharedFileService.GetBaseUri().Returns(new Uri("s3://bucket/folder1/folder2")); LanguageTagService = Substitute.For(); @@ -110,7 +117,7 @@ public TestEnvironment() x[1] = "eng_Latn"; return true; }); - BuildJobFactory = new NmtClearMLBuildJobFactory(SharedFileService, LanguageTagService, Engines, Options); + BuildJobFactory = new NmtClearMLBuildJobFactory(SharedFileService, LanguageTagService, Engines); } } } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs index 31b1c120e..98c08b00f 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs @@ -108,6 +108,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", + Type = TranslationEngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -127,41 +128,66 @@ public TestEnvironment() .GetProjectIdAsync("engine1", Arg.Any()) .Returns(Task.FromResult("project1")); ClearMLService - .CreateTaskAsync("build1", "project1", Arg.Any(), Arg.Any()) + .CreateTaskAsync( + "build1", + "project1", + Arg.Any(), + Arg.Any(), + Arg.Any() + ) .Returns(Task.FromResult("job1")); ClearMLService - .When(x => x.EnqueueTaskAsync("job1", Arg.Any())) + .When(x => x.EnqueueTaskAsync("job1", Arg.Any(), Arg.Any())) .Do(_ => TrainJobTask = Task.Run(TrainJobFunc)); SharedFileService = new SharedFileService(Substitute.For()); - var clearMLOptions = Substitute.For>(); - clearMLOptions.CurrentValue.Returns(new ClearMLOptions()); - BuildJobService = new BuildJobService( - new IBuildJobRunner[] + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns( + new BuildJobOptions { - new HangfireBuildJobRunner(_jobClient, new[] { new NmtHangfireBuildJobFactory() }), + ClearML = + [ + new ClearMLBuildQueue() + { + TranslationEngineType = TranslationEngineType.Nmt, + ModelType = "huggingface", + DockerImage = "default", + Queue = "default" + }, + new ClearMLBuildQueue() + { + TranslationEngineType = TranslationEngineType.SmtTransfer, + ModelType = "hmm", + DockerImage = "default", + Queue = "default" + } + ] + } + ); + BuildJobService = new BuildJobService( + [ + new HangfireBuildJobRunner(_jobClient, [new NmtHangfireBuildJobFactory()]), new ClearMLBuildJobRunner( ClearMLService, - new[] - { + [ new NmtClearMLBuildJobFactory( SharedFileService, Substitute.For(), - Engines, - clearMLOptions + Engines ) - } + ], + buildJobOptions ) - }, - Engines, - new OptionsWrapper(new BuildJobOptions()) + ], + Engines ); - var clearMLOptionsMonitor = Substitute.For>(); - clearMLOptionsMonitor.Value.Returns(new ClearMLOptions()); - ClearMLMonitorService = new ClearMLMonitorService( + var clearMLOptions = Substitute.For>(); + clearMLOptions.CurrentValue.Returns(new ClearMLOptions()); + ClearMLQueueService = new ClearMLMonitorService( Substitute.For(), ClearMLService, SharedFileService, - clearMLOptionsMonitor, + clearMLOptions, + buildJobOptions, Substitute.For>() ); _jobServer = CreateJobServer(); @@ -169,7 +195,7 @@ public TestEnvironment() } public NmtEngineService Service { get; private set; } - public ClearMLMonitorService ClearMLMonitorService { get; } + public IClearMLQueueService ClearMLQueueService { get; } public MemoryRepository Engines { get; } public IPlatformService PlatformService { get; } public IClearMLService ClearMLService { get; } @@ -209,7 +235,7 @@ private NmtEngineService CreateService() Engines, BuildJobService, new LanguageTagService(), - ClearMLMonitorService, + ClearMLQueueService, SharedFileService ); } @@ -222,7 +248,7 @@ public Task WaitForBuildToFinishAsync() public Task WaitForBuildToStartAsync() { return WaitForBuildState(e => - e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage == NmtBuildStages.Train + e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage == BuildStage.Train ); } @@ -250,11 +276,10 @@ private async Task RunMockTrainJob() } await BuildJobService.StartBuildJobAsync( - BuildJobType.Cpu, - TranslationEngineType.Nmt, + BuildJobRunnerType.Hangfire, "engine1", "build1", - NmtBuildStages.Postprocess, + BuildStage.Postprocess, (0, 0.0) ); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs index cf1b8002d..9fbbe4149 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs @@ -255,7 +255,7 @@ private class TestEnvironment : ObjectModel.DisposableBase public ILogger Logger { get; } public IClearMLService ClearMLService { get; } public NmtPreprocessBuildJob BuildJob { get; } - public IOptionsMonitor Options { get; } + public IOptionsMonitor BuildJobOptions { get; } public Corpus DefaultTextFileCorpus { get; } public Corpus DefaultMixedSourceTextFileCorpus { get; } @@ -323,6 +323,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", + Type = TranslationEngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -332,8 +333,8 @@ public TestEnvironment() BuildId = "build1", JobId = "job1", JobState = BuildJobState.Pending, - JobRunner = BuildJobRunner.Hangfire, - Stage = NmtBuildStages.Preprocess + BuildJobRunner = BuildJobRunnerType.Hangfire, + Stage = BuildStage.Preprocess } } ); @@ -342,6 +343,7 @@ public TestEnvironment() { Id = "engine2", EngineId = "engine2", + Type = TranslationEngineType.Nmt, SourceLanguage = "xxx", TargetLanguage = "zzz", BuildRevision = 1, @@ -351,8 +353,28 @@ public TestEnvironment() BuildId = "build1", JobId = "job1", JobState = BuildJobState.Pending, - JobRunner = BuildJobRunner.Hangfire, - Stage = NmtBuildStages.Preprocess + BuildJobRunner = BuildJobRunnerType.Hangfire, + Stage = BuildStage.Preprocess + } + } + ); + Engines.Add( + new TranslationEngine + { + Id = "engine2", + EngineId = "engine2", + Type = TranslationEngineType.Nmt, + SourceLanguage = "xxx", + TargetLanguage = "zzz", + BuildRevision = 1, + IsModelPersisted = false, + CurrentBuild = new() + { + BuildId = "build1", + JobId = "job1", + JobState = BuildJobState.Pending, + BuildJobRunner = BuildJobRunnerType.Hangfire, + Stage = BuildStage.Preprocess } } ); @@ -363,8 +385,29 @@ public TestEnvironment() new MemoryRepository(), new ObjectIdGenerator() ); - Options = Substitute.For>(); - Options.CurrentValue.Returns(new ClearMLOptions { ModelType = "test_model" }); + BuildJobOptions = Substitute.For>(); + BuildJobOptions.CurrentValue.Returns( + new BuildJobOptions + { + ClearML = + [ + new ClearMLBuildQueue() + { + TranslationEngineType = TranslationEngineType.Nmt, + ModelType = "huggingface", + DockerImage = "default", + Queue = "default" + }, + new ClearMLBuildQueue() + { + TranslationEngineType = TranslationEngineType.SmtTransfer, + ModelType = "hmm", + DockerImage = "default", + Queue = "default" + } + ] + } + ); ClearMLService = Substitute.For(); ClearMLService .GetProjectIdAsync("engine1", Arg.Any()) @@ -373,30 +416,41 @@ public TestEnvironment() .GetProjectIdAsync("engine2", Arg.Any()) .Returns(Task.FromResult("project1")); ClearMLService - .CreateTaskAsync("build1", "project1", Arg.Any(), Arg.Any()) + .GetProjectIdAsync("engine2", Arg.Any()) + .Returns(Task.FromResult("project1")); + ClearMLService + .CreateTaskAsync( + "build1", + "project1", + Arg.Any(), + Arg.Any(), + Arg.Any() + ) .Returns(Task.FromResult("job1")); SharedFileService = new SharedFileService(Substitute.For()); Logger = Substitute.For>(); BuildJobService = new BuildJobService( + [ [ new HangfireBuildJobRunner( Substitute.For(), [new NmtHangfireBuildJobFactory()] + [new NmtHangfireBuildJobFactory()] ), new ClearMLBuildJobRunner( ClearMLService, + [ [ new NmtClearMLBuildJobFactory( SharedFileService, Substitute.For(), - Engines, - Options + Engines ) - ] + ], + BuildJobOptions ) ], - Engines, - new OptionsWrapper(new BuildJobOptions()) + Engines ); BuildJob = new NmtPreprocessBuildJob( PlatformService, diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index c33974bf9..6b7cfa685 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -1,42 +1,49 @@ -namespace SIL.Machine.AspNetCore.Services; +using NSubstitute.ClearExtensions; +using SIL.Machine.Corpora; + +namespace SIL.Machine.AspNetCore.Services; [TestFixture] public class SmtTransferEngineServiceTests { + const string EngineId1 = "engine1"; + const string EngineId2 = "engine2"; + const string BuildId1 = "build1"; + const string CorpusId1 = "corpus1"; + [Test] public async Task CreateAsync() { using var env = new TestEnvironment(); - await env.Service.CreateAsync("engine2", "Engine 2", "es", "en"); - TranslationEngine? engine = await env.Engines.GetAsync(e => e.EngineId == "engine2"); + await env.Service.CreateAsync(EngineId2, "Engine 2", "es", "en"); + TranslationEngine? engine = await env.Engines.GetAsync(e => e.EngineId == EngineId2); Assert.Multiple(() => { Assert.That(engine, Is.Not.Null); - Assert.That(engine?.EngineId, Is.EqualTo("engine2")); + Assert.That(engine?.EngineId, Is.EqualTo(EngineId2)); Assert.That(engine?.BuildRevision, Is.EqualTo(0)); Assert.That(engine?.IsModelPersisted, Is.True); }); - env.SmtModelFactory.Received().InitNew("engine2"); - env.TransferEngineFactory.Received().InitNew("engine2"); + env.SmtModelFactory.Received().InitNew(EngineId2); + env.TransferEngineFactory.Received().InitNew(EngineId2); } [Test] public async Task StartBuildAsync() { using var env = new TestEnvironment(); - TranslationEngine engine = env.Engines.Get("engine1"); + TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.BuildRevision, Is.EqualTo(1)); // ensure that the SMT model was loaded before training - await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); + await env.Service.TranslateAsync(EngineId1, n: 1, "esto es una prueba."); await env.Service.StartBuildAsync( - "engine1", - "build1", + EngineId1, + BuildId1, null, - new[] - { + [ new Corpus() { - Id = "corpus1", + Id = CorpusId1, SourceLanguage = "es", TargetLanguage = "en", SourceFiles = [], @@ -44,7 +51,7 @@ await env.Service.StartBuildAsync( TrainOnTextIds = null, PretranslateTextIds = null } - } + ] ); await env.WaitForBuildToFinishAsync(); await env @@ -55,12 +62,12 @@ await env .TrainAsync(Arg.Any>(), Arg.Any()); await env.SmtBatchTrainer.Received().SaveAsync(Arg.Any()); await env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); - engine = env.Engines.Get("engine1"); + engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); Assert.That(engine.BuildRevision, Is.EqualTo(2)); // check if SMT model was reloaded upon first use after training env.SmtModel.ClearReceivedCalls(); - await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); + await env.Service.TranslateAsync(EngineId1, n: 1, "esto es una prueba."); env.SmtModel.Received().Dispose(); await env.SmtModel.DidNotReceive().SaveAsync(); await env.Truecaser.DidNotReceive().SaveAsync(); @@ -81,16 +88,16 @@ await env.SmtBatchTrainer.TrainAsync( } }) ); - await env.Service.StartBuildAsync("engine1", "build1", "{}", Array.Empty()); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); await env.WaitForBuildToStartAsync(); - TranslationEngine engine = env.Engines.Get("engine1"); + TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); - await env.Service.CancelBuildAsync("engine1"); + await env.Service.CancelBuildAsync(EngineId1); await env.WaitForBuildToFinishAsync(); await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); await env.TruecaserTrainer.DidNotReceive().SaveAsync(); - engine = env.Engines.Get("engine1"); + engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); } @@ -98,7 +105,7 @@ await env.SmtBatchTrainer.TrainAsync( public void CancelBuildAsync_NotBuilding() { using var env = new TestEnvironment(); - Assert.ThrowsAsync(() => env.Service.CancelBuildAsync("engine1")); + Assert.ThrowsAsync(() => env.Service.CancelBuildAsync(EngineId1)); } [Test] @@ -117,21 +124,21 @@ await env.SmtBatchTrainer.TrainAsync( } }) ); - await env.Service.StartBuildAsync("engine1", "build1", "{}", Array.Empty()); - await env.WaitForBuildToStartAsync(); - TranslationEngine engine = env.Engines.Get("engine1"); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); + await env.WaitForTrainingToStartAsync(); + TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); env.StopServer(); await env.WaitForBuildToRestartAsync(); - engine = env.Engines.Get("engine1"); + engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Pending)); - await env.PlatformService.Received().BuildRestartingAsync("build1"); + await env.PlatformService.Received().BuildRestartingAsync(BuildId1); env.SmtBatchTrainer.ClearSubstitute(ClearOptions.CallActions); env.StartServer(); await env.WaitForBuildToFinishAsync(); - engine = env.Engines.Get("engine1"); + engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); } @@ -150,17 +157,17 @@ await env.SmtBatchTrainer.TrainAsync( } }) ); - await env.Service.StartBuildAsync("engine1", "build1", "{}", Array.Empty()); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); await env.WaitForBuildToStartAsync(); - TranslationEngine engine = env.Engines.Get("engine1"); + TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); - await env.Service.DeleteAsync("engine1"); + await env.Service.DeleteAsync(EngineId1); // ensure that the build job was canceled await env.WaitForAllHangfireJobsToFinishAsync(); await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); await env.TruecaserTrainer.DidNotReceive().SaveAsync(); - Assert.That(env.Engines.Contains("engine1"), Is.False); + Assert.That(env.Engines.Contains(EngineId1), Is.False); } [Test] @@ -179,16 +186,17 @@ await env.SmtBatchTrainer.TrainAsync( } }) ); - await env.Service.StartBuildAsync("engine1", "build1", "{}", Array.Empty()); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); await env.WaitForBuildToStartAsync(); - TranslationEngine engine = env.Engines.Get("engine1"); + TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); - await env.Service.TrainSegmentPairAsync("engine1", "esto es una prueba.", "this is a test.", true); + await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); training = false; await env.WaitForBuildToFinishAsync(); - engine = env.Engines.Get("engine1"); + engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); await env.SmtModel.Received(2).TrainSegmentAsync("esto es una prueba.", "this is a test.", true); } @@ -196,28 +204,28 @@ await env.SmtBatchTrainer.TrainAsync( public async Task CommitAsync_LoadedInactive() { using var env = new TestEnvironment(); - await env.Service.TrainSegmentPairAsync("engine1", "esto es una prueba.", "this is a test.", true); + await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await Task.Delay(10); await env.CommitAsync(TimeSpan.Zero); await env.SmtModel.Received().SaveAsync(); - Assert.That(env.StateService.Get("engine1").IsLoaded, Is.False); + Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.False); } [Test] public async Task CommitAsync_LoadedActive() { using var env = new TestEnvironment(); - await env.Service.TrainSegmentPairAsync("engine1", "esto es una prueba.", "this is a test.", true); + await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await env.CommitAsync(TimeSpan.FromHours(1)); await env.SmtModel.Received().SaveAsync(); - Assert.That(env.StateService.Get("engine1").IsLoaded, Is.True); + Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.True); } [Test] public async Task TranslateAsync() { using var env = new TestEnvironment(); - TranslationResult result = (await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."))[0]; + TranslationResult result = (await env.Service.TranslateAsync(EngineId1, n: 1, "esto es una prueba."))[0]; Assert.That(result.Translation, Is.EqualTo("this is a TEST.")); } @@ -225,7 +233,7 @@ public async Task TranslateAsync() public async Task GetWordGraphAsync() { using var env = new TestEnvironment(); - WordGraph result = await env.Service.GetWordGraphAsync("engine1", "esto es una prueba."); + WordGraph result = await env.Service.GetWordGraphAsync(EngineId1, "esto es una prueba."); Assert.That( result.Arcs.Select(a => string.Join(' ', a.TargetTokens)), Is.EqualTo(new[] { "this is", "a test", "." }) @@ -239,16 +247,17 @@ private class TestEnvironment : ObjectModel.DisposableBase private BackgroundJobServer _jobServer; private readonly ITruecaserFactory _truecaserFactory; private readonly IDistributedReaderWriterLockFactory _lockFactory; - private readonly IBuildJobService _buildJobService; public TestEnvironment() { + TrainJobFunc = RunMockTrainJob; Engines = new MemoryRepository(); Engines.Add( new TranslationEngine { - Id = "engine1", - EngineId = "engine1", + Id = EngineId1, + EngineId = EngineId1, + Type = TranslationEngineType.SmtTransfer, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -261,10 +270,12 @@ public TestEnvironment() PlatformService = Substitute.For(); SmtModel = Substitute.For(); SmtBatchTrainer = Substitute.For(); - SmtBatchTrainer.Stats.Returns(new TrainStats { Metrics = { { "bleu", 0.0 }, { "perplexity", 0.0 } } }); + SmtBatchTrainer.Stats.Returns( + new TrainStats { TrainCorpusSize = 0, Metrics = { { "bleu", 0.0 }, { "perplexity", 0.0 } } } + ); Truecaser = Substitute.For(); TruecaserTrainer = Substitute.For(); - TruecaserTrainer.SaveAsync().Returns(Task.CompletedTask); + SmtModelFactory = CreateSmtModelFactory(); TransferEngineFactory = CreateTransferEngineFactory(); _truecaserFactory = CreateTruecaserFactory(); @@ -273,18 +284,51 @@ public TestEnvironment() new MemoryRepository(), new ObjectIdGenerator() ); - _buildJobService = new BuildJobService( - new[] { new HangfireBuildJobRunner(_jobClient, new[] { new SmtTransferHangfireBuildJobFactory() }) }, - Engines, - new OptionsWrapper( - new BuildJobOptions - { - Runners = new Dictionary + SharedFileService = new SharedFileService(Substitute.For()); + var clearMLOptions = Substitute.For>(); + clearMLOptions.CurrentValue.Returns(new ClearMLOptions()); + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns( + new BuildJobOptions + { + ClearML = + [ + new ClearMLBuildQueue() + { + TranslationEngineType = TranslationEngineType.Nmt, + ModelType = "huggingface", + DockerImage = "default", + Queue = "default" + }, + new ClearMLBuildQueue() { - { BuildJobType.Cpu, BuildJobRunner.Hangfire } + TranslationEngineType = TranslationEngineType.SmtTransfer, + ModelType = "hmm", + DockerImage = "default", + Queue = "default" } - } - ) + ] + } + ); + ClearMLService = Substitute.For(); + ClearMLMonitorService = new ClearMLMonitorService( + Substitute.For(), + ClearMLService, + SharedFileService, + clearMLOptions, + buildJobOptions, + Substitute.For>() + ); + BuildJobService = new BuildJobService( + [ + new HangfireBuildJobRunner(_jobClient, [new SmtTransferHangfireBuildJobFactory()]), + new ClearMLBuildJobRunner( + ClearMLService, + [new SmtTransferClearMLBuildJobFactory(SharedFileService, Engines)], + buildJobOptions + ) + ], + Engines ); _jobServer = CreateJobServer(); StateService = CreateStateService(); @@ -303,6 +347,16 @@ public TestEnvironment() public ITrainer TruecaserTrainer { get; } public IPlatformService PlatformService { get; } + public IClearMLService ClearMLService { get; } + public IClearMLQueueService ClearMLMonitorService { get; } + + public ISharedFileService SharedFileService { get; } + + public IBuildJobService BuildJobService { get; } + public Func TrainJobFunc { get; set; } + + public Task? TrainJobTask { get; private set; } + public async Task CommitAsync(TimeSpan inactiveTimeout) { await StateService.CommitAsync(_lockFactory, Engines, inactiveTimeout); @@ -310,8 +364,8 @@ public async Task CommitAsync(TimeSpan inactiveTimeout) public void StopServer() { - StateService.Dispose(); _jobServer.Dispose(); + StateService.Dispose(); } public void StartServer() @@ -346,8 +400,8 @@ private SmtTransferEngineService CreateService() Engines, TrainSegmentPairs, StateService, - _buildJobService, - _memoryStorage + BuildJobService, + ClearMLMonitorService ); } @@ -520,6 +574,13 @@ public Task WaitForBuildToStartAsync() return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Active); } + public Task WaitForTrainingToStartAsync() + { + return WaitForBuildState(e => + e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage is BuildStage.Train + ); + } + public Task WaitForBuildToRestartAsync() { return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Pending); @@ -528,7 +589,7 @@ public Task WaitForBuildToRestartAsync() private async Task WaitForBuildState(Func predicate) { using ISubscription subscription = await Engines.SubscribeAsync(e => - e.EngineId == "engine1" + e.EngineId == EngineId1 ); while (true) { @@ -545,32 +606,57 @@ protected override void DisposeManagedResources() _jobServer.Dispose(); } + private Task RunMockTrainJob() + { + throw new InvalidOperationException(); + } + private class EnvActivator(TestEnvironment env) : JobActivator { private readonly TestEnvironment _env = env; public override object ActivateJob(Type jobType) { - if (jobType == typeof(SmtTransferBuildJob)) + if (jobType == typeof(SmtTransferPreprocessBuildJob)) { - ICorpusService corpusService = Substitute.For(); - corpusService - .CreateTextCorpora(Arg.Any>()) - .Returns([new DictionaryTextCorpus()]); - corpusService - .CreateTermCorpora(Arg.Any>()) - .Returns([new DictionaryTextCorpus()]); - - return new SmtTransferBuildJob( + return new SmtTransferPreprocessBuildJob( _env.PlatformService, _env.Engines, _env._lockFactory, - _env._buildJobService, - Substitute.For>(), + Substitute.For>(), + _env.BuildJobService, + _env.SharedFileService, + Substitute.For() + ) + { + TrainJobRunnerType = BuildJobRunnerType.Hangfire + }; + } + if (jobType == typeof(SmtTransferPostprocessBuildJob)) + { + return new SmtTransferPostprocessBuildJob( + _env.PlatformService, + _env.Engines, + _env._lockFactory, + _env.BuildJobService, + Substitute.For>(), + _env.SharedFileService, _env.TrainSegmentPairs, - _env._truecaserFactory, _env.SmtModelFactory, - corpusService + _env._truecaserFactory + ); + } + if (jobType == typeof(SmtTransferTrainBuildJob)) + { + return new SmtTransferTrainBuildJob( + _env.PlatformService, + _env.Engines, + _env._lockFactory, + _env.BuildJobService, + Substitute.For>(), + _env.SharedFileService, + _env._truecaserFactory, + _env.SmtModelFactory ); } return base.ActivateJob(jobType); diff --git a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs index c4806736a..98efc5ed0 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs @@ -9,7 +9,6 @@ global using Microsoft.Extensions.Logging; global using Microsoft.Extensions.Options; global using NSubstitute; -global using NSubstitute.ClearExtensions; global using NSubstitute.ReceivedExtensions; global using NUnit.Framework; global using RichardSzalay.MockHttp; @@ -17,7 +16,6 @@ global using SIL.Machine.Annotations; global using SIL.Machine.AspNetCore.Configuration; global using SIL.Machine.AspNetCore.Models; -global using SIL.Machine.Corpora; global using SIL.Machine.Tokenization; global using SIL.Machine.Translation; global using SIL.Machine.Utils; From 7267440e1aeb86fbe1d94dc6ebe509e6faa44c01 Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Mon, 10 Jun 2024 09:32:39 -0400 Subject: [PATCH 2/4] Various updates to SMT engine - refactored SMT engine classes - more consistent model factories - fixed some dispose issues - Preserve changes from https://github.com/sillsdev/machine/pull/205. - change SMT model save location to be in build - always do it and auto-save when delete. - Revert to tar.gz - Remove SMT model cleanup (unneeded) --- .editorconfig | 3 + .../Models/ClearMLTask.cs | 3 + .../SIL.Machine.AspNetCore.csproj | 1 + .../Services/ClearMLMonitorService.cs | 47 +-- .../Services/ClearMLService.cs | 5 +- .../Services/ISmtModelFactory.cs | 22 +- .../Services/ITransferEngineFactory.cs | 11 +- .../Services/ITruecaserFactory.cs | 11 +- .../Services/ModelCleanupService.cs | 10 +- .../Services/NmtHangfireBuildJobFactory.cs | 2 +- .../Services/NmtPostprocessBuildJob.cs | 39 --- .../Services/NmtPreprocessBuildJob.cs | 35 +-- .../Services/PostprocessBuildJob.cs | 45 ++- .../Services/PreprocessBuildJob.cs | 82 ++--- .../Services/S3WriteStream.cs | 112 +++---- .../SmtTransferClearMLBuildJobFactory.cs | 3 - .../Services/SmtTransferEngineService.cs | 11 +- .../Services/SmtTransferEngineState.cs | 31 +- .../Services/SmtTransferEngineStateService.cs | 12 +- .../SmtTransferHangfireBuildJobFactory.cs | 5 +- .../SmtTransferPostprocessBuildJob.cs | 74 ++--- .../Services/SmtTransferPreprocessBuildJob.cs | 20 -- .../Services/SmtTransferTrainBuildJob.cs | 209 +++++++++++-- .../Services/ThotSmtModelFactory.cs | 112 ++++--- .../Services/TransferEngineFactory.cs | 25 +- .../Services/UnigramTruecaserFactory.cs | 29 +- src/SIL.Machine.AspNetCore/Usings.cs | 3 + .../Utils/DictionaryStringConverter.cs | 73 +++++ .../appsettings.json | 2 +- .../appsettings.json | 2 +- .../SIL.Machine.Translation.Thot.csproj | 2 +- src/SIL.Machine/Utils/ProgressStatus.cs | 2 +- .../Services/NmtEngineServiceTests.cs | 80 +++-- ...JobTests.cs => PreprocessBuildJobTests.cs} | 89 ++++-- .../Services/SmtTransferEngineServiceTests.cs | 295 +++++++++++------- tests/SIL.Machine.AspNetCore.Tests/Usings.cs | 2 + 36 files changed, 902 insertions(+), 607 deletions(-) delete mode 100644 src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Utils/DictionaryStringConverter.cs rename tests/SIL.Machine.AspNetCore.Tests/Services/{NmtPreprocessBuildJobTests.cs => PreprocessBuildJobTests.cs} (88%) diff --git a/.editorconfig b/.editorconfig index db9ec586d..1a56c51fd 100644 --- a/.editorconfig +++ b/.editorconfig @@ -40,6 +40,9 @@ csharp_new_line_before_finally = true csharp_new_line_before_members_in_object_initializers = true csharp_new_line_before_members_in_anonymous_types = true +# Indentation settings +csharp_indent_case_contents_when_block = false + # Namespace settings csharp_style_namespace_declarations = file_scoped diff --git a/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs b/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs index 1c7adf3a9..9a05325ce 100644 --- a/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs +++ b/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs @@ -29,4 +29,7 @@ public required IReadOnlyDictionary< string, IReadOnlyDictionary > LastMetrics { get; init; } + + [JsonConverter(typeof(DictionaryStringStringConverter))] + public required IReadOnlyDictionary Runtime { get; init; } } diff --git a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj index d4df12db7..159e7bab9 100644 --- a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj +++ b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj @@ -31,6 +31,7 @@ + diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs index 05b931bee..217a398e1 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs @@ -17,12 +17,9 @@ ILogger logger ), IClearMLQueueService { - private static readonly string EvalMetric = CreateMD5("eval"); - private static readonly string BleuVariant = CreateMD5("bleu"); - private static readonly string SummaryMetric = CreateMD5("Summary"); - private static readonly string CorpusSizeVariant = CreateMD5("corpus_size"); - private static readonly string ProgressVariant = CreateMD5("progress"); + private static readonly string TrainCorpusSizeVariant = CreateMD5("train_corpus_size"); + private static readonly string ConfidenceVariant = CreateMD5("confidence"); private readonly IClearMLService _clearMLService = clearMLService; private readonly ISharedFileService _sharedFileService = sharedFileService; @@ -134,24 +131,29 @@ or ClearMLTaskStatus.Completed switch (task.Status) { case ClearMLTaskStatus.InProgress: + { + double? percentCompleted = null; + if (task.Runtime.TryGetValue("progress", out string? progressStr)) + percentCompleted = int.Parse(progressStr, CultureInfo.InvariantCulture) / 100.0; + task.Runtime.TryGetValue("message", out string? message); await UpdateTrainJobStatus( platformService, engine.CurrentBuild.BuildId, - new ProgressStatus( - task.LastIteration ?? 0, - percentCompleted: GetMetric(task, SummaryMetric, ProgressVariant) - ), - 0, + new ProgressStatus(task.LastIteration ?? 0, percentCompleted, message), + queueDepth: 0, cancellationToken ); break; + } case ClearMLTaskStatus.Completed: + { + task.Runtime.TryGetValue("message", out string? message); await UpdateTrainJobStatus( platformService, engine.CurrentBuild.BuildId, - new ProgressStatus(task.LastIteration ?? 0, percentCompleted: 1.0), - 0, + new ProgressStatus(task.LastIteration ?? 0, percentCompleted: 1.0, message), + queueDepth: 0, cancellationToken ); bool canceling = !await TrainJobCompletedAsync( @@ -159,8 +161,8 @@ await UpdateTrainJobStatus( buildJobService, engine.EngineId, engine.CurrentBuild.BuildId, - (int)GetMetric(task, SummaryMetric, CorpusSizeVariant), - GetMetric(task, EvalMetric, BleuVariant), + (int)GetMetric(task, SummaryMetric, TrainCorpusSizeVariant), + GetMetric(task, SummaryMetric, ConfidenceVariant), engine.CurrentBuild.Options, cancellationToken ); @@ -176,8 +178,10 @@ await TrainJobCanceledAsync( ); } break; + } case ClearMLTaskStatus.Stopped: + { await TrainJobCanceledAsync( lockFactory, buildJobService, @@ -187,8 +191,10 @@ await TrainJobCanceledAsync( cancellationToken ); break; + } case ClearMLTaskStatus.Failed: + { await TrainJobFaultedAsync( lockFactory, buildJobService, @@ -199,6 +205,7 @@ await TrainJobFaultedAsync( cancellationToken ); break; + } } } } @@ -227,7 +234,7 @@ private async Task TrainJobStartedAsync( await platformService.BuildStartedAsync(buildId, CancellationToken.None); await UpdateTrainJobStatus(platformService, buildId, new ProgressStatus(0), 0, cancellationToken); - _logger.LogInformation("Build started ({0})", buildId); + _logger.LogInformation("Build started ({BuildId})", buildId); return true; } @@ -287,7 +294,7 @@ await buildJobService.BuildJobFinishedAsync( CancellationToken.None ); } - _logger.LogError("Build faulted ({0}). Error: {1}", buildId, message); + _logger.LogError("Build faulted ({BuildId}). Error: {ErrorMessage}", buildId, message); } finally { @@ -317,7 +324,7 @@ await buildJobService.BuildJobFinishedAsync( CancellationToken.None ); } - _logger.LogInformation("Build canceled ({0})", buildId); + _logger.LogInformation("Build canceled ({BuildId})", buildId); } finally { @@ -327,7 +334,7 @@ await buildJobService.BuildJobFinishedAsync( } catch (Exception e) { - _logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + _logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId); } _curBuildStatus.Remove(buildId); } @@ -365,10 +372,8 @@ private static double GetMetric(ClearMLTask task, string metric, string variant) private static string CreateMD5(string input) { - using var md5 = MD5.Create(); - byte[] inputBytes = Encoding.UTF8.GetBytes(input); - byte[] hashBytes = md5.ComputeHash(inputBytes); + byte[] hashBytes = MD5.HashData(inputBytes); return Convert.ToHexString(hashBytes).ToLower(); } diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs index d755aadcc..fd11b25dd 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs @@ -15,7 +15,7 @@ IHostEnvironment env new() { PropertyNamingPolicy = JsonNamingPolicy, - Converters = { new Utils.CustomEnumConverterFactory(JsonNamingPolicy) } + Converters = { new CustomEnumConverterFactory(JsonNamingPolicy) } }; private readonly IClearMLAuthenticationService _clearMLAuthService = clearMLAuthService; @@ -181,7 +181,8 @@ private async Task> GetTasksAsync( "status_message", "created", "active_duration", - "last_metrics" + "last_metrics", + "runtime" ); JsonObject? result = await CallAsync("tasks", "get_all_ex", body, cancellationToken); var tasks = (JsonArray?)result?["data"]?["tasks"]; diff --git a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs index 51d7fe1e5..292644766 100644 --- a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs @@ -2,15 +2,21 @@ public interface ISmtModelFactory { - IInteractiveTranslationModel Create( - string engineId, + Task CreateAsync( + string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser + ITruecaser truecaser, + CancellationToken cancellationToken = default ); - ITrainer CreateTrainer(string engineId, IRangeTokenizer tokenizer, IParallelTextCorpus corpus); - Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); - Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); - void InitNew(string engineId); - void Cleanup(string engineId); + Task CreateTrainerAsync( + string engineDir, + IRangeTokenizer tokenizer, + IParallelTextCorpus corpus, + CancellationToken cancellationToken = default + ); + Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); + Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default); + Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default); } diff --git a/src/SIL.Machine.AspNetCore/Services/ITransferEngineFactory.cs b/src/SIL.Machine.AspNetCore/Services/ITransferEngineFactory.cs index c45014ec8..c99682e07 100644 --- a/src/SIL.Machine.AspNetCore/Services/ITransferEngineFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ITransferEngineFactory.cs @@ -2,12 +2,13 @@ public interface ITransferEngineFactory { - ITranslationEngine? Create( - string engineId, + Task CreateAsync( + string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser + ITruecaser truecaser, + CancellationToken cancellationToken = default ); - void InitNew(string engineId); - void Cleanup(string engineId); + Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); + Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); } diff --git a/src/SIL.Machine.AspNetCore/Services/ITruecaserFactory.cs b/src/SIL.Machine.AspNetCore/Services/ITruecaserFactory.cs index f937a80fa..4395a7f29 100644 --- a/src/SIL.Machine.AspNetCore/Services/ITruecaserFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ITruecaserFactory.cs @@ -2,7 +2,12 @@ public interface ITruecaserFactory { - Task CreateAsync(string engineId); - ITrainer CreateTrainer(string engineId, ITokenizer tokenizer, ITextCorpus corpus); - void Cleanup(string engineId); + Task CreateAsync(string engineDir, CancellationToken cancellationToken = default); + Task CreateTrainerAsync( + string engineDir, + ITokenizer tokenizer, + ITextCorpus corpus, + CancellationToken cancellationToken = default + ); + Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); } diff --git a/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs index baf7d75a9..2d132b27d 100644 --- a/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs @@ -35,15 +35,7 @@ private async Task CheckModelsAsync(CancellationToken cancellationToken) .Where(e => e.Type == TranslationEngineType.Nmt) .Select(e => NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision + 1)); - // Add SMT engines - IEnumerable validSmtModels = allEngines - .Where(e => e.Type == TranslationEngineType.SmtTransfer) - .Select(e => SmtTransferEngineService.GetModelPath(e.EngineId)); - - HashSet filenameFilter = validNmtFilenames - .Concat(validNmtFilenamesForNextBuild) - .Concat(validSmtModels) - .ToHashSet(); + HashSet filenameFilter = validNmtFilenames.Concat(validNmtFilenamesForNextBuild).ToHashSet(); foreach (string path in paths) { diff --git a/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs index 746c336aa..c0366078c 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs @@ -19,7 +19,7 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object? buildOptions ), BuildStage.Postprocess - => CreateJob(engineId, buildId, "nmt", data, buildOptions), + => CreateJob(engineId, buildId, "nmt", data, buildOptions), _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), }; } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs deleted file mode 100644 index ba17ef414..000000000 --- a/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs +++ /dev/null @@ -1,39 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class NmtPostprocessBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - IBuildJobService buildJobService, - ILogger logger, - ISharedFileService sharedFileService -) : PostprocessBuildJob(platformService, engines, lockFactory, buildJobService, logger, sharedFileService) -{ - protected override async Task DoWorkAsync( - string engineId, - string buildId, - (int, double) data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - (int corpusSize, double confidence) = data; - - // The NMT job has successfully completed, so insert the generated pretranslations into the database. - await InsertPretranslationsAsync(engineId, buildId, cancellationToken); - - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await PlatformService.BuildCompletedAsync( - buildId, - corpusSize, - Math.Round(confidence, 2, MidpointRounding.AwayFromZero), - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); - } - - Logger.LogInformation("Build completed ({0}).", buildId); - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs index 6cdcbfdee..5ba4d99d1 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -1,29 +1,20 @@ namespace SIL.Machine.AspNetCore.Services; -public class NmtPreprocessBuildJob : PreprocessBuildJob +public class NmtPreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService, + ILanguageTagService languageTagService +) : PreprocessBuildJob(platformService, engines, lockFactory, logger, buildJobService, sharedFileService, corpusService) { - public NmtPreprocessBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - ILogger logger, - IBuildJobService buildJobService, - ISharedFileService sharedFileService, - ICorpusService corpusService, - ILanguageTagService languageTagService - ) - : base(platformService, engines, lockFactory, logger, buildJobService, sharedFileService, corpusService) - { - _languageTagService = languageTagService; - PretranslationEnabled = true; - EngineType = TranslationEngineType.Nmt; - } - - private readonly ILanguageTagService _languageTagService; + private readonly ILanguageTagService _languageTagService = languageTagService; - protected override string ResolveLanguageCode(string languageCode) + protected override bool ResolveLanguageCodeForBaseModel(string languageCode, out string resolvedCode) { - _languageTagService.ConvertToFlores200Code(languageCode, out string resolvedCode); - return resolvedCode; + return _languageTagService.ConvertToFlores200Code(languageCode, out resolvedCode); } } diff --git a/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs index 009fda8ad..1bf7a4389 100644 --- a/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/PostprocessBuildJob.cs @@ -1,6 +1,6 @@ namespace SIL.Machine.AspNetCore.Services; -public abstract class PostprocessBuildJob( +public class PostprocessBuildJob( IPlatformService platformService, IRepository engines, IDistributedReaderWriterLockFactory lockFactory, @@ -9,7 +9,44 @@ public abstract class PostprocessBuildJob( ISharedFileService sharedFileService ) : HangfireBuildJob<(int, double)>(platformService, engines, lockFactory, buildJobService, logger) { - protected readonly ISharedFileService SharedFileService = sharedFileService; + private static readonly JsonSerializerOptions JsonSerializerOptions = + new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + + protected ISharedFileService SharedFileService { get; } = sharedFileService; + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + (int, double) data, + string? buildOptions, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + (int corpusSize, double confidence) = data; + + // The MT job has successfully completed, so insert the generated pretranslations into the database. + await InsertPretranslationsAsync(engineId, buildId, cancellationToken); + + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + int additionalCorpusSize = await SaveModelAsync(engineId, buildId); + await PlatformService.BuildCompletedAsync( + buildId, + corpusSize + additionalCorpusSize, + Math.Round(confidence, 2, MidpointRounding.AwayFromZero), + CancellationToken.None + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); + } + + Logger.LogInformation("Build completed ({0}).", buildId); + } + + protected virtual Task SaveModelAsync(string engineId, string buildId) + { + return Task.FromResult(0); + } protected override async Task CleanupAsync( string engineId, @@ -39,7 +76,7 @@ protected async Task InsertPretranslationsAsync( CancellationToken cancellationToken ) { - await using var targetPretranslateStream = await SharedFileService.OpenReadAsync( + await using Stream targetPretranslateStream = await SharedFileService.OpenReadAsync( $"builds/{buildId}/pretranslate.trg.json", cancellationToken ); @@ -47,7 +84,7 @@ CancellationToken cancellationToken IAsyncEnumerable pretranslations = JsonSerializer .DeserializeAsyncEnumerable( targetPretranslateStream, - new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, + JsonSerializerOptions, cancellationToken ) .OfType(); diff --git a/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs index 9d562d19a..37d8535de 100644 --- a/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs @@ -1,12 +1,10 @@ namespace SIL.Machine.AspNetCore.Services; -public abstract class PreprocessBuildJob : HangfireBuildJob> +public class PreprocessBuildJob : HangfireBuildJob> { private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; internal BuildJobRunnerType TrainJobRunnerType { get; init; } = BuildJobRunnerType.ClearML; - protected TranslationEngineType EngineType { get; init; } - protected bool PretranslationEnabled { get; init; } private readonly ISharedFileService _sharedFileService; private readonly ICorpusService _corpusService; @@ -72,10 +70,19 @@ CancellationToken cancellationToken if (engine is null) throw new OperationCanceledException($"Engine {engineId} does not exist. Build canceled."); - buildPreprocessSummary.Add("SourceLanguageResolved", ResolveLanguageCode(engine.SourceLanguage)); - buildPreprocessSummary.Add("TargetLanguageResolved", ResolveLanguageCode(engine.TargetLanguage)); + bool sourceTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.SourceLanguage, out string srcLang); + buildPreprocessSummary.Add("SourceLanguageResolved", srcLang); + bool targetTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.TargetLanguage, out string trgLang); + buildPreprocessSummary.Add("TargetLanguageResolved", trgLang); Logger.LogInformation("{summary}", buildPreprocessSummary.ToJsonString()); + if (trainCount == 0 && (!sourceTagInBaseModel || !targetTagInBaseModel)) + { + throw new InvalidOperationException( + $"Neither language code in build {buildId} are known to the base model, and the data specified for training was empty. Build canceled." + ); + } + cancellationToken.ThrowIfCancellationRequested(); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) @@ -85,7 +92,6 @@ CancellationToken cancellationToken engineId, buildId, BuildStage.Train, - data: new object(), buildOptions: buildOptions, cancellationToken: cancellationToken ); @@ -111,11 +117,11 @@ CancellationToken cancellationToken await using StreamWriter targetTrainWriter = new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); - using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( + await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync( $"builds/{buildId}/pretranslate.src.json", cancellationToken ); - using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions); + await using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions); int trainCount = 0; int pretranslateCount = 0; @@ -173,27 +179,24 @@ CancellationToken cancellationToken } } - if (PretranslationEnabled) + foreach (Row row in AlignPretranslateCorpus(sourceTextCorpora[0], targetTextCorpus)) { - foreach (Row row in AlignPretranslateCorpus(sourceTextCorpora[0], targetTextCorpus)) + if ( + IsInPretranslate(row, corpus) + && row.SourceSegment.Length > 0 + && (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus)) + ) { - if ( - IsInPretranslate(row, corpus) - && row.SourceSegment.Length > 0 - && (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus)) - ) - { - pretranslateWriter.WriteStartObject(); - pretranslateWriter.WriteString("corpusId", corpus.Id); - pretranslateWriter.WriteString("textId", row.TextId); - pretranslateWriter.WriteStartArray("refs"); - foreach (object rowRef in row.Refs) - pretranslateWriter.WriteStringValue(rowRef.ToString()); - pretranslateWriter.WriteEndArray(); - pretranslateWriter.WriteString("translation", row.SourceSegment); - pretranslateWriter.WriteEndObject(); - pretranslateCount++; - } + pretranslateWriter.WriteStartObject(); + pretranslateWriter.WriteString("corpusId", corpus.Id); + pretranslateWriter.WriteString("textId", row.TextId); + pretranslateWriter.WriteStartArray("refs"); + foreach (object rowRef in row.Refs) + pretranslateWriter.WriteStringValue(rowRef.ToString()); + pretranslateWriter.WriteEndArray(); + pretranslateWriter.WriteString("translation", row.SourceSegment); + pretranslateWriter.WriteEndObject(); + pretranslateCount++; } } } @@ -226,27 +229,31 @@ JobCompletionStatus completionStatus private static bool IsInTrain(Row row, Corpus corpus) { - return IsIncluded(row, corpus.TrainOnAll, corpus.TrainOnTextIds, corpus.TrainOnChapters); + return IsIncluded(row, corpus.TrainOnTextIds, corpus.TrainOnChapters); } private static bool IsInPretranslate(Row row, Corpus corpus) { - return IsIncluded(row, corpus.PretranslateAll, corpus.PretranslateTextIds, corpus.PretranslateChapters); + return IsIncluded(row, corpus.PretranslateTextIds, corpus.PretranslateChapters); } private static bool IsIncluded( - Row row, - bool all, - IReadOnlySet textIds, + Row? row, + IReadOnlySet? textIds, IReadOnlyDictionary>? chapters ) { + if (row is null) + return false; if (chapters is not null) { - if (row.Refs.Any(r => IsInChapters(chapters, r))) - return true; + return row.Refs.Any(r => IsInChapters(chapters, r)); + } + if (textIds is not null) + { + return textIds.Contains(row.TextId); } - return all || textIds.Contains(row.TextId); + return true; } private static bool IsInChapters(IReadOnlyDictionary> bookChapters, object rowRef) @@ -424,8 +431,9 @@ private record Row( int RowCount ); - protected virtual string ResolveLanguageCode(string languageCode) + protected virtual bool ResolveLanguageCodeForBaseModel(string languageCode, out string resolvedCode) { - return languageCode; + resolvedCode = languageCode; + return true; } } diff --git a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs index f96ed0ae2..35130134f 100644 --- a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs +++ b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs @@ -1,10 +1,5 @@ namespace SIL.Machine.AspNetCore.Services; -[SuppressMessage( - "Usage", - "CA1844: Provide memory-based overrides of async methods when subclassing 'Stream'", - Justification = "Data would have to be copied anyway" -)] public class S3WriteStream( AmazonS3Client client, string key, @@ -46,19 +41,7 @@ public override void Flush() { } public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; - public override void Write(byte[] buffer, int offset, int count) - { - try - { - using MemoryStream ms = new(buffer, offset, count); - WriteAsync(ms, count).Wait(); - } - catch (Exception e) - { - AbortAsync(e).Wait(); - throw; - } - } + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); public override async ValueTask WriteAsync( ReadOnlyMemory buffer, @@ -67,22 +50,45 @@ public override async ValueTask WriteAsync( { try { - using MemoryStream ms = new(buffer.ToArray(), 0, buffer.Length); - await WriteAsync(ms, buffer.Length); - } - catch (Exception e) - { - await AbortAsync(e); - throw; - } - } + using Stream stream = buffer.AsStream(); - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - try - { - using MemoryStream ms = new(buffer, offset, count); - await WriteAsync(ms, count); + int bytesWritten = 0; + + while (stream.Length > bytesWritten) + { + int partNumber = _uploadResponses.Count + 1; + UploadPartRequest request = + new() + { + BucketName = _bucketName, + Key = _key, + UploadId = _uploadId, + PartNumber = partNumber, + InputStream = stream, + PartSize = MaxPartSize + }; + request.StreamTransferProgress += new EventHandler( + (_, e) => + { + _logger.LogDebug( + "Transferred {e.TransferredBytes}/{e.TotalBytes}", + e.TransferredBytes, + e.TotalBytes + ); + } + ); + UploadPartResponse response = await _client.UploadPartAsync(request); + if (response.HttpStatusCode != HttpStatusCode.OK) + { + throw new HttpRequestException( + $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + } + + _uploadResponses.Add(response); + + bytesWritten += MaxPartSize; + } } catch (Exception e) { @@ -91,45 +97,9 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc } } - private async Task WriteAsync(MemoryStream ms, int count) + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - int bytesWritten = 0; - - while (count > bytesWritten) - { - int partNumber = _uploadResponses.Count + 1; - UploadPartRequest request = - new() - { - BucketName = _bucketName, - Key = _key, - UploadId = _uploadId, - PartNumber = partNumber, - InputStream = ms, - PartSize = Math.Min(MaxPartSize, count - bytesWritten) - }; - request.StreamTransferProgress += new EventHandler( - (_, e) => - { - _logger.LogDebug( - "Transferred {e.TransferredBytes}/{e.TotalBytes}", - e.TransferredBytes, - e.TotalBytes - ); - } - ); - UploadPartResponse response = await _client.UploadPartAsync(request); - if (response.HttpStatusCode != HttpStatusCode.OK) - { - throw new HttpRequestException( - $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" - ); - } - - _uploadResponses.Add(response); - - bytesWritten += MaxPartSize; - } + await WriteAsync(buffer.AsMemory(offset, count), cancellationToken); } protected override void Dispose(bool disposing) diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs index 367bdb2d2..6a1f42aeb 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferClearMLBuildJobFactory.cs @@ -37,9 +37,6 @@ public async Task CreateJobScriptAsync( + $" 'shared_file_uri': '{baseUri}',\n" + $" 'shared_file_folder': '{folder}',\n" + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") - // buildRevision + 1 because the build revision is incremented after the build job - // is finished successfully but the file should be saved with the new revision number - + (engine.IsModelPersisted ? $" 'save_model': '{engineId}',\n" : $"") + $" 'clearml': True\n" + "}\n" + "run(args)\n"; diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index 54c14a556..941731cd8 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -22,13 +22,6 @@ IClearMLQueueService clearMLQueueService public TranslationEngineType Type => TranslationEngineType.SmtTransfer; - public const string ModelDirectory = "models/"; - - public static string GetModelPath(string engineId) - { - return $"{ModelDirectory}{engineId}.zip"; - } - public async Task CreateAsync( string engineId, string? engineName, @@ -46,7 +39,7 @@ public async Task CreateAsync( ); } - var translationEngine = await _dataAccessContext.WithTransactionAsync( + TranslationEngine translationEngine = await _dataAccessContext.WithTransactionAsync( async (ct) => { var translationEngine = new TranslationEngine @@ -68,7 +61,7 @@ public async Task CreateAsync( await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) { SmtTransferEngineState state = _stateService.Get(engineId); - state.InitNew(); + await state.InitNewAsync(CancellationToken.None); } return translationEngine; } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs index c4336b3e1..d76350989 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs @@ -4,12 +4,14 @@ public class SmtTransferEngineState( ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory, ITruecaserFactory truecaserFactory, + IOptionsMonitor options, string engineId ) : AsyncDisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; + private readonly IOptionsMonitor _options = options; private readonly AsyncLock _lock = new(); private IInteractiveTranslationModel? _smtModel; @@ -22,10 +24,12 @@ string engineId public DateTime LastUsedTime { get; set; } = DateTime.UtcNow; public bool IsLoaded => _hybridEngine != null; - public void InitNew() + private string EngineDir => Path.Combine(_options.CurrentValue.EnginesDir, EngineId); + + public async Task InitNewAsync(CancellationToken cancellationToken = default) { - _smtModelFactory.InitNew(EngineId); - _transferEngineFactory.InitNew(EngineId); + await _smtModelFactory.InitNewAsync(EngineDir, cancellationToken); + await _transferEngineFactory.InitNewAsync(EngineDir, cancellationToken); } public async Task GetHybridEngineAsync(int buildRevision) @@ -40,11 +44,16 @@ public async Task GetHybridEngineAsync(int buildRevisio if (_hybridEngine is null) { - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - var truecaser = await _truecaserFactory.CreateAsync(EngineId); - _smtModel = _smtModelFactory.Create(EngineId, tokenizer, detokenizer, truecaser); - var transferEngine = _transferEngineFactory.Create(EngineId, tokenizer, detokenizer, truecaser); + LatinWordTokenizer tokenizer = new(); + LatinWordDetokenizer detokenizer = new(); + ITruecaser truecaser = await _truecaserFactory.CreateAsync(EngineDir); + _smtModel = await _smtModelFactory.CreateAsync(EngineDir, tokenizer, detokenizer, truecaser); + ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + EngineDir, + tokenizer, + detokenizer, + truecaser + ); _hybridEngine = new HybridTranslationEngine(_smtModel, transferEngine) { TargetDetokenizer = detokenizer @@ -58,9 +67,9 @@ public async Task GetHybridEngineAsync(int buildRevisio public async Task DeleteDataAsync() { await UnloadAsync(); - _smtModelFactory.Cleanup(EngineId); - _transferEngineFactory.Cleanup(EngineId); - _truecaserFactory.Cleanup(EngineId); + await _smtModelFactory.CleanupAsync(EngineDir); + await _transferEngineFactory.CleanupAsync(EngineDir); + await _truecaserFactory.CleanupAsync(EngineDir); } public async Task CommitAsync( diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs index bb0097ef6..66f4ebcf2 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs @@ -3,12 +3,14 @@ public class SmtTransferEngineStateService( ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory, - ITruecaserFactory truecaserFactory + ITruecaserFactory truecaserFactory, + IOptionsMonitor options ) : AsyncDisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; + private readonly IOptionsMonitor _options = options; private readonly ConcurrentDictionary _engineStates = new ConcurrentDictionary(); @@ -52,7 +54,13 @@ engine is not null private SmtTransferEngineState CreateState(string engineId) { - return new SmtTransferEngineState(_smtModelFactory, _transferEngineFactory, _truecaserFactory, engineId); + return new SmtTransferEngineState( + _smtModelFactory, + _transferEngineFactory, + _truecaserFactory, + _options, + engineId + ); } protected override async ValueTask DisposeAsyncCore() diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs index 7ebd1bede..efbeb1207 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs @@ -11,7 +11,7 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object? return stage switch { BuildStage.Preprocess - => CreateJob>( + => CreateJob>( engineId, buildId, "smt_transfer", @@ -26,8 +26,7 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, buildOptions ), - BuildStage.Train - => CreateJob(engineId, buildId, "smt_transfer", data, buildOptions), + BuildStage.Train => CreateJob(engineId, buildId, "smt_transfer", buildOptions), _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), }; } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs index bc457a34a..2065fe221 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs @@ -9,72 +9,56 @@ public class SmtTransferPostprocessBuildJob( ISharedFileService sharedFileService, IRepository trainSegmentPairs, ISmtModelFactory smtModelFactory, - ITruecaserFactory truecaserFactory + ITruecaserFactory truecaserFactory, + IOptionsMonitor options ) : PostprocessBuildJob(platformService, engines, lockFactory, buildJobService, logger, sharedFileService) { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly IRepository _trainSegmentPairs = trainSegmentPairs; + private readonly IOptionsMonitor _options = options; - protected override async Task DoWorkAsync( - string engineId, - string buildId, - (int, double) data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) + protected override async Task SaveModelAsync(string engineId, string buildId) { - cancellationToken.ThrowIfCancellationRequested(); - - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + await using ( + Stream engineStream = await SharedFileService.OpenReadAsync( + $"builds/{buildId}/model.tar.gz", + CancellationToken.None + ) + ) { - await _smtModelFactory.DownloadBuiltEngineAsync(engineId, cancellationToken); - int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, cancellationToken); - await PlatformService.BuildCompletedAsync( - buildId, - trainSize: data.Item1 + segmentPairsSize, - confidence: Math.Round(data.Item2, 2, MidpointRounding.AwayFromZero), - cancellationToken: CancellationToken.None + await _smtModelFactory.UpdateEngineFromAsync( + Path.Combine(_options.CurrentValue.EnginesDir, engineId), + engineStream, + CancellationToken.None ); - await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); } - - Logger.LogInformation("Build completed ({0}).", buildId); + return await TrainOnNewSegmentPairsAsync(engineId); } - private async Task TrainOnNewSegmentPairs(string engineId, CancellationToken cancellationToken) + private async Task TrainOnNewSegmentPairsAsync(string engineId) { - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new OperationCanceledException(); - - cancellationToken.ThrowIfCancellationRequested(); - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine.Id, - CancellationToken.None + IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync(p => + p.TranslationEngineRef == engineId ); if (segmentPairs.Count == 0) return segmentPairs.Count; + string engineDir = Path.Combine(_options.CurrentValue.EnginesDir, engineId); var tokenizer = new LatinWordTokenizer(); var detokenizer = new LatinWordDetokenizer(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); - - using ( - IInteractiveTranslationModel smtModel = _smtModelFactory.Create(engineId, tokenizer, detokenizer, truecaser) - ) + ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir); + using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + engineDir, + tokenizer, + detokenizer, + truecaser + ); + foreach (TrainSegmentPair segmentPair in segmentPairs) { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - await smtModel.SaveAsync(CancellationToken.None); + await smtModel.TrainSegmentAsync(segmentPair.Source, segmentPair.Target); } + await smtModel.SaveAsync(); return segmentPairs.Count; } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs deleted file mode 100644 index a6317c753..000000000 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferPreprocessBuildJob.cs +++ /dev/null @@ -1,20 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class SmtTransferPreprocessBuildJob : PreprocessBuildJob -{ - public SmtTransferPreprocessBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - ILogger logger, - IBuildJobService buildJobService, - ISharedFileService sharedFileService, - ICorpusService corpusService - ) - : base(platformService, engines, lockFactory, logger, buildJobService, sharedFileService, corpusService) - { - EngineType = TranslationEngineType.SmtTransfer; - PretranslationEnabled = false; - TrainJobRunnerType = BuildJobRunnerType.Hangfire; - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs index 7796a6260..5946b50a4 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs @@ -8,12 +8,19 @@ public class SmtTransferTrainBuildJob( ILogger logger, ISharedFileService sharedFileService, ITruecaserFactory truecaserFactory, - ISmtModelFactory smtModelFactory -) : HangfireBuildJob(platformService, engines, lockFactory, buildJobService, logger) + ISmtModelFactory smtModelFactory, + ITransferEngineFactory transferEngineFactory +) : HangfireBuildJob(platformService, engines, lockFactory, buildJobService, logger) { + private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; + private static readonly JsonSerializerOptions JsonSerializerOptions = + new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + private const int BatchSize = 128; + private readonly ISharedFileService _sharedFileService = sharedFileService; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; + private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; protected override async Task DoWorkAsync( string engineId, @@ -24,37 +31,28 @@ protected override async Task DoWorkAsync( CancellationToken cancellationToken ) { - DirectoryInfo tempDir = Directory.CreateTempSubdirectory(); - await DownloadTrainingText(buildId, tempDir.FullName, cancellationToken); + using TempDirectory tempDir = new(buildId); + string corpusDir = Path.Combine(tempDir.Path, "corpus"); + await DownloadDataAsync(buildId, corpusDir, cancellationToken); // assemble corpus - DictionaryTextCorpus sourceCorpus = - new(new TextFileText("train", Path.Combine(tempDir.FullName, "train.src.txt"))); - DictionaryTextCorpus targetCorpus = - new(new TextFileText("train", Path.Combine(tempDir.FullName, "train.trg.txt"))); - ParallelTextCorpus parallelCorpus = new ParallelTextCorpus(sourceCorpus, targetCorpus); - int corpusSize = parallelCorpus.Count(includeEmpty: false); + ITextCorpus sourceCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.src.txt")); + ITextCorpus targetCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.trg.txt")); + IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); // train SMT model - var tokenizer = new LatinWordTokenizer(); - - using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineId, tokenizer, parallelCorpus); - using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineId, tokenizer, targetCorpus); - - cancellationToken.ThrowIfCancellationRequested(); - - var progress = new BuildProgress(PlatformService, buildId); - await smtModelTrainer.TrainAsync(progress, cancellationToken); - await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); + string engineDir = Path.Combine(tempDir.Path, "engine"); + (int trainCorpusSize, double confidence) = await TrainAsync( + buildId, + engineDir, + targetCorpus, + parallelCorpus, + cancellationToken + ); cancellationToken.ThrowIfCancellationRequested(); - await smtModelTrainer.SaveAsync(CancellationToken.None); - await truecaseTrainer.SaveAsync(CancellationToken.None); - - await _smtModelFactory.UploadBuiltEngineAsync(engineId, cancellationToken); - - cancellationToken.ThrowIfCancellationRequested(); + await GeneratePretranslationsAsync(buildId, engineDir, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { @@ -63,7 +61,7 @@ CancellationToken cancellationToken engineId, buildId, BuildStage.Postprocess, - data: (smtModelTrainer.Stats.TrainCorpusSize, smtModelTrainer.Stats.Metrics["bleu"] * 100.0), + data: (trainCorpusSize, confidence), buildOptions: buildOptions, cancellationToken: cancellationToken ); @@ -72,20 +70,167 @@ CancellationToken cancellationToken } } - private async Task DownloadTrainingText(string buildId, string directory, CancellationToken cancellationToken) + protected override async Task CleanupAsync( + string engineId, + string buildId, + object? data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Canceled) + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId); + } + } + } + + private async Task DownloadDataAsync(string buildId, string corpusDir, CancellationToken cancellationToken) { - using Stream srcText = await _sharedFileService.OpenReadAsync( + Directory.CreateDirectory(corpusDir); + await using Stream srcText = await _sharedFileService.OpenReadAsync( $"builds/{buildId}/train.src.txt", cancellationToken ); - using FileStream srcFileStream = File.Create(Path.Combine(directory, "train.src.txt")); + await using FileStream srcFileStream = File.Create(Path.Combine(corpusDir, "train.src.txt")); await srcText.CopyToAsync(srcFileStream, cancellationToken); - using Stream tgtText = await _sharedFileService.OpenReadAsync( + await using Stream tgtText = await _sharedFileService.OpenReadAsync( $"builds/{buildId}/train.trg.txt", cancellationToken ); - using FileStream tgtFileStream = File.Create(Path.Combine(directory, "train.trg.txt")); + await using FileStream tgtFileStream = File.Create(Path.Combine(corpusDir, "train.trg.txt")); await tgtText.CopyToAsync(tgtFileStream, cancellationToken); } + + private async Task<(int TrainCorpusSize, double Confidence)> TrainAsync( + string buildId, + string engineDir, + ITextCorpus targetCorpus, + IParallelTextCorpus parallelCorpus, + CancellationToken cancellationToken + ) + { + await _smtModelFactory.InitNewAsync(engineDir, cancellationToken); + LatinWordTokenizer tokenizer = new(); + int trainCorpusSize; + double confidence; + using ITrainer smtModelTrainer = await _smtModelFactory.CreateTrainerAsync( + engineDir, + tokenizer, + parallelCorpus, + cancellationToken + ); + using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( + engineDir, + tokenizer, + targetCorpus, + cancellationToken + ); + cancellationToken.ThrowIfCancellationRequested(); + + var progress = new BuildProgress(PlatformService, buildId); + await smtModelTrainer.TrainAsync(progress, cancellationToken); + await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); + + trainCorpusSize = smtModelTrainer.Stats.TrainCorpusSize; + confidence = smtModelTrainer.Stats.Metrics["bleu"] * 100.0; + + cancellationToken.ThrowIfCancellationRequested(); + + await smtModelTrainer.SaveAsync(cancellationToken); + await truecaseTrainer.SaveAsync(cancellationToken); + + await using Stream engineStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/model.tar.gz", + cancellationToken + ); + await _smtModelFactory.SaveEngineToAsync(engineDir, engineStream, cancellationToken); + return (trainCorpusSize, confidence); + } + + private async Task GeneratePretranslationsAsync( + string buildId, + string engineDir, + CancellationToken cancellationToken + ) + { + await using Stream sourceStream = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/pretranslate.src.json", + cancellationToken + ); + + IAsyncEnumerable pretranslations = JsonSerializer + .DeserializeAsyncEnumerable(sourceStream, JsonSerializerOptions, cancellationToken) + .OfType(); + + await using Stream targetStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/pretranslate.trg.json", + cancellationToken + ); + await using Utf8JsonWriter targetWriter = new(targetStream, PretranslateWriterOptions); + + LatinWordTokenizer tokenizer = new(); + LatinWordDetokenizer detokenizer = new(); + ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir, CancellationToken.None); + using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + engineDir, + tokenizer, + detokenizer, + truecaser, + cancellationToken + ); + using ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + engineDir, + tokenizer, + detokenizer, + truecaser, + cancellationToken + ); + HybridTranslationEngine hybridEngine = new(smtModel, transferEngine) { TargetDetokenizer = detokenizer }; + + await foreach (IReadOnlyList batch in BatchAsync(pretranslations)) + { + string[] segments = batch.Select(p => p.Translation).ToArray(); + IReadOnlyList results = await hybridEngine.TranslateBatchAsync( + segments, + cancellationToken + ); + foreach ((Pretranslation pretranslation, TranslationResult result) in batch.Zip(results)) + { + JsonSerializer.Serialize( + targetWriter, + pretranslation with + { + Translation = result.Translation + }, + JsonSerializerOptions + ); + } + } + } + + public static async IAsyncEnumerable> BatchAsync( + IAsyncEnumerable pretranslations + ) + { + List batch = []; + await foreach (Pretranslation item in pretranslations) + { + batch.Add(item); + if (batch.Count == BatchSize) + { + yield return batch; + batch = []; + } + } + if (batch.Count > 0) + yield return batch; + } } diff --git a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs index d3eb4f7f7..25b810a6c 100644 --- a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs @@ -1,24 +1,19 @@ namespace SIL.Machine.AspNetCore.Services; -public class ThotSmtModelFactory( - IOptionsMonitor options, - IOptionsMonitor engineOptions, - ISharedFileService sharedFileService -) : ISmtModelFactory +public class ThotSmtModelFactory(IOptionsMonitor options) : ISmtModelFactory { private readonly IOptionsMonitor _options = options; - private readonly IOptionsMonitor _engineOptions = engineOptions; - private readonly ISharedFileService _sharedFileService = sharedFileService; - public IInteractiveTranslationModel Create( - string engineId, + public Task CreateAsync( + string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser + ITruecaser truecaser, + CancellationToken cancellationToken = default ) { - string smtConfigFileName = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId, "smt.cfg"); - var model = new ThotSmtModel(ThotWordAlignmentModelType.Hmm, smtConfigFileName) + string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); + IInteractiveTranslationModel model = new ThotSmtModel(ThotWordAlignmentModelType.Hmm, smtConfigFileName) { SourceTokenizer = tokenizer, TargetTokenizer = tokenizer, @@ -27,64 +22,39 @@ ITruecaser truecaser LowercaseTarget = true, Truecaser = truecaser }; - return model; + return Task.FromResult(model); } - public ITrainer CreateTrainer( - string engineId, + public Task CreateTrainerAsync( + string engineDir, IRangeTokenizer tokenizer, - IParallelTextCorpus corpus + IParallelTextCorpus corpus, + CancellationToken cancellationToken = default ) { - string smtConfigFileName = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId, "smt.cfg"); - return new ThotSmtModelTrainer(ThotWordAlignmentModelType.Hmm, corpus, smtConfigFileName) + string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); + ITrainer trainer = new ThotSmtModelTrainer(ThotWordAlignmentModelType.Hmm, corpus, smtConfigFileName) { SourceTokenizer = tokenizer, TargetTokenizer = tokenizer, LowercaseSource = true, LowercaseTarget = true }; + return Task.FromResult(trainer); } - public async Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) { - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); - if (!Directory.Exists(engineDir)) - Directory.CreateDirectory(engineDir); - string sharedFilePath = $"models/{engineId}.zip"; - using Stream sharedStream = await _sharedFileService.OpenReadAsync(sharedFilePath, cancellationToken); - ZipFile.ExtractToDirectory(sharedStream, engineDir, overwriteFiles: true); - await _sharedFileService.DeleteAsync(sharedFilePath); - } - - public async Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken) - { - // create zip archive in memory stream - // This cannot be created directly to the shared stream because it all needs to be written at once - using var memoryStream = new MemoryStream(); - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); - ZipFile.CreateFromDirectory(engineDir, memoryStream); - - // copy to shared file - memoryStream.Seek(0, SeekOrigin.Begin); - string sharedFilePath = $"models/{engineId}.zip"; - using Stream sharedStream = await _sharedFileService.OpenWriteAsync(sharedFilePath, cancellationToken); - await sharedStream.WriteAsync(memoryStream.ToArray().AsMemory(0, (int)memoryStream.Length), cancellationToken); - } - - public void InitNew(string engineId) - { - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); if (!Directory.Exists(engineDir)) Directory.CreateDirectory(engineDir); ZipFile.ExtractToDirectory(_options.CurrentValue.NewModelFile, engineDir); + return Task.CompletedTask; } - public void Cleanup(string engineId) + public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) { - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); if (!Directory.Exists(engineDir)) - return; + return Task.CompletedTask; DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "lm")); DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "tm")); string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -92,5 +62,49 @@ public void Cleanup(string engineId) File.Delete(smtConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); + return Task.CompletedTask; + } + + public async Task UpdateEngineFromAsync( + string engineDir, + Stream source, + CancellationToken cancellationToken = default + ) + { + if (!Directory.Exists(engineDir)) + Directory.CreateDirectory(engineDir); + + await using MemoryStream memoryStream = new(); + await using (GZipStream gzipStream = new(source, CompressionMode.Decompress)) + { + await gzipStream.CopyToAsync(memoryStream, cancellationToken); + } + memoryStream.Seek(0, SeekOrigin.Begin); + await TarFile.ExtractToDirectoryAsync( + memoryStream, + engineDir, + overwriteFiles: true, + cancellationToken: cancellationToken + ); + } + + public async Task SaveEngineToAsync( + string engineDir, + Stream destination, + CancellationToken cancellationToken = default + ) + { + // create zip archive in memory stream + // This cannot be created directly to the shared stream because it all needs to be written at once + await using MemoryStream memoryStream = new(); + await TarFile.CreateFromDirectoryAsync( + engineDir, + memoryStream, + includeBaseDirectory: false, + cancellationToken: cancellationToken + ); + memoryStream.Seek(0, SeekOrigin.Begin); + await using GZipStream gzipStream = new(destination, CompressionMode.Compress); + await memoryStream.CopyToAsync(gzipStream, cancellationToken); } } diff --git a/src/SIL.Machine.AspNetCore/Services/TransferEngineFactory.cs b/src/SIL.Machine.AspNetCore/Services/TransferEngineFactory.cs index 2aed5b0d9..300501b0b 100644 --- a/src/SIL.Machine.AspNetCore/Services/TransferEngineFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/TransferEngineFactory.cs @@ -1,20 +1,18 @@ namespace SIL.Machine.AspNetCore.Services; -public class TransferEngineFactory(IOptionsMonitor engineOptions) : ITransferEngineFactory +public class TransferEngineFactory : ITransferEngineFactory { - private readonly IOptionsMonitor _engineOptions = engineOptions; - - public ITranslationEngine? Create( - string engineId, + public Task CreateAsync( + string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser + ITruecaser truecaser, + CancellationToken cancellationToken = default ) { - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); string hcTrgConfigFileName = Path.Combine(engineDir, "trg-hc.xml"); - TransferEngine? transferEngine = null; + ITranslationEngine? transferEngine = null; if (File.Exists(hcSrcConfigFileName) && File.Exists(hcTrgConfigFileName)) { var hcTraceManager = new TraceManager(); @@ -37,19 +35,19 @@ ITruecaser truecaser Truecaser = truecaser }; } - return transferEngine; + return Task.FromResult(transferEngine); } - public void InitNew(string engineId) + public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) { // TODO: generate source and target config files + return Task.CompletedTask; } - public void Cleanup(string engineId) + public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) { - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); if (!Directory.Exists(engineDir)) - return; + return Task.CompletedTask; string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); if (File.Exists(hcSrcConfigFileName)) File.Delete(hcSrcConfigFileName); @@ -58,5 +56,6 @@ public void Cleanup(string engineId) File.Delete(hcTrgConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); + return Task.CompletedTask; } } diff --git a/src/SIL.Machine.AspNetCore/Services/UnigramTruecaserFactory.cs b/src/SIL.Machine.AspNetCore/Services/UnigramTruecaserFactory.cs index 2e66c1a16..c5b8d70d3 100644 --- a/src/SIL.Machine.AspNetCore/Services/UnigramTruecaserFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/UnigramTruecaserFactory.cs @@ -1,32 +1,37 @@ namespace SIL.Machine.AspNetCore.Services; -public class UnigramTruecaserFactory(IOptionsMonitor engineOptions) : ITruecaserFactory +public class UnigramTruecaserFactory : ITruecaserFactory { - private readonly IOptionsMonitor _engineOptions = engineOptions; - - public async Task CreateAsync(string engineId) + public async Task CreateAsync(string engineDir, CancellationToken cancellationToken = default) { var truecaser = new UnigramTruecaser(); - string path = GetModelPath(engineId); + string path = GetModelPath(engineDir); await truecaser.LoadAsync(path); return truecaser; } - public ITrainer CreateTrainer(string engineId, ITokenizer tokenizer, ITextCorpus corpus) + public Task CreateTrainerAsync( + string engineDir, + ITokenizer tokenizer, + ITextCorpus corpus, + CancellationToken cancellationToken = default + ) { - string path = GetModelPath(engineId); - return new UnigramTruecaserTrainer(path, corpus) { Tokenizer = tokenizer }; + string path = GetModelPath(engineDir); + ITrainer trainer = new UnigramTruecaserTrainer(path, corpus) { Tokenizer = tokenizer }; + return Task.FromResult(trainer); } - public void Cleanup(string engineId) + public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) { - string path = GetModelPath(engineId); + string path = GetModelPath(engineDir); if (File.Exists(path)) File.Delete(path); + return Task.CompletedTask; } - private string GetModelPath(string engineId) + private static string GetModelPath(string engineDir) { - return Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId, "unigram-casing-model.txt"); + return Path.Combine(engineDir, "unigram-casing-model.txt"); } } diff --git a/src/SIL.Machine.AspNetCore/Usings.cs b/src/SIL.Machine.AspNetCore/Usings.cs index 2f2bb7afb..7b5434f72 100644 --- a/src/SIL.Machine.AspNetCore/Usings.cs +++ b/src/SIL.Machine.AspNetCore/Usings.cs @@ -2,6 +2,8 @@ global using System.Data; global using System.Diagnostics; global using System.Diagnostics.CodeAnalysis; +global using System.Formats.Tar; +global using System.Globalization; global using System.IO.Compression; global using System.Linq.Expressions; global using System.Net; @@ -19,6 +21,7 @@ global using Amazon.Runtime; global using Amazon.S3; global using Amazon.S3.Model; +global using CommunityToolkit.HighPerformance; global using Grpc.Core; global using Grpc.Core.Interceptors; global using Grpc.Net.Client.Configuration; diff --git a/src/SIL.Machine.AspNetCore/Utils/DictionaryStringConverter.cs b/src/SIL.Machine.AspNetCore/Utils/DictionaryStringConverter.cs new file mode 100644 index 000000000..28e8eeac1 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Utils/DictionaryStringConverter.cs @@ -0,0 +1,73 @@ +namespace SIL.Machine.AspNetCore.Utils; + +internal sealed class DictionaryStringStringConverter : JsonConverter> +{ + public override IReadOnlyDictionary Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options + ) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException($"JsonTokenType was of type {reader.TokenType}, only objects are supported"); + } + + var dictionary = new Dictionary(); + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + return dictionary; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("JsonTokenType was not PropertyName"); + } + + var propertyName = reader.GetString(); + + if (string.IsNullOrWhiteSpace(propertyName)) + { + throw new JsonException("Failed to get property name"); + } + + reader.Read(); + + dictionary.Add(propertyName!, ExtractValue(ref reader)); + } + + return dictionary; + } + + public override void Write( + Utf8JsonWriter writer, + IReadOnlyDictionary value, + JsonSerializerOptions options + ) + { + JsonSerializer.Serialize(writer, value, options); + } + + private static string ExtractValue(ref Utf8JsonReader reader) + { + switch (reader.TokenType) + { + case JsonTokenType.String: + return reader.GetString() ?? "Error Reading String."; + case JsonTokenType.False: + return "false"; + case JsonTokenType.True: + return "true"; + case JsonTokenType.Null: + return "null"; + case JsonTokenType.Number: + if (reader.TryGetDouble(out var result)) + return result.ToString(CultureInfo.InvariantCulture); + return "Error Reading Number."; + default: + throw new JsonException($"'{reader.TokenType}' is not supported"); + } + } +} diff --git a/src/SIL.Machine.Serval.EngineServer/appsettings.json b/src/SIL.Machine.Serval.EngineServer/appsettings.json index 828d2a41e..12f4a051c 100644 --- a/src/SIL.Machine.Serval.EngineServer/appsettings.json +++ b/src/SIL.Machine.Serval.EngineServer/appsettings.json @@ -20,7 +20,7 @@ }, { "TranslationEngineType": "SmtTransfer", - "ModelType": "hmm", + "ModelType": "thot", "Queue": "cpu_only", "DockerImage": "ghcr.io/sillsdev/machine.py:latest" } diff --git a/src/SIL.Machine.Serval.JobServer/appsettings.json b/src/SIL.Machine.Serval.JobServer/appsettings.json index 2e83382ba..4ff49d691 100644 --- a/src/SIL.Machine.Serval.JobServer/appsettings.json +++ b/src/SIL.Machine.Serval.JobServer/appsettings.json @@ -20,7 +20,7 @@ }, { "TranslationEngineType": "SmtTransfer", - "ModelType": "hmm", + "ModelType": "thot", "Queue": "jobs_backlog", "DockerImage": "ghcr.io/sillsdev/machine.py:latest" } diff --git a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj index 862ebc60e..e7e5b9b27 100644 --- a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj +++ b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj @@ -12,7 +12,7 @@ - + diff --git a/src/SIL.Machine/Utils/ProgressStatus.cs b/src/SIL.Machine/Utils/ProgressStatus.cs index 5ad7e74d8..89b7eb371 100644 --- a/src/SIL.Machine/Utils/ProgressStatus.cs +++ b/src/SIL.Machine/Utils/ProgressStatus.cs @@ -7,7 +7,7 @@ public struct ProgressStatus : IEquatable public ProgressStatus(int step, int stepCount, string message = null) : this(step, stepCount == 0 ? 1.0 : (double)step / stepCount, message) { } - public ProgressStatus(int step, double? percentCompleted = null, string message = null, int? queueDepth = null) + public ProgressStatus(int step, double? percentCompleted = null, string message = null) { Step = step; PercentCompleted = percentCompleted; diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs index 98c08b00f..33d25da3b 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs @@ -24,18 +24,7 @@ public async Task StartBuildAsync() public async Task CancelBuildAsync_Building() { using var env = new TestEnvironment(); - - var cts = new CancellationTokenSource(); - env.ClearMLService.When(x => x.StopTaskAsync("job1", Arg.Any())).Do(_ => cts.Cancel()); - env.TrainJobFunc = async () => - { - await env.BuildJobService.BuildJobStartedAsync("engine1", "build1"); - - while (!cts.IsCancellationRequested) - await Task.Delay(50); - - await env.BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); - }; + env.UseInfiniteTrainJob(); TranslationEngine engine = env.Engines.Get("engine1"); Assert.That(engine.BuildRevision, Is.EqualTo(1)); @@ -62,18 +51,7 @@ public void CancelBuildAsync_NotBuilding() public async Task DeleteAsync_WhileBuilding() { using var env = new TestEnvironment(); - - var cts = new CancellationTokenSource(); - env.ClearMLService.When(x => x.StopTaskAsync("job1", Arg.Any())).Do(_ => cts.Cancel()); - env.TrainJobFunc = async () => - { - await env.BuildJobService.BuildJobStartedAsync("engine1", "build1"); - - while (!cts.IsCancellationRequested) - await Task.Delay(50); - - await env.BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); - }; + env.UseInfiniteTrainJob(); TranslationEngine engine = env.Engines.Get("engine1"); Assert.That(engine.BuildRevision, Is.EqualTo(1)); @@ -84,8 +62,7 @@ public async Task DeleteAsync_WhileBuilding() Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.DeleteAsync("engine1"); // ensure that the train job has completed - if (env.TrainJobTask is not null) - await env.TrainJobTask; + await env.WaitForBuildToFinishAsync(); Assert.That(env.Engines.Contains("engine1"), Is.False); } @@ -95,13 +72,16 @@ private class TestEnvironment : ObjectModel.DisposableBase private readonly BackgroundJobClient _jobClient; private BackgroundJobServer _jobServer; private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly CancellationTokenSource _cancellationTokenSource = new(); + private Func _trainJobFunc; + private Task? _trainJobTask; public TestEnvironment() { if (!Sldr.IsInitialized) Sldr.Initialize(offlineMode: true); - TrainJobFunc = RunMockTrainJob; + _trainJobFunc = RunNormalTrainJob; Engines = new MemoryRepository(); Engines.Add( new TranslationEngine @@ -138,7 +118,10 @@ public TestEnvironment() .Returns(Task.FromResult("job1")); ClearMLService .When(x => x.EnqueueTaskAsync("job1", Arg.Any(), Arg.Any())) - .Do(_ => TrainJobTask = Task.Run(TrainJobFunc)); + .Do(_ => _trainJobTask = Task.Run(_trainJobFunc)); + ClearMLService + .When(x => x.StopTaskAsync("job1", Arg.Any())) + .Do(_ => _cancellationTokenSource.Cancel()); SharedFileService = new SharedFileService(Substitute.For()); var buildJobOptions = Substitute.For>(); buildJobOptions.CurrentValue.Returns( @@ -156,7 +139,7 @@ public TestEnvironment() new ClearMLBuildQueue() { TranslationEngineType = TranslationEngineType.SmtTransfer, - ModelType = "hmm", + ModelType = "thot", DockerImage = "default", Queue = "default" } @@ -201,8 +184,6 @@ public TestEnvironment() public IClearMLService ClearMLService { get; } public ISharedFileService SharedFileService { get; } public IBuildJobService BuildJobService { get; } - public Func TrainJobFunc { get; set; } - public Task? TrainJobTask { get; private set; } public void StopServer() { @@ -240,9 +221,11 @@ private NmtEngineService CreateService() ); } - public Task WaitForBuildToFinishAsync() + public async Task WaitForBuildToFinishAsync() { - return WaitForBuildState(e => e.CurrentBuild is null); + await WaitForBuildState(e => e.CurrentBuild is null); + if (_trainJobTask is not null) + await _trainJobTask; } public Task WaitForBuildToStartAsync() @@ -252,6 +235,11 @@ public Task WaitForBuildToStartAsync() ); } + public void UseInfiniteTrainJob() + { + _trainJobFunc = RunInfiniteTrainJob; + } + private async Task WaitForBuildState(Func predicate) { using ISubscription subscription = await Engines.SubscribeAsync(e => @@ -260,20 +248,17 @@ private async Task WaitForBuildState(Func predicate) while (true) { TranslationEngine? engine = subscription.Change.Entity; - if (engine is not null && predicate(engine)) + if (engine is null || predicate(engine)) break; await subscription.WaitForChangeAsync(); } } - private async Task RunMockTrainJob() + private async Task RunNormalTrainJob() { await BuildJobService.BuildJobStartedAsync("engine1", "build1"); - await using (Stream stream = await SharedFileService.OpenWriteAsync("builds/build1/pretranslate.trg.json")) - { - await JsonSerializer.SerializeAsync(stream, Array.Empty()); - } + await using Stream stream = await SharedFileService.OpenWriteAsync("builds/build1/pretranslate.trg.json"); await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, @@ -284,9 +269,20 @@ await BuildJobService.StartBuildJobAsync( ); } + private async Task RunInfiniteTrainJob() + { + await BuildJobService.BuildJobStartedAsync("engine1", "build1"); + + while (!_cancellationTokenSource.IsCancellationRequested) + await Task.Delay(50); + + await BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); + } + protected override void DisposeManagedResources() { _jobServer.Dispose(); + _cancellationTokenSource.Dispose(); } private class EnvActivator(TestEnvironment env) : JobActivator @@ -308,14 +304,14 @@ public override object ActivateJob(Type jobType) new LanguageTagService() ); } - if (jobType == typeof(NmtPostprocessBuildJob)) + if (jobType == typeof(PostprocessBuildJob)) { - return new NmtPostprocessBuildJob( + return new PostprocessBuildJob( _env.PlatformService, _env.Engines, _env._lockFactory, _env.BuildJobService, - Substitute.For>(), + Substitute.For>(), _env.SharedFileService ); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/PreprocessBuildJobTests.cs similarity index 88% rename from tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs rename to tests/SIL.Machine.AspNetCore.Tests/Services/PreprocessBuildJobTests.cs index 9fbbe4149..d082b8fa2 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/PreprocessBuildJobTests.cs @@ -1,7 +1,7 @@ namespace SIL.Machine.AspNetCore.Services; [TestFixture] -public class NmtPreprocessBuildJobTests +public class PreprocessBuildJobTests { [Test] public async Task RunAsync_FilterOutEverything() @@ -233,6 +233,15 @@ public void RunAsync_UnknownLanguageTagsNoData() }); } + [Test] + public async Task RunAsync_UnknownLanguageTagsNoDataSmtTransfer() + { + using TestEnvironment env = new(); + Corpus corpus1 = env.DefaultTextFileCorpus with { SourceLanguage = "xxx", TargetLanguage = "zzz" }; + + await env.RunBuildJobAsync(corpus1, engineId: "engine2", engineType: TranslationEngineType.SmtTransfer); + } + private class TestEnvironment : ObjectModel.DisposableBase { private static readonly string TestDataPath = Path.Combine( @@ -252,9 +261,7 @@ private class TestEnvironment : ObjectModel.DisposableBase public MemoryRepository Engines { get; } public IDistributedReaderWriterLockFactory LockFactory { get; } public IBuildJobService BuildJobService { get; } - public ILogger Logger { get; } public IClearMLService ClearMLService { get; } - public NmtPreprocessBuildJob BuildJob { get; } public IOptionsMonitor BuildJobOptions { get; } public Corpus DefaultTextFileCorpus { get; } @@ -267,7 +274,7 @@ public TestEnvironment() if (!Sldr.IsInitialized) Sldr.Initialize(offlineMode: true); - _tempDir = new TempDirectory("NmtPreprocessBuildJobTests"); + _tempDir = new TempDirectory("PreprocessBuildJobTests"); ZipParatextProject("pt-source1"); ZipParatextProject("pt-source2"); @@ -401,7 +408,7 @@ public TestEnvironment() new ClearMLBuildQueue() { TranslationEngineType = TranslationEngineType.SmtTransfer, - ModelType = "hmm", + ModelType = "thot", DockerImage = "default", Queue = "default" } @@ -428,18 +435,14 @@ public TestEnvironment() ) .Returns(Task.FromResult("job1")); SharedFileService = new SharedFileService(Substitute.For()); - Logger = Substitute.For>(); BuildJobService = new BuildJobService( - [ [ new HangfireBuildJobRunner( Substitute.For(), [new NmtHangfireBuildJobFactory()] - [new NmtHangfireBuildJobFactory()] ), new ClearMLBuildJobRunner( ClearMLService, - [ [ new NmtClearMLBuildJobFactory( SharedFileService, @@ -452,30 +455,58 @@ [new NmtHangfireBuildJobFactory()] ], Engines ); - BuildJob = new NmtPreprocessBuildJob( - PlatformService, - Engines, - LockFactory, - Logger, - BuildJobService, - SharedFileService, - CorpusService, - new LanguageTagService() - ) + } + + public PreprocessBuildJob GetBuildJob(TranslationEngineType engineType) + { + switch (engineType) { - Seed = 1234 - }; + case TranslationEngineType.Nmt: + { + return new NmtPreprocessBuildJob( + PlatformService, + Engines, + LockFactory, + Substitute.For>(), + BuildJobService, + SharedFileService, + CorpusService, + new LanguageTagService() + ) + { + Seed = 1234 + }; + } + case TranslationEngineType.SmtTransfer: + { + return new PreprocessBuildJob( + PlatformService, + Engines, + LockFactory, + Substitute.For>(), + BuildJobService, + SharedFileService, + CorpusService + ) + { + Seed = 1234 + }; + } + default: + throw new InvalidOperationException("Unknown engine type."); + } + ; } - public Task RunBuildJobAsync(Corpus corpus, bool useKeyTerms = true, string engineId = "engine1") + public Task RunBuildJobAsync( + Corpus corpus, + bool useKeyTerms = true, + string engineId = "engine1", + TranslationEngineType engineType = TranslationEngineType.Nmt + ) { - return BuildJob.RunAsync( - engineId, - "build1", - [corpus], - useKeyTerms ? null : "{\"use_key_terms\":false}", - default - ); + return GetBuildJob(engineType) + .RunAsync(engineId, "build1", [corpus], useKeyTerms ? null : "{\"use_key_terms\":false}", default); } public async Task<(int Source1Count, int Source2Count, int TargetCount, int TermCount)> GetTrainCountAsync() diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index 6b7cfa685..40dbed2f5 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -1,7 +1,4 @@ -using NSubstitute.ClearExtensions; -using SIL.Machine.Corpora; - -namespace SIL.Machine.AspNetCore.Services; +namespace SIL.Machine.AspNetCore.Services; [TestFixture] public class SmtTransferEngineServiceTests @@ -24,14 +21,16 @@ public async Task CreateAsync() Assert.That(engine?.BuildRevision, Is.EqualTo(0)); Assert.That(engine?.IsModelPersisted, Is.True); }); - env.SmtModelFactory.Received().InitNew(EngineId2); - env.TransferEngineFactory.Received().InitNew(EngineId2); + string engineDir = Path.Combine("translation_engines", EngineId2); + _ = env.SmtModelFactory.Received().InitNewAsync(engineDir); + _ = env.TransferEngineFactory.Received().InitNewAsync(engineDir); } - [Test] - public async Task StartBuildAsync() + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task StartBuildAsync(BuildJobRunnerType trainJobRunnerType) { - using var env = new TestEnvironment(); + using var env = new TestEnvironment(trainJobRunnerType); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.BuildRevision, Is.EqualTo(1)); // ensure that the SMT model was loaded before training @@ -54,14 +53,14 @@ await env.Service.StartBuildAsync( ] ); await env.WaitForBuildToFinishAsync(); - await env + _ = env .SmtBatchTrainer.Received() .TrainAsync(Arg.Any>(), Arg.Any()); - await env + _ = env .TruecaserTrainer.Received() .TrainAsync(Arg.Any>(), Arg.Any()); - await env.SmtBatchTrainer.Received().SaveAsync(Arg.Any()); - await env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); + _ = env.SmtBatchTrainer.Received().SaveAsync(Arg.Any()); + _ = env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); Assert.That(engine.BuildRevision, Is.EqualTo(2)); @@ -69,34 +68,26 @@ await env env.SmtModel.ClearReceivedCalls(); await env.Service.TranslateAsync(EngineId1, n: 1, "esto es una prueba."); env.SmtModel.Received().Dispose(); - await env.SmtModel.DidNotReceive().SaveAsync(); - await env.Truecaser.DidNotReceive().SaveAsync(); + _ = env.SmtModel.DidNotReceive().SaveAsync(); + _ = env.Truecaser.DidNotReceive().SaveAsync(); } - [Test] - public async Task CancelBuildAsync_Building() + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerType) { - using var env = new TestEnvironment(); - await env.SmtBatchTrainer.TrainAsync( - Arg.Any>(), - Arg.Do(cancellationToken => - { - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - Thread.Sleep(100); - } - }) - ); + using var env = new TestEnvironment(trainJobRunnerType); + env.UseInfiniteTrainJob(); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); - await env.WaitForBuildToStartAsync(); + await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.CancelBuildAsync(EngineId1); await env.WaitForBuildToFinishAsync(); - await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); - await env.TruecaserTrainer.DidNotReceive().SaveAsync(); + _ = env.SmtBatchTrainer.DidNotReceive().SaveAsync(); + _ = env.TruecaserTrainer.DidNotReceive().SaveAsync(); engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); } @@ -111,19 +102,9 @@ public void CancelBuildAsync_NotBuilding() [Test] public async Task StartBuildAsync_RestartUnfinishedBuild() { - using var env = new TestEnvironment(); + using var env = new TestEnvironment(BuildJobRunnerType.Hangfire); + env.UseInfiniteTrainJob(); - await env.SmtBatchTrainer.TrainAsync( - Arg.Any>(), - Arg.Do(cancellationToken => - { - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - Thread.Sleep(100); - } - }) - ); await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); @@ -134,7 +115,7 @@ await env.SmtBatchTrainer.TrainAsync( engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Pending)); - await env.PlatformService.Received().BuildRestartingAsync(BuildId1); + _ = env.PlatformService.Received().BuildRestartingAsync(BuildId1); env.SmtBatchTrainer.ClearSubstitute(ClearOptions.CallActions); env.StartServer(); await env.WaitForBuildToFinishAsync(); @@ -142,62 +123,45 @@ await env.SmtBatchTrainer.TrainAsync( Assert.That(engine.CurrentBuild, Is.Null); } - [Test] - public async Task DeleteAsync_WhileBuilding() + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task DeleteAsync_WhileBuilding(BuildJobRunnerType trainJobRunnerType) { - using var env = new TestEnvironment(); - await env.SmtBatchTrainer.TrainAsync( - Arg.Any>(), - Arg.Do(cancellationToken => - { - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - Thread.Sleep(100); - } - }) - ); + using var env = new TestEnvironment(trainJobRunnerType); + env.UseInfiniteTrainJob(); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); - await env.WaitForBuildToStartAsync(); + await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.DeleteAsync(EngineId1); - // ensure that the build job was canceled + await env.WaitForBuildToFinishAsync(); await env.WaitForAllHangfireJobsToFinishAsync(); - await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); - await env.TruecaserTrainer.DidNotReceive().SaveAsync(); + _ = env.SmtBatchTrainer.DidNotReceive().SaveAsync(); + _ = env.TruecaserTrainer.DidNotReceive().SaveAsync(); Assert.That(env.Engines.Contains(EngineId1), Is.False); } - [Test] - public async Task TrainSegmentPairAsync() + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task TrainSegmentPairAsync(BuildJobRunnerType trainJobRunnerType) { - using var env = new TestEnvironment(); - bool training = true; - await env.SmtBatchTrainer.TrainAsync( - Arg.Any>(), - Arg.Do(cancellationToken => - { - while (training) - { - cancellationToken.ThrowIfCancellationRequested(); - Thread.Sleep(100); - } - }) - ); + using var env = new TestEnvironment(trainJobRunnerType); + env.UseInfiniteTrainJob(); + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); await env.WaitForBuildToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); - training = false; + env.StopTraining(); await env.WaitForBuildToFinishAsync(); engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); Assert.That(engine.BuildRevision, Is.EqualTo(2)); - await env.SmtModel.Received(2).TrainSegmentAsync("esto es una prueba.", "this is a test.", true); + _ = env.SmtModel.Received(2).TrainSegmentAsync("esto es una prueba.", "this is a test.", true); } [Test] @@ -207,7 +171,7 @@ public async Task CommitAsync_LoadedInactive() await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await Task.Delay(10); await env.CommitAsync(TimeSpan.Zero); - await env.SmtModel.Received().SaveAsync(); + _ = env.SmtModel.Received().SaveAsync(); Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.False); } @@ -217,7 +181,7 @@ public async Task CommitAsync_LoadedActive() using var env = new TestEnvironment(); await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await env.CommitAsync(TimeSpan.FromHours(1)); - await env.SmtModel.Received().SaveAsync(); + _ = env.SmtModel.Received().SaveAsync(); Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.True); } @@ -247,10 +211,14 @@ private class TestEnvironment : ObjectModel.DisposableBase private BackgroundJobServer _jobServer; private readonly ITruecaserFactory _truecaserFactory; private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly BuildJobRunnerType _trainJobRunnerType; + private Task? _trainJobTask; + private readonly CancellationTokenSource _cancellationTokenSource = new(); + private bool _training = true; - public TestEnvironment() + public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerType.ClearML) { - TrainJobFunc = RunMockTrainJob; + _trainJobRunnerType = trainJobRunnerType; Engines = new MemoryRepository(); Engines.Add( new TranslationEngine @@ -303,7 +271,7 @@ public TestEnvironment() new ClearMLBuildQueue() { TranslationEngineType = TranslationEngineType.SmtTransfer, - ModelType = "hmm", + ModelType = "thot", DockerImage = "default", Queue = "default" } @@ -311,6 +279,24 @@ public TestEnvironment() } ); ClearMLService = Substitute.For(); + ClearMLService + .GetProjectIdAsync("engine1", Arg.Any()) + .Returns(Task.FromResult("project1")); + ClearMLService + .CreateTaskAsync( + "build1", + "project1", + Arg.Any(), + Arg.Any(), + Arg.Any() + ) + .Returns(Task.FromResult("job1")); + ClearMLService + .When(x => x.EnqueueTaskAsync("job1", Arg.Any(), Arg.Any())) + .Do(_ => _trainJobTask = Task.Run(RunTrainJob)); + ClearMLService + .When(x => x.StopTaskAsync("job1", Arg.Any())) + .Do(_ => _cancellationTokenSource.Cancel()); ClearMLMonitorService = new ClearMLMonitorService( Substitute.For(), ClearMLService, @@ -353,9 +339,6 @@ [new SmtTransferClearMLBuildJobFactory(SharedFileService, Engines)], public ISharedFileService SharedFileService { get; } public IBuildJobService BuildJobService { get; } - public Func TrainJobFunc { get; set; } - - public Task? TrainJobTask { get; private set; } public async Task CommitAsync(TimeSpan inactiveTimeout) { @@ -375,6 +358,26 @@ public void StartServer() Service = CreateService(); } + public void UseInfiniteTrainJob() + { + SmtBatchTrainer.TrainAsync( + Arg.Any>(), + Arg.Do(cancellationToken => + { + while (_training) + { + cancellationToken.ThrowIfCancellationRequested(); + Thread.Sleep(100); + } + }) + ); + } + + public void StopTraining() + { + _training = false; + } + private BackgroundJobServer CreateJobServer() { var jobServerOptions = new BackgroundJobServerOptions @@ -388,7 +391,14 @@ private BackgroundJobServer CreateJobServer() private SmtTransferEngineStateService CreateStateService() { - return new SmtTransferEngineStateService(SmtModelFactory, TransferEngineFactory, _truecaserFactory); + var options = Substitute.For>(); + options.CurrentValue.Returns(new SmtTransferEngineOptions()); + return new SmtTransferEngineStateService( + SmtModelFactory, + TransferEngineFactory, + _truecaserFactory, + options + ); } private SmtTransferEngineService CreateService() @@ -479,20 +489,22 @@ [new Phrase(Range.Create(0, 5), 5)] ); factory - .Create( + .CreateAsync( Arg.Any(), Arg.Any>(), Arg.Any>(), - Arg.Any() + Arg.Any(), + Arg.Any() ) - .Returns(SmtModel); + .Returns(Task.FromResult(SmtModel)); factory - .CreateTrainer( + .CreateTrainerAsync( Arg.Any(), Arg.Any>(), - Arg.Any() + Arg.Any(), + Arg.Any() ) - .Returns(SmtBatchTrainer); + .Returns(Task.FromResult(SmtBatchTrainer)); return factory; } @@ -529,13 +541,14 @@ [new Phrase(Range.Create(0, 5), 5)] ) ); factory - .Create( + .CreateAsync( Arg.Any(), Arg.Any>(), Arg.Any>(), - Arg.Any() + Arg.Any(), + Arg.Any() ) - .Returns(engine); + .Returns(Task.FromResult(engine)); return factory; } @@ -544,8 +557,13 @@ private ITruecaserFactory CreateTruecaserFactory() ITruecaserFactory factory = Substitute.For(); factory.CreateAsync(Arg.Any()).Returns(Task.FromResult(Truecaser)); factory - .CreateTrainer(Arg.Any(), Arg.Any>(), Arg.Any()) - .Returns(TruecaserTrainer); + .CreateTrainerAsync( + Arg.Any(), + Arg.Any>(), + Arg.Any(), + Arg.Any() + ) + .Returns(Task.FromResult(TruecaserTrainer)); return factory; } @@ -564,9 +582,11 @@ public async Task WaitForAllHangfireJobsToFinishAsync() await Task.Delay(50); } - public Task WaitForBuildToFinishAsync() + public async Task WaitForBuildToFinishAsync() { - return WaitForBuildState(e => e.CurrentBuild is null); + await WaitForBuildState(e => e.CurrentBuild is null); + if (_trainJobTask is not null) + await _trainJobTask; } public Task WaitForBuildToStartAsync() @@ -594,7 +614,7 @@ private async Task WaitForBuildState(Func predicate) while (true) { TranslationEngine? engine = subscription.Change.Entity; - if (engine is not null && predicate(engine)) + if (engine is null || predicate(engine)) break; await subscription.WaitForChangeAsync(); } @@ -606,9 +626,58 @@ protected override void DisposeManagedResources() _jobServer.Dispose(); } - private Task RunMockTrainJob() + private async Task RunTrainJob() { - throw new InvalidOperationException(); + try + { + await BuildJobService.BuildJobStartedAsync("engine1", "build1", _cancellationTokenSource.Token); + + string engineDir = Path.Combine("translation_engines", EngineId1); + await SmtModelFactory.InitNewAsync(engineDir, _cancellationTokenSource.Token); + ITextCorpus sourceCorpus = new DictionaryTextCorpus(); + ITextCorpus targetCorpus = new DictionaryTextCorpus(); + IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); + LatinWordTokenizer tokenizer = new(); + using ITrainer smtModelTrainer = await SmtModelFactory.CreateTrainerAsync( + engineDir, + tokenizer, + parallelCorpus, + _cancellationTokenSource.Token + ); + using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( + engineDir, + tokenizer, + targetCorpus, + _cancellationTokenSource.Token + ); + await smtModelTrainer.TrainAsync(null, _cancellationTokenSource.Token); + await truecaseTrainer.TrainAsync(cancellationToken: _cancellationTokenSource.Token); + + await smtModelTrainer.SaveAsync(_cancellationTokenSource.Token); + await truecaseTrainer.SaveAsync(_cancellationTokenSource.Token); + + await using Stream engineStream = await SharedFileService.OpenWriteAsync( + $"builds/{BuildId1}/model.tar.gz", + _cancellationTokenSource.Token + ); + + await using Stream targetStream = await SharedFileService.OpenWriteAsync( + $"builds/{BuildId1}/pretranslate.trg.json", + _cancellationTokenSource.Token + ); + + await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + EngineId1, + BuildId1, + BuildStage.Postprocess, + data: (0, 0.0) + ); + } + catch (OperationCanceledException) + { + await BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); + } } private class EnvActivator(TestEnvironment env) : JobActivator @@ -617,23 +686,25 @@ private class EnvActivator(TestEnvironment env) : JobActivator public override object ActivateJob(Type jobType) { - if (jobType == typeof(SmtTransferPreprocessBuildJob)) + if (jobType == typeof(PreprocessBuildJob)) { - return new SmtTransferPreprocessBuildJob( + return new PreprocessBuildJob( _env.PlatformService, _env.Engines, _env._lockFactory, - Substitute.For>(), + Substitute.For>(), _env.BuildJobService, _env.SharedFileService, Substitute.For() ) { - TrainJobRunnerType = BuildJobRunnerType.Hangfire + TrainJobRunnerType = _env._trainJobRunnerType }; } if (jobType == typeof(SmtTransferPostprocessBuildJob)) { + var options = Substitute.For>(); + options.CurrentValue.Returns(new SmtTransferEngineOptions()); return new SmtTransferPostprocessBuildJob( _env.PlatformService, _env.Engines, @@ -643,7 +714,8 @@ public override object ActivateJob(Type jobType) _env.SharedFileService, _env.TrainSegmentPairs, _env.SmtModelFactory, - _env._truecaserFactory + _env._truecaserFactory, + options ); } if (jobType == typeof(SmtTransferTrainBuildJob)) @@ -656,7 +728,8 @@ public override object ActivateJob(Type jobType) Substitute.For>(), _env.SharedFileService, _env._truecaserFactory, - _env.SmtModelFactory + _env.SmtModelFactory, + _env.TransferEngineFactory ); } return base.ActivateJob(jobType); diff --git a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs index 98efc5ed0..c4806736a 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs @@ -9,6 +9,7 @@ global using Microsoft.Extensions.Logging; global using Microsoft.Extensions.Options; global using NSubstitute; +global using NSubstitute.ClearExtensions; global using NSubstitute.ReceivedExtensions; global using NUnit.Framework; global using RichardSzalay.MockHttp; @@ -16,6 +17,7 @@ global using SIL.Machine.Annotations; global using SIL.Machine.AspNetCore.Configuration; global using SIL.Machine.AspNetCore.Models; +global using SIL.Machine.Corpora; global using SIL.Machine.Tokenization; global using SIL.Machine.Translation; global using SIL.Machine.Utils; From cf5f45393b4e86b2813ca1beb2c915d88539d159 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Tue, 11 Jun 2024 13:13:57 -0400 Subject: [PATCH 3/4] Updates for Migration --- .../Models/TranslationEngine.cs | 2 +- .../Services/BuildJobService.cs | 11 +++++++++-- .../Services/NmtEngineService.cs | 7 +++++++ .../Services/SmtTransferEngineService.cs | 7 +++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs index cedb504c6..7824070e9 100644 --- a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs +++ b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs @@ -5,7 +5,7 @@ public record TranslationEngine : IEntity public string Id { get; set; } = ""; public int Revision { get; set; } = 1; public required string EngineId { get; init; } - public required TranslationEngineType Type { get; init; } + public required TranslationEngineType? Type { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } public required bool IsModelPersisted { get; init; } diff --git a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs index 406474283..a373c061e 100644 --- a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs +++ b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs @@ -77,8 +77,15 @@ public async Task StartBuildJobAsync( return false; IBuildJobRunner runner = _runners[runnerType]; + if (engine.Type is null) + { + throw new InvalidOperationException( + "Engine type is not set. This can come from a invalid migration to Serval 1.5." + ); + } + TranslationEngineType type = engine.Type.Value; string jobId = await runner.CreateJobAsync( - engine.Type, + type, engineId, buildId, stage, @@ -105,7 +112,7 @@ await _engines.UpdateAsync( ), cancellationToken: cancellationToken ); - await runner.EnqueueJobAsync(jobId, engine.Type, cancellationToken); + await runner.EnqueueJobAsync(jobId, type, cancellationToken); return true; } catch diff --git a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs index 28af9ee9a..19cc04502 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs @@ -80,6 +80,13 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { + // Update the engine type to Nmt if unset - for migrating to v1.5 + await _engines.UpdateAsync( + e => e.EngineId == engineId && e.Type == null, + u => u.Set(e => e.Type, TranslationEngineType.Nmt), + cancellationToken: cancellationToken + ); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index 941731cd8..4bf562110 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -188,6 +188,13 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { + // Update the engine type to Smt if unset - for migrating to v1.5 + await _engines.UpdateAsync( + e => e.EngineId == engineId && e.Type == null, + u => u.Set(e => e.Type, TranslationEngineType.SmtTransfer), + cancellationToken: cancellationToken + ); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { From 960ee4b4367bf15c729a627bf0a3423ef75c2dca Mon Sep 17 00:00:00 2001 From: John Lambert Date: Tue, 11 Jun 2024 13:24:52 -0400 Subject: [PATCH 4/4] Revert "Updates for Migration" This reverts commit cf5f45393b4e86b2813ca1beb2c915d88539d159. --- .../Models/TranslationEngine.cs | 2 +- .../Services/BuildJobService.cs | 11 ++--------- .../Services/NmtEngineService.cs | 7 ------- .../Services/SmtTransferEngineService.cs | 7 ------- 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs index 7824070e9..cedb504c6 100644 --- a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs +++ b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs @@ -5,7 +5,7 @@ public record TranslationEngine : IEntity public string Id { get; set; } = ""; public int Revision { get; set; } = 1; public required string EngineId { get; init; } - public required TranslationEngineType? Type { get; init; } + public required TranslationEngineType Type { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } public required bool IsModelPersisted { get; init; } diff --git a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs index a373c061e..406474283 100644 --- a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs +++ b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs @@ -77,15 +77,8 @@ public async Task StartBuildJobAsync( return false; IBuildJobRunner runner = _runners[runnerType]; - if (engine.Type is null) - { - throw new InvalidOperationException( - "Engine type is not set. This can come from a invalid migration to Serval 1.5." - ); - } - TranslationEngineType type = engine.Type.Value; string jobId = await runner.CreateJobAsync( - type, + engine.Type, engineId, buildId, stage, @@ -112,7 +105,7 @@ await _engines.UpdateAsync( ), cancellationToken: cancellationToken ); - await runner.EnqueueJobAsync(jobId, type, cancellationToken); + await runner.EnqueueJobAsync(jobId, engine.Type, cancellationToken); return true; } catch diff --git a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs index 19cc04502..28af9ee9a 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs @@ -80,13 +80,6 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - // Update the engine type to Nmt if unset - for migrating to v1.5 - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.Type == null, - u => u.Set(e => e.Type, TranslationEngineType.Nmt), - cancellationToken: cancellationToken - ); - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index 4bf562110..941731cd8 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -188,13 +188,6 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - // Update the engine type to Smt if unset - for migrating to v1.5 - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.Type == null, - u => u.Set(e => e.Type, TranslationEngineType.SmtTransfer), - cancellationToken: cancellationToken - ); - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) {