diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index 5a577cb5..0e496a84 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -216,23 +216,23 @@ public static IMachineBuilder AddMongoHangfireJobClient( public static IMachineBuilder AddHangfireJobServer( this IMachineBuilder builder, - IEnumerable? engineTypes = null + IEnumerable? engineTypes = null ) { engineTypes ??= - builder.Configuration?.GetSection("TranslationEngines").Get() - ?? [TranslationEngineType.SmtTransfer, TranslationEngineType.Nmt]; + builder.Configuration?.GetSection("TranslationEngines").Get() + ?? [EngineType.SmtTransfer, EngineType.Nmt]; var queues = new List(); - foreach (TranslationEngineType engineType in engineTypes.Distinct()) + foreach (EngineType engineType in engineTypes.Distinct()) { switch (engineType) { - case TranslationEngineType.SmtTransfer: + case EngineType.SmtTransfer: builder.Services.AddSingleton(); builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); queues.Add("smt_transfer"); break; - case TranslationEngineType.Nmt: + case EngineType.Nmt: queues.Add("nmt"); break; } @@ -371,7 +371,7 @@ public static IMachineBuilder AddServalPlatformService( public static IMachineBuilder AddServalTranslationEngineService( this IMachineBuilder builder, string? connectionString = null, - IEnumerable? engineTypes = null + IEnumerable? engineTypes = null ) { builder.Services.AddGrpc(options => @@ -383,19 +383,19 @@ public static IMachineBuilder AddServalTranslationEngineService( builder.AddServalPlatformService(connectionString); engineTypes ??= - builder.Configuration?.GetSection("TranslationEngines").Get() - ?? [TranslationEngineType.SmtTransfer, TranslationEngineType.Nmt]; - foreach (TranslationEngineType engineType in engineTypes.Distinct()) + builder.Configuration?.GetSection("TranslationEngines").Get() + ?? [EngineType.SmtTransfer, EngineType.Nmt]; + foreach (EngineType engineType in engineTypes.Distinct()) { switch (engineType) { - case TranslationEngineType.SmtTransfer: + case EngineType.SmtTransfer: builder.Services.AddSingleton(); builder.Services.AddHostedService(); builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); builder.Services.AddScoped(); break; - case TranslationEngineType.Nmt: + case EngineType.Nmt: builder.Services.AddScoped(); break; } @@ -406,7 +406,8 @@ public static IMachineBuilder AddServalTranslationEngineService( public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, string? smtTransferEngineDir = null) { - builder.Services.AddScoped(); + builder.Services.AddScoped, TranslationBuildJobService>(); + builder.Services.AddScoped, BuildJobService>(); builder.Services.AddScoped(); builder.Services.AddScoped(); diff --git a/src/Machine/src/Serval.Machine.Shared/Models/ITrainingEngine.cs b/src/Machine/src/Serval.Machine.Shared/Models/ITrainingEngine.cs new file mode 100644 index 00000000..3f7af125 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Models/ITrainingEngine.cs @@ -0,0 +1,11 @@ +namespace Serval.Machine.Shared.Models; + +public interface ITrainingEngine : IEntity +{ + public string EngineId { get; init; } + public EngineType Type { get; init; } + public string SourceLanguage { get; init; } + public string TargetLanguage { get; init; } + public int BuildRevision { get; init; } + public Build? CurrentBuild { get; init; } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs index e3143a3c..53fa082b 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs @@ -1,11 +1,11 @@ namespace Serval.Machine.Shared.Models; -public record TranslationEngine : IEntity +public record TranslationEngine : ITrainingEngine { public string Id { get; set; } = ""; public int Revision { get; set; } = 1; public required string EngineId { get; init; } - public required TranslationEngineType Type { get; init; } + public required EngineType Type { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } public required bool IsModelPersisted { get; init; } diff --git a/src/Machine/src/Serval.Machine.Shared/Models/WordAlignmentEngine.cs b/src/Machine/src/Serval.Machine.Shared/Models/WordAlignmentEngine.cs index 776f305c..7739e980 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/WordAlignmentEngine.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/WordAlignmentEngine.cs @@ -1,11 +1,11 @@ namespace Serval.Machine.Shared.Models; -public record WordAlignmentEngine : IEntity +public record WordAlignmentEngine : ITrainingEngine { public string Id { get; set; } = ""; public int Revision { get; set; } = 1; public required string EngineId { get; init; } - public required WordAlignmentEngineType Type { get; init; } + public required EngineType Type { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } public int BuildRevision { get; init; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs index da670439..7f263d9f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs @@ -1,22 +1,24 @@ namespace Serval.Machine.Shared.Services; -public class BuildJobService(IEnumerable runners, IRepository engines) - : IBuildJobService +public class BuildJobService(IEnumerable runners, IRepository engines) + : IBuildJobService + where TEngine : ITrainingEngine { - private readonly Dictionary _runners = runners.ToDictionary(r => r.Type); - private readonly IRepository _engines = engines; + // TODO: make some sort of service to get the engine repos. + protected readonly Dictionary Runners = runners.ToDictionary(r => r.Type); + protected readonly IRepository Engines = engines; public Task IsEngineBuilding(string engineId, CancellationToken cancellationToken = default) { - return _engines.ExistsAsync(e => e.EngineId == engineId && e.CurrentBuild != null, cancellationToken); + return Engines.ExistsAsync(e => e.EngineId == engineId && e.CurrentBuild != null, cancellationToken); } - public Task> GetBuildingEnginesAsync( + public Task> GetBuildingEnginesAsync( BuildJobRunnerType runner, CancellationToken cancellationToken = default ) { - return _engines.GetAllAsync( + return Engines.GetAllAsync( e => e.CurrentBuild != null && e.CurrentBuild.BuildJobRunner == runner, cancellationToken ); @@ -28,7 +30,7 @@ public Task> GetBuildingEnginesAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.GetAsync( + TEngine? engine = await Engines.GetAsync( e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, cancellationToken ); @@ -41,25 +43,25 @@ public async Task CreateEngineAsync( CancellationToken cancellationToken = default ) { - foreach (BuildJobRunnerType runnerType in _runners.Keys) + foreach (BuildJobRunnerType runnerType in Runners.Keys) { - IBuildJobRunner runner = _runners[runnerType]; + IBuildJobRunner runner = Runners[runnerType]; await runner.CreateEngineAsync(engineId, name, cancellationToken); } } public async Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default) { - foreach (BuildJobRunnerType runnerType in _runners.Keys) + foreach (BuildJobRunnerType runnerType in Runners.Keys) { - IBuildJobRunner runner = _runners[runnerType]; + IBuildJobRunner runner = Runners[runnerType]; await runner.DeleteEngineAsync(engineId, cancellationToken); } } public async Task StartBuildJobAsync( BuildJobRunnerType runnerType, - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, BuildStage stage, @@ -68,7 +70,7 @@ public async Task StartBuildJobAsync( CancellationToken cancellationToken = default ) { - IBuildJobRunner runner = _runners[runnerType]; + IBuildJobRunner runner = Runners[runnerType]; string jobId = await runner.CreateJobAsync( engineType, engineId, @@ -80,7 +82,7 @@ public async Task StartBuildJobAsync( ); try { - TranslationEngine? engine = await _engines.UpdateAsync( + TEngine? engine = await Engines.UpdateAsync( e => e.EngineId == engineId && ( @@ -121,18 +123,17 @@ public async Task StartBuildJobAsync( } } - public async Task<(string? BuildId, BuildJobState State)> CancelBuildJobAsync( + public virtual async Task<(string? BuildId, BuildJobState State)> CancelBuildJobAsync( string engineId, CancellationToken cancellationToken = default ) { // cancel a job that hasn't started yet - TranslationEngine? engine = await _engines.UpdateAsync( + TEngine? engine = await Engines.UpdateAsync( e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Pending, u => { u.Unset(b => b.CurrentBuild); - u.Set(e => e.CollectTrainSegmentPairs, false); }, returnOriginal: true, cancellationToken: cancellationToken @@ -140,20 +141,20 @@ public async Task StartBuildJobAsync( if (engine is not null && engine.CurrentBuild is not null) { // job will be deleted from the queue - IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + IBuildJobRunner runner = Runners[engine.CurrentBuild.BuildJobRunner]; await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); return (engine.CurrentBuild.BuildId, BuildJobState.None); } // cancel a job that is already running - engine = await _engines.UpdateAsync( + engine = await Engines.UpdateAsync( e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Active, u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), cancellationToken: cancellationToken ); if (engine is not null && engine.CurrentBuild is not null) { - IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + IBuildJobRunner runner = Runners[engine.CurrentBuild.BuildJobRunner]; await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); } @@ -167,7 +168,7 @@ public async Task BuildJobStartedAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.UpdateAsync( + TEngine? engine = await Engines.UpdateAsync( e => e.EngineId == engineId && e.CurrentBuild != null @@ -179,19 +180,18 @@ public async Task BuildJobStartedAsync( return engine is not null; } - public Task BuildJobFinishedAsync( + public virtual Task BuildJobFinishedAsync( string engineId, string buildId, bool buildComplete, CancellationToken cancellationToken = default ) { - return _engines.UpdateAsync( + return Engines.UpdateAsync( e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, u => { u.Unset(e => e.CurrentBuild); - u.Set(e => e.CollectTrainSegmentPairs, false); if (buildComplete) u.Inc(e => e.BuildRevision); }, @@ -201,7 +201,7 @@ public Task BuildJobFinishedAsync( public Task BuildJobRestartingAsync(string engineId, string buildId, CancellationToken cancellationToken = default) { - return _engines.UpdateAsync( + return Engines.UpdateAsync( e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Pending), cancellationToken: cancellationToken diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLBuildJobRunner.cs b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLBuildJobRunner.cs index 794f1b8b..e47ac4b6 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLBuildJobRunner.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLBuildJobRunner.cs @@ -7,11 +7,12 @@ IOptionsMonitor options ) : IBuildJobRunner { private readonly IClearMLService _clearMLService = clearMLService; - private readonly Dictionary _buildJobFactories = + private readonly Dictionary _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); - private readonly Dictionary _options = - options.CurrentValue.ClearML.ToDictionary(o => o.EngineType); + private readonly Dictionary _options = options.CurrentValue.ClearML.ToDictionary(o => + o.EngineType + ); public BuildJobRunnerType Type => BuildJobRunnerType.ClearML; @@ -32,7 +33,7 @@ public async Task DeleteEngineAsync(string engineId, CancellationToken cancellat } public async Task CreateJobAsync( - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, BuildStage stage, @@ -74,7 +75,7 @@ public Task DeleteJobAsync(string jobId, CancellationToken cancellationTok public Task EnqueueJobAsync( string jobId, - TranslationEngineType engineType, + EngineType engineType, CancellationToken cancellationToken = default ) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs index b256c7c7..8cb5adf1 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs @@ -1,4 +1,6 @@ -namespace Serval.Machine.Shared.Services; +using System.Linq; + +namespace Serval.Machine.Shared.Services; public class ClearMLMonitorService( IServiceProvider services, @@ -33,16 +35,14 @@ ILogger logger buildJobOptions.CurrentValue.ClearML.ToDictionary(x => x.EngineType, x => 0) ); - public int GetQueueSize(TEnum engineType) - where TEnum : Enum + public int GetQueueSize(EngineType engineType) { return _queueSizePerEngineType[engineType.ToString()]; } protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) { - await MonitorClearMLTasksPerDomain(scope, cancellationToken); - await MonitorClearMLTasksPerDomain(scope, cancellationToken); + await MonitorClearMLTasksPerDomain(scope, cancellationToken); } private async Task MonitorClearMLTasksPerDomain(IServiceScope scope, CancellationToken cancellationToken) @@ -50,10 +50,22 @@ private async Task MonitorClearMLTasksPerDomain(IServiceScope scope, Ca { try { - var buildJobService = scope.ServiceProvider.GetRequiredService(); - IReadOnlyList trainingEngines = await buildJobService.GetBuildingEnginesAsync( - BuildJobRunnerType.ClearML, - cancellationToken + var translationBuildJobService = scope.ServiceProvider.GetRequiredService< + IBuildJobService + >(); + var wordAlignmentBuildJobService = scope.ServiceProvider.GetRequiredService< + IBuildJobService + >(); + + Dictionary> trainingEngines = ( + await translationBuildJobService.GetBuildingEnginesAsync(BuildJobRunnerType.ClearML, cancellationToken) + ).ToDictionary(e => e, e => translationBuildJobService as IBuildJobService); + + trainingEngines.AddRange( + await wordAlignmentBuildJobService.GetBuildingEnginesAsync( + BuildJobRunnerType.ClearML, + cancellationToken + ) ); if (trainingEngines.Count == 0) return; @@ -245,7 +257,7 @@ private async Task TrainJobStartedAsync( private async Task TrainJobCompletedAsync( IBuildJobService buildJobService, - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, int corpusSize, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/EngineType.cs b/src/Machine/src/Serval.Machine.Shared/Services/EngineType.cs new file mode 100644 index 00000000..9a0a7a3a --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/EngineType.cs @@ -0,0 +1,28 @@ +namespace Serval.Machine.Shared.Services; + +public enum EngineType +{ + SmtTransfer, + Nmt, + Statistical +} + +public enum EngineGroup +{ + Translation, + WordAlignment +} + +public static class EngineHelper +{ + public static EngineGroup GetEngineGroup(EngineType engineType) + { + return engineType switch + { + EngineType.SmtTransfer => EngineGroup.Translation, + EngineType.Nmt => EngineGroup.Translation, + EngineType.Statistical => EngineGroup.WordAlignment, + _ => throw new ArgumentOutOfRangeException(nameof(engineType), engineType, null) + }; + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJobRunner.cs b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJobRunner.cs index d5be7f30..96830dfd 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJobRunner.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJobRunner.cs @@ -33,7 +33,7 @@ public static Job CreateJob(string engineId, string buildId, string queue, } private readonly IBackgroundJobClient _jobClient = jobClient; - private readonly Dictionary _buildJobFactories = + private readonly Dictionary _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); public BuildJobRunnerType Type => BuildJobRunnerType.Hangfire; @@ -49,7 +49,7 @@ public Task DeleteEngineAsync(string engineId, CancellationToken cancellationTok } public Task CreateJobAsync( - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, BuildStage stage, @@ -70,7 +70,7 @@ public Task DeleteJobAsync(string jobId, CancellationToken cancellationTok public Task EnqueueJobAsync( string jobId, - TranslationEngineType engineType, + EngineType engineType, CancellationToken cancellationToken = default ) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobRunner.cs b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobRunner.cs index 6f6d3696..0c04cbde 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobRunner.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobRunner.cs @@ -8,7 +8,7 @@ public interface IBuildJobRunner Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); Task CreateJobAsync( - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, BuildStage stage, @@ -19,11 +19,7 @@ Task CreateJobAsync( Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default); - Task EnqueueJobAsync( - string jobId, - TranslationEngineType engineType, - CancellationToken cancellationToken = default - ); + Task EnqueueJobAsync(string jobId, EngineType engineType, CancellationToken cancellationToken = default); Task StopJobAsync(string jobId, CancellationToken cancellationToken = default); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs index 61c6122e..2a1ebf17 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs @@ -1,8 +1,9 @@ namespace Serval.Machine.Shared.Services; -public interface IBuildJobService +public interface IBuildJobService + where TEngine : ITrainingEngine { - Task> GetBuildingEnginesAsync( + Task> GetBuildingEnginesAsync( BuildJobRunnerType runner, CancellationToken cancellationToken = default ); @@ -15,7 +16,7 @@ Task> GetBuildingEnginesAsync( Task StartBuildJobAsync( BuildJobRunnerType runnerType, - TranslationEngineType engineType, + EngineType engineType, string engineId, string buildId, BuildStage stage, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IClearMLBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/IClearMLBuildJobFactory.cs index bb5afc57..fe265fc6 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IClearMLBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IClearMLBuildJobFactory.cs @@ -2,7 +2,7 @@ public interface IClearMLBuildJobFactory { - TranslationEngineType EngineType { get; } + EngineType EngineType { get; } Task CreateJobScriptAsync( string engineId, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IClearMLQueueService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IClearMLQueueService.cs index 1e2425a4..d2e17ebd 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IClearMLQueueService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IClearMLQueueService.cs @@ -2,5 +2,6 @@ public interface IClearMLQueueService { - public int GetQueueSize(TranslationEngineType engineType); + public int GetQueueSize(TEnum engineType) + where TEnum : Enum; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/IHangfireBuildJobFactory.cs index faabcfec..e57ac8c5 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IHangfireBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IHangfireBuildJobFactory.cs @@ -2,7 +2,7 @@ public interface IHangfireBuildJobFactory { - TranslationEngineType EngineType { get; } + EngineType EngineType { get; } Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ITranslationEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ITranslationEngineService.cs index b9e64472..3d4f983a 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ITranslationEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ITranslationEngineService.cs @@ -2,7 +2,7 @@ public interface ITranslationEngineService { - TranslationEngineType Type { get; } + EngineType Type { get; } Task CreateAsync( string engineId, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs index 8495efce..867857b3 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs @@ -2,14 +2,13 @@ public interface IWordAlignmentEngineService { - WordAlignmentEngineType WordAlignmentEngine { get; } + EngineType Type { get; } Task CreateAsync( string engineId, string? engineName, string sourceLanguage, string targetLanguage, - bool? isModelPersisted = null, CancellationToken cancellationToken = default ); Task DeleteAsync(string engineId, CancellationToken cancellationToken = default); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ModelCleanupService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ModelCleanupService.cs index 92b38d6a..da638014 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ModelCleanupService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ModelCleanupService.cs @@ -26,12 +26,12 @@ internal async Task CheckModelsAsync(IRepository engines, Can // Get all NMT engine ids from the database IReadOnlyList? allEngines = await engines.GetAllAsync(cancellationToken: cancellationToken); IEnumerable validNmtFilenames = allEngines - .Where(e => e.Type == TranslationEngineType.Nmt) + .Where(e => e.Type == EngineType.Nmt) .Select(e => NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision)); // If there is a currently running build that creates and pushes a new file, but the database has not // updated yet, don't delete the new file. IEnumerable validNmtFilenamesForNextBuild = allEngines - .Where(e => e.Type == TranslationEngineType.Nmt) + .Where(e => e.Type == EngineType.Nmt) .Select(e => NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision + 1)); var filenameFilter = validNmtFilenames.Concat(validNmtFilenamesForNextBuild).ToHashSet(); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtClearMLBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtClearMLBuildJobFactory.cs index 4f465936..ee2b07bc 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtClearMLBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtClearMLBuildJobFactory.cs @@ -10,7 +10,7 @@ IRepository engines private readonly ILanguageTagService _languageTagService = languageTagService; private readonly IRepository _engines = engines; - public TranslationEngineType EngineType => TranslationEngineType.Nmt; + public EngineType EngineType => EngineType.Nmt; public async Task CreateJobScriptAsync( string engineId, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs index 0dd66544..3b1cfec4 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs @@ -24,7 +24,7 @@ public static string GetModelPath(string engineId, int buildRevision) return $"{ModelDirectory}{engineId}_{buildRevision}.tar.gz"; } - public TranslationEngineType Type => TranslationEngineType.Nmt; + public EngineType Type => EngineType.Nmt; private const int MinutesToExpire = 60; @@ -45,7 +45,7 @@ public async Task CreateAsync( EngineId = engineId, SourceLanguage = sourceLanguage, TargetLanguage = targetLanguage, - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, IsModelPersisted = isModelPersisted ?? false // models are not persisted if not specified }; await _engines.InsertAsync(translationEngine, ct); @@ -75,7 +75,7 @@ public async Task StartBuildAsync( { bool building = !await _buildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.Nmt, + EngineType.Nmt, engineId, buildId, BuildStage.Preprocess, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs index 4d250188..3cce8f53 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtHangfireBuildJobFactory.cs @@ -4,7 +4,7 @@ namespace Serval.Machine.Shared.Services; public class NmtHangfireBuildJobFactory : IHangfireBuildJobFactory { - public TranslationEngineType EngineType => TranslationEngineType.Nmt; + public EngineType EngineType => EngineType.Nmt; public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs index dfc52263..e93c86b7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationEngineServiceV1.cs @@ -8,8 +8,9 @@ public class ServalTranslationEngineServiceV1(IEnumerable _engineServices = - engineServices.ToDictionary(es => es.Type); + private readonly Dictionary _engineServices = engineServices.ToDictionary( + es => es.Type + ); public override async Task Create(CreateRequest request, ServerCallContext context) { @@ -175,10 +176,10 @@ private ITranslationEngineService GetEngineService(string engineTypeStr) throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); } - private static TranslationEngineType GetEngineType(string engineTypeStr) + private static EngineType GetEngineType(string engineTypeStr) { engineTypeStr = engineTypeStr[0].ToString().ToUpperInvariant() + engineTypeStr[1..]; - if (System.Enum.TryParse(engineTypeStr, out TranslationEngineType engineType)) + if (System.Enum.TryParse(engineTypeStr, out EngineType engineType)) return engineType; throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs index 2010c237..21d208df 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs @@ -8,8 +8,9 @@ public class ServalWordAlignmentEngineServiceV1(IEnumerable _engineServices = - engineServices.ToDictionary(es => es.WordAlignmentEngine); + private readonly Dictionary _engineServices = engineServices.ToDictionary( + es => es.Type + ); public override async Task Create(CreateRequest request, ServerCallContext context) { @@ -19,7 +20,6 @@ await engineService.CreateAsync( request.HasEngineName ? request.EngineName : null, request.SourceLanguage, request.TargetLanguage, - isModelPersisted: true, cancellationToken: context.CancellationToken ); return Empty; @@ -103,10 +103,10 @@ private IWordAlignmentEngineService GetEngineService(string engineTypeStr) throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); } - private static WordAlignmentEngineType GetEngineType(string engineTypeStr) + private static EngineType GetEngineType(string engineTypeStr) { engineTypeStr = engineTypeStr[0].ToString().ToUpperInvariant() + engineTypeStr[1..]; - if (System.Enum.TryParse(engineTypeStr, out WordAlignmentEngineType engineType)) + if (System.Enum.TryParse(engineTypeStr, out EngineType engineType)) return engineType; throw new RpcException(new Status(StatusCode.InvalidArgument, "The engine type is invalid.")); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferClearMLBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferClearMLBuildJobFactory.cs index 6e0b6b9c..fe97eaeb 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferClearMLBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferClearMLBuildJobFactory.cs @@ -8,7 +8,7 @@ IRepository engines private readonly ISharedFileService _sharedFileService = sharedFileService; private readonly IRepository _engines = engines; - public TranslationEngineType EngineType => TranslationEngineType.SmtTransfer; + public EngineType EngineType => EngineType.SmtTransfer; public async Task CreateJobScriptAsync( string engineId, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs index 60e8c9a8..88f388ac 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs @@ -9,7 +9,7 @@ public class SmtTransferEngineService( SmtTransferEngineStateService stateService, IBuildJobService buildJobService, IClearMLQueueService clearMLQueueService -) : ITranslationEngineService, IWordAlignmentEngineService +) : ITranslationEngineService { private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; private readonly IPlatformService _platformService = platformService; @@ -20,8 +20,7 @@ IClearMLQueueService clearMLQueueService private readonly IBuildJobService _buildJobService = buildJobService; private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService; - public TranslationEngineType Type => TranslationEngineType.SmtTransfer; - public WordAlignmentEngineType WordAlignmentType => WordAlignmentEngineType.Statistical; + public EngineType Type => EngineType.SmtTransfer; public async Task CreateAsync( string engineId, @@ -48,7 +47,7 @@ public async Task CreateAsync( EngineId = engineId, SourceLanguage = sourceLanguage, TargetLanguage = targetLanguage, - Type = TranslationEngineType.SmtTransfer, + Type = EngineType.SmtTransfer, IsModelPersisted = isModelPersisted ?? true // models are persisted if not specified }; await _engines.InsertAsync(translationEngine, ct); @@ -187,7 +186,7 @@ public async Task StartBuildAsync( { bool building = !await _buildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.SmtTransfer, + EngineType.SmtTransfer, engineId, buildId, BuildStage.Preprocess, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs index 71f2d09a..18fda0fc 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs @@ -4,7 +4,7 @@ namespace Serval.Machine.Shared.Services; public class SmtTransferHangfireBuildJobFactory : IHangfireBuildJobFactory { - public TranslationEngineType EngineType => TranslationEngineType.SmtTransfer; + public EngineType EngineType => EngineType.SmtTransfer; public Job CreateJob(string engineId, string buildId, BuildStage stage, object? data, string? buildOptions) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs index e81fc354..025bbf84 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs @@ -55,7 +55,7 @@ CancellationToken cancellationToken bool canceling = !await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.SmtTransfer, + EngineType.SmtTransfer, engineId, buildId, BuildStage.Postprocess, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs index 63fcf111..8ccb46c4 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs @@ -18,7 +18,7 @@ IClearMLQueueService clearMLQueueService private readonly IBuildJobService _buildJobService = buildJobService; private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService; - public WordAlignmentEngineType Type => WordAlignmentEngineType.Statistical; + public EngineType Type => EngineType.Statistical; public async Task CreateAsync( string engineId, @@ -36,7 +36,7 @@ public async Task CreateAsync( EngineId = engineId, SourceLanguage = sourceLanguage, TargetLanguage = targetLanguage, - Type = WordAlignmentEngineType.Statistical, + Type = EngineType.Statistical, }; await _engines.InsertAsync(waEngine, ct); await _buildJobService.CreateEngineAsync(engineId, engineName, ct); @@ -71,31 +71,6 @@ await _dataAccessContext.WithTransactionAsync( await _lockFactory.DeleteAsync(engineId, CancellationToken.None); } - public async Task> TranslateAsync( - string engineId, - int n, - string segment, - CancellationToken cancellationToken = default - ) - { - TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); - SmtTransferEngineState state = _stateService.Get(engineId); - - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - IReadOnlyList results = await @lock.ReaderLockAsync( - async ct => - { - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); - // there is no way to cancel this call - return hybridEngine.Translate(n, segment); - }, - cancellationToken: cancellationToken - ); - - state.Touch(); - return results; - } - public async Task StartBuildAsync( string engineId, string buildId, @@ -106,7 +81,7 @@ public async Task StartBuildAsync( { bool building = !await _buildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.SmtTransfer, + EngineType.Statistical, engineId, buildId, BuildStage.Preprocess, @@ -137,11 +112,6 @@ public int GetQueueSize() return _clearMLQueueService.GetQueueSize(Type); } - public bool IsLanguageNativeToModel(string language, out string internalCode) - { - throw new NotSupportedException("SMT transfer engines do not support language info."); - } - private async Task CancelBuildJobAsync(string engineId, CancellationToken cancellationToken) { string? buildId = null; @@ -165,17 +135,17 @@ public Task GetModelDownloadUrlAsync( throw new NotSupportedException(); } - private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) + private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) { - TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + WordAlignmentEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); if (engine is null) throw new InvalidOperationException($"The engine {engineId} does not exist."); return engine; } - private async Task GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + private async Task GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken) { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + WordAlignmentEngine engine = await GetEngineAsync(engineId, cancellationToken); if (engine.BuildRevision == 0) throw new EngineNotBuiltException("The engine must be built first."); return engine; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TranslationBuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/TranslationBuildJobService.cs new file mode 100644 index 00000000..8dd2c493 --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/TranslationBuildJobService.cs @@ -0,0 +1,65 @@ +namespace Serval.Machine.Shared.Services; + +public class TranslationBuildJobService(IEnumerable runners, IRepository engines) + : BuildJobService(runners, engines) +{ + public override async Task<(string? BuildId, BuildJobState State)> CancelBuildJobAsync( + string engineId, + CancellationToken cancellationToken = default + ) + { + // cancel a job that hasn't started yet + TranslationEngine? engine = await Engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Pending, + u => + { + u.Unset(b => b.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); + }, + returnOriginal: true, + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) + { + // job will be deleted from the queue + IBuildJobRunner runner = Runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.None); + } + + // cancel a job that is already running + engine = await Engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Active, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) + { + IBuildJobRunner runner = Runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); + } + + return (null, BuildJobState.None); + } + + public override Task BuildJobFinishedAsync( + string engineId, + string buildId, + bool buildComplete, + CancellationToken cancellationToken = default + ) + { + return Engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + u => + { + u.Unset(e => e.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); + if (buildComplete) + u.Inc(e => e.BuildRevision); + }, + cancellationToken: cancellationToken + ); + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TranslationEngineType.cs b/src/Machine/src/Serval.Machine.Shared/Services/TranslationEngineType.cs deleted file mode 100644 index 61df1966..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Services/TranslationEngineType.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Serval.Machine.Shared.Services; - -public enum TranslationEngineType -{ - SmtTransfer, - Nmt -} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineType.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineType.cs deleted file mode 100644 index a8ed74cb..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineType.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Serval.Machine.Shared.Services; - -public enum WordAlignmentEngineType -{ - Statistical, -} diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/ModelCleanupServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/ModelCleanupServiceTests.cs index 49923372..88797059 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/ModelCleanupServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/ModelCleanupServiceTests.cs @@ -46,7 +46,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engineId1", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -58,7 +58,7 @@ public TestEnvironment() { Id = "engine2", EngineId = "engineId2", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 2, diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtClearMLBuildJobFactoryTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtClearMLBuildJobFactoryTests.cs index 439b8d7c..f5e5ceaa 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtClearMLBuildJobFactoryTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtClearMLBuildJobFactoryTests.cs @@ -81,7 +81,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs index f0d131a1..ace30e5c 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs @@ -88,7 +88,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -132,14 +132,14 @@ public TestEnvironment() [ new ClearMLBuildQueue() { - EngineType = TranslationEngineType.Nmt.ToString(), + EngineType = EngineType.Nmt.ToString(), ModelType = "huggingface", DockerImage = "default", Queue = "default" }, new ClearMLBuildQueue() { - EngineType = TranslationEngineType.SmtTransfer.ToString(), + EngineType = EngineType.SmtTransfer.ToString(), ModelType = "thot", DockerImage = "default", Queue = "default" @@ -262,7 +262,7 @@ private async Task RunNormalTrainJob() await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.Nmt, + EngineType.Nmt, "engine1", "build1", BuildStage.Postprocess, diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs index 439fa94e..5ff509a1 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs @@ -232,7 +232,7 @@ public async Task RunAsync_UnknownLanguageTagsNoDataSmtTransfer() using TestEnvironment env = new(); ParallelCorpus corpus1 = TestEnvironment.TextFileCorpus(sourceLanguage: "xxx", targetLanguage: "zzz"); - await env.RunBuildJobAsync(corpus1, engineId: "engine2", engineType: TranslationEngineType.SmtTransfer); + await env.RunBuildJobAsync(corpus1, engineId: "engine2", engineType: EngineType.SmtTransfer); } [Test] @@ -643,7 +643,7 @@ public TestEnvironment() { Id = "engine1", EngineId = "engine1", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -663,7 +663,7 @@ public TestEnvironment() { Id = "engine2", EngineId = "engine2", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "xxx", TargetLanguage = "zzz", BuildRevision = 1, @@ -683,7 +683,7 @@ public TestEnvironment() { Id = "engine2", EngineId = "engine2", - Type = TranslationEngineType.Nmt, + Type = EngineType.Nmt, SourceLanguage = "xxx", TargetLanguage = "zzz", BuildRevision = 1, @@ -715,14 +715,14 @@ public TestEnvironment() [ new ClearMLBuildQueue() { - EngineType = TranslationEngineType.Nmt.ToString(), + EngineType = EngineType.Nmt.ToString(), ModelType = "huggingface", DockerImage = "default", Queue = "default" }, new ClearMLBuildQueue() { - EngineType = TranslationEngineType.SmtTransfer.ToString(), + EngineType = EngineType.SmtTransfer.ToString(), ModelType = "thot", DockerImage = "default", Queue = "default" @@ -772,11 +772,11 @@ [new NmtHangfireBuildJobFactory()] ); } - public PreprocessBuildJob GetBuildJob(TranslationEngineType engineType) + public PreprocessBuildJob GetBuildJob(EngineType engineType) { switch (engineType) { - case TranslationEngineType.Nmt: + case EngineType.Nmt: { return new NmtPreprocessBuildJob( PlatformService, @@ -792,7 +792,7 @@ public PreprocessBuildJob GetBuildJob(TranslationEngineType engineType) Seed = 1234 }; } - case TranslationEngineType.SmtTransfer: + case EngineType.SmtTransfer: { return new SmtTransferPreprocessBuildJob( PlatformService, @@ -941,7 +941,7 @@ public Task RunBuildJobAsync( ParallelCorpus corpus, bool useKeyTerms = true, string engineId = "engine1", - TranslationEngineType engineType = TranslationEngineType.Nmt + EngineType engineType = EngineType.Nmt ) { return RunBuildJobAsync([corpus], useKeyTerms, engineId, engineType); @@ -951,7 +951,7 @@ public Task RunBuildJobAsync( IEnumerable corpora, bool useKeyTerms = true, string engineId = "engine1", - TranslationEngineType engineType = TranslationEngineType.Nmt + EngineType engineType = EngineType.Nmt ) { return GetBuildJob(engineType) diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs index f5fabd6e..000d7634 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs @@ -238,7 +238,7 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp { Id = EngineId1, EngineId = EngineId1, - Type = TranslationEngineType.SmtTransfer, + Type = EngineType.SmtTransfer, SourceLanguage = "es", TargetLanguage = "en", BuildRevision = 1, @@ -277,14 +277,14 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp [ new ClearMLBuildQueue() { - EngineType = TranslationEngineType.Nmt.ToString().ToString(), + EngineType = EngineType.Nmt.ToString().ToString(), ModelType = "huggingface", DockerImage = "default", Queue = "default" }, new ClearMLBuildQueue() { - EngineType = TranslationEngineType.SmtTransfer.ToString(), + EngineType = EngineType.SmtTransfer.ToString(), ModelType = "thot", DockerImage = "default", Queue = "default" @@ -659,7 +659,7 @@ private async Task RunTrainJob() await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, - TranslationEngineType.SmtTransfer, + EngineType.SmtTransfer, EngineId1, BuildId1, BuildStage.Postprocess,