Skip to content

Commit

Permalink
Refactor ClearML NMT build job
Browse files Browse the repository at this point in the history
- add support for multiple build stages
- add support for running build jobs on Hangfire or ClearML
- add BuildJobService
- categorize build jobs into CPU or GPU jobs
- decouple build job runners from translation engines
- fix issues with S3FileStorage
- fix issues with ClearMLService
  • Loading branch information
ddaspit committed Oct 6, 2023
1 parent 62ddd19 commit c8a27cc
Show file tree
Hide file tree
Showing 66 changed files with 2,954 additions and 1,776 deletions.
9 changes: 9 additions & 0 deletions src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace SIL.Machine.AspNetCore.Configuration;

public class BuildJobOptions
{
public const string Key = "BuildJob";

public Dictionary<BuildJobType, BuildJobRunner> Runners { get; set; } =
new() { { BuildJobType.Cpu, BuildJobRunner.Hangfire }, { BuildJobType.Gpu, BuildJobRunner.ClearML } };
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
namespace SIL.Machine.AspNetCore.Configuration;

public class ClearMLNmtEngineOptions
public class ClearMLOptions
{
public const string Key = "ClearMLNmtEngine";
public const string Key = "ClearML";

public string ApiServer { get; set; } = "http://localhost:8008";
public string Queue { get; set; } = "default";
public string AccessKey { get; set; } = "";
public string SecretKey { get; set; } = "";
public TimeSpan BuildPollingTimeout { get; set; } = TimeSpan.FromSeconds(2);
public bool BuildPollingEnabled { get; set; } = false;
public TimeSpan BuildPollingTimeout { get; set; } = TimeSpan.FromSeconds(10);
public string ModelType { get; set; } = "huggingface";
public int MaxSteps { get; set; } = 20_000;
public string RootProject { get; set; } = "Machine";
Expand Down
121 changes: 89 additions & 32 deletions src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.AspNetCore.Http;
using Serval.Translation.V1;
using Serval.Translation.V1;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -35,18 +34,18 @@ public static IMachineBuilder AddSmtTransferEngineOptions(this IMachineBuilder b
return builder;
}

public static IMachineBuilder AddClearMLNmtEngineOptions(
public static IMachineBuilder AddClearMLOptions(
this IMachineBuilder builder,
Action<ClearMLNmtEngineOptions> configureOptions
Action<ClearMLOptions> configureOptions
)
{
builder.Services.Configure(configureOptions);
return builder;
}

public static IMachineBuilder AddClearMLNmtEngineOptions(this IMachineBuilder builder, IConfiguration config)
public static IMachineBuilder AddClearMLOptions(this IMachineBuilder builder, IConfiguration config)
{
builder.Services.Configure<ClearMLNmtEngineOptions>(config);
builder.Services.Configure<ClearMLOptions>(config);
return builder;
}

Expand All @@ -67,8 +66,10 @@ public static IMachineBuilder AddSharedFileOptions(this IMachineBuilder builder,

public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder)
{
builder.Services.AddSingleton<ISmtModelFactory, ThotSmtModelFactory>();
return builder;
if (builder.Configuration is null)
return builder.AddThotSmtModel(o => { });
else
return builder.AddThotSmtModel(builder.Configuration.GetSection(ThotSmtModelOptions.Key));
}

public static IMachineBuilder AddThotSmtModel(
Expand All @@ -77,13 +78,15 @@ Action<ThotSmtModelOptions> configureOptions
)
{
builder.Services.Configure(configureOptions);
return builder.AddThotSmtModel();
builder.Services.AddSingleton<ISmtModelFactory, ThotSmtModelFactory>();
return builder;
}

public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder, IConfiguration config)
{
builder.Services.Configure<ThotSmtModelOptions>(config);
return builder.AddThotSmtModel();
builder.Services.AddSingleton<ISmtModelFactory, ThotSmtModelFactory>();
return builder;
}

public static IMachineBuilder AddTransferEngine(this IMachineBuilder builder)
Expand All @@ -98,7 +101,7 @@ public static IMachineBuilder AddUnigramTruecaser(this IMachineBuilder builder)
return builder;
}

public static IMachineBuilder AddClearMLService(this IMachineBuilder builder)
private static IMachineBuilder AddClearMLBuildJobRunner(this IMachineBuilder builder)
{
builder.Services.AddSingleton<IClearMLService, ClearMLService>();
//Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds
Expand All @@ -111,20 +114,33 @@ public static IMachineBuilder AddClearMLService(this IMachineBuilder builder)
// workaround register satisfying the interface and as a hosted service.
builder.Services.AddSingleton<IClearMLAuthenticationService, ClearMLAuthenticationService>();
builder.Services.AddHostedService(p => p.GetRequiredService<IClearMLAuthenticationService>());
//Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds
// Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds
builder.Services
.AddHttpClient<IClearMLAuthenticationService, ClearMLAuthenticationService>()
.AddTransientHttpErrorPolicy(
b => b.WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(Math.Pow(2, retryAttempt)))
);

builder.Services.AddSingleton<S3HealthCheck>();
builder.Services.AddScoped<IBuildJobRunner, ClearMLBuildJobRunner>();
builder.Services.AddScoped<IClearMLBuildJobFactory, NmtClearMLBuildJobFactory>();
builder.Services.AddHostedService<ClearMLMonitorService>();

builder.Services.AddHealthChecks().AddCheck<ClearMLHealthCheck>("ClearML Health Check");

return builder;
}

public static IMachineBuilder AddMongoBackgroundJobClient(
private static IMachineBuilder AddHangfireBuildJobRunner(this IMachineBuilder builder)
{
builder.Services.AddScoped<IBuildJobRunner, HangfireBuildJobRunner>();

builder.Services.AddScoped<IHangfireBuildJobFactory, SmtTransferHangfireBuildJobFactory>();
builder.Services.AddScoped<IHangfireBuildJobFactory, NmtHangfireBuildJobFactory>();

return builder;
}

public static IMachineBuilder AddMongoHangfireJobClient(
this IMachineBuilder builder,
string? connectionString = null
)
Expand All @@ -147,12 +163,13 @@ public static IMachineBuilder AddMongoBackgroundJobClient(
CheckQueuedJobsStrategy = CheckQueuedJobsStrategy.TailNotificationsCollection,
}
)
.UseFilter(new AutomaticRetryAttribute { Attempts = 0 })
);
builder.Services.AddHealthChecks().AddCheck<HangfireHealthCheck>(name: "Hangfire");
return builder;
}

public static IMachineBuilder AddBackgroundJobServer(
public static IMachineBuilder AddHangfireJobServer(
this IMachineBuilder builder,
IEnumerable<TranslationEngineType>? engineTypes = null
)
Expand All @@ -170,7 +187,6 @@ public static IMachineBuilder AddBackgroundJobServer(
queues.Add("smt_transfer");
break;
case TranslationEngineType.Nmt:
builder.AddClearMLService();
queues.Add("nmt");
break;
}
Expand Down Expand Up @@ -205,28 +221,24 @@ public static IMachineBuilder AddMongoDataAccess(this IMachineBuilder builder, s
{
o.AddRepository<TranslationEngine>(
"translation_engines",
init: c =>
c.Indexes.CreateOrUpdateAsync(
new CreateIndexModel<TranslationEngine>(
Builders<TranslationEngine>.IndexKeys.Ascending(p => p.EngineId)
)
)
);
o.AddRepository<RWLock>(
"locks",
mapSetup: m => m.SetIgnoreExtraElements(true),
init: async c =>
{
await c.Indexes.CreateOrUpdateAsync(
new CreateIndexModel<RWLock>(Builders<RWLock>.IndexKeys.Ascending("writerLock._id"))
);
await c.Indexes.CreateOrUpdateAsync(
new CreateIndexModel<RWLock>(Builders<RWLock>.IndexKeys.Ascending("readerLocks._id"))
new CreateIndexModel<TranslationEngine>(
Builders<TranslationEngine>.IndexKeys
.Ascending(e => e.EngineId)
.Ascending("currentBuild._id")
)
);
await c.Indexes.CreateOrUpdateAsync(
new CreateIndexModel<RWLock>(Builders<RWLock>.IndexKeys.Ascending("writerQueue._id"))
new CreateIndexModel<TranslationEngine>(
Builders<TranslationEngine>.IndexKeys.Ascending(e => e.CurrentBuild!.JobRunner)
)
);
}
);
o.AddRepository<RWLock>("locks");
o.AddRepository<TrainSegmentPair>(
"train_segment_pairs",
init: c =>
Expand Down Expand Up @@ -313,13 +325,58 @@ public static IMachineBuilder AddServalTranslationEngineService(
builder.Services.AddScoped<ITranslationEngineService, SmtTransferEngineService>();
break;
case TranslationEngineType.Nmt:
builder.AddClearMLService();
builder.Services.AddScoped<ITranslationEngineService, ClearMLNmtEngineService>();
builder.Services.AddScoped<ITranslationEngineService, NmtEngineService>();
break;
}
}
builder.Services.AddGrpcHealthChecks();

return builder;
}

public static IMachineBuilder AddBuildJobService(
this IMachineBuilder builder,
Action<BuildJobOptions> configureOptions
)
{
builder.Services.Configure(configureOptions);
var options = new BuildJobOptions();
configureOptions(options);
return builder.AddBuildJobService(options);
}

public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, IConfiguration config)
{
builder.Services.Configure<BuildJobOptions>(config);
var options = config.Get<BuildJobOptions>();
return builder.AddBuildJobService(options);
}

public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder)
{
if (builder.Configuration is null)
builder.AddBuildJobService(o => { });
else
builder.AddBuildJobService(builder.Configuration.GetSection(BuildJobOptions.Key));
return builder;
}

private static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, BuildJobOptions options)
{
builder.Services.AddScoped<IBuildJobService, BuildJobService>();

foreach (BuildJobRunner runnerType in options.Runners.Values.Distinct())
{
switch (runnerType)
{
case BuildJobRunner.ClearML:
builder.AddClearMLBuildJobRunner();
break;
case BuildJobRunner.Hangfire:
builder.AddHangfireBuildJobRunner();
break;
}
}
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ public static class IServiceCollectionExtensions
{
public static IMachineBuilder AddMachine(this IServiceCollection services, IConfiguration? configuration = null)
{
if (!Sldr.IsInitialized)
Sldr.Initialize();

services.AddSingleton<ISharedFileService, SharedFileService>();
services.AddSingleton<S3HealthCheck>();
services.AddHealthChecks().AddCheck<S3HealthCheck>("S3 Bucket");

services.AddScoped<IDistributedReaderWriterLockFactory, DistributedReaderWriterLockFactory>();
services.AddSingleton<ICorpusService, CorpusService>();
services.AddStartupTask((sp, ct) => sp.GetRequiredService<IDistributedReaderWriterLockFactory>().InitAsync(ct));
Expand All @@ -17,14 +21,14 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf
builder.AddServiceOptions(o => { });
builder.AddSharedFileOptions(o => { });
builder.AddSmtTransferEngineOptions(o => { });
builder.AddClearMLNmtEngineOptions(o => { });
builder.AddClearMLOptions(o => { });
}
else
{
builder.AddServiceOptions(configuration.GetSection(ServiceOptions.Key));
builder.AddSharedFileOptions(configuration.GetSection(SharedFileOptions.Key));
builder.AddSmtTransferEngineOptions(configuration.GetSection(SmtTransferEngineOptions.Key));
builder.AddClearMLNmtEngineOptions(configuration.GetSection(ClearMLNmtEngineOptions.Key));
builder.AddClearMLOptions(configuration.GetSection(ClearMLOptions.Key));
}
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,4 @@ public class SmtTransferEngineOptions
public string EnginesDir { get; set; } = "translation_engines";
public TimeSpan EngineCommitFrequency { get; set; } = TimeSpan.FromMinutes(5);
public TimeSpan InactiveEngineTimeout { get; set; } = TimeSpan.FromMinutes(10);
public ISet<TranslationEngineType> Types { get; set; } =
new HashSet<TranslationEngineType> { TranslationEngineType.Nmt, TranslationEngineType.SmtTransfer };
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

public class ThotSmtModelOptions
{
public const string ThotSmtModel = "ThotSmtModel";
public const string Key = "ThotSmtModel";

public ThotSmtModelOptions()
{
Expand Down
24 changes: 24 additions & 0 deletions src/SIL.Machine.AspNetCore/Models/Build.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace SIL.Machine.AspNetCore.Models;

public enum BuildJobState
{
None,
Pending,
Active,
Canceling
}

public enum BuildJobRunner
{
Hangfire,
ClearML
}

public class Build
{
public string BuildId { get; set; } = default!;
public BuildJobState JobState { get; set; }
public string JobId { get; set; } = default!;
public BuildJobRunner JobRunner { get; set; }
public string Stage { get; set; } = default!;
}
12 changes: 12 additions & 0 deletions src/SIL.Machine.AspNetCore/Models/ClearMLMetricsEvent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace SIL.Machine.AspNetCore.Models;

public class ClearMLMetricsEvent
{
public string Metric { get; set; } = default!;
public string Variant { get; set; } = default!;
public double Value { get; set; }
public double MinValue { get; set; }
public int MinValueIteration { get; set; }
public double MaxValue { get; set; }
public int MaxValueIteration { get; set; }
}
1 change: 1 addition & 0 deletions src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ public class ClearMLTask
public string StatusMessage { get; set; } = default!;
public int LastIteration { get; set; }
public int ActiveDuration { get; set; }
public Dictionary<string, Dictionary<string, ClearMLMetricsEvent>> LastMetrics { get; set; } = default!;
}
12 changes: 1 addition & 11 deletions src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
namespace SIL.Machine.AspNetCore.Models;

public enum BuildState
{
None,
Pending,
Active
}

public class TranslationEngine : IEntity
{
public string Id { get; set; } = default!;
public int Revision { get; set; } = 1;
public string EngineId { get; set; } = default!;
public string SourceLanguage { get; set; } = default!;
public string TargetLanguage { get; set; } = default!;
public BuildState BuildState { get; set; } = BuildState.None;
public bool IsCanceled { get; set; }
public string? BuildId { get; set; }
public int BuildRevision { get; set; }
public string? JobId { get; set; }
public Build? CurrentBuild { get; set; }
}
14 changes: 7 additions & 7 deletions src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@

<ItemGroup>
<PackageReference Include="AspNetCore.HealthChecks.MongoDb" Version="6.0.2" />
<PackageReference Include="AWSSDK.S3" Version="3.7.201.7" />
<PackageReference Include="Grpc.AspNetCore" Version="2.52.0" />
<PackageReference Include="Grpc.AspNetCore.HealthChecks" Version="2.52.0" />
<PackageReference Include="HangFire" Version="1.8.0" />
<PackageReference Include="Hangfire.MemoryStorage" Version="1.7.0" />
<PackageReference Include="Hangfire.Mongo" Version="1.9.3" />
<PackageReference Include="AWSSDK.S3" Version="3.7.205.8" />
<PackageReference Include="Grpc.AspNetCore" Version="2.57.0" />
<PackageReference Include="Grpc.AspNetCore.HealthChecks" Version="2.57.0" />
<PackageReference Include="HangFire" Version="1.8.5" />
<PackageReference Include="Hangfire.Mongo" Version="1.9.10" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.NewtonsoftJson" Version="6.0.16" />
<PackageReference Include="Microsoft.Extensions.Http.Polly" Version="6.0.14" />
<PackageReference Include="Python.Included" Version="3.11.4" />
<PackageReference Include="Serval.Grpc" Version="0.8.0" Condition="!Exists('..\..\..\serval\src\Serval.Grpc\Serval.Grpc.csproj')" />
<PackageReference Include="SIL.DataAccess" Version="0.5.1" Condition="!Exists('..\..\..\serval\src\SIL.DataAccess\SIL.DataAccess.csproj')" />
<PackageReference Include="SIL.DataAccess" Version="0.5.2" Condition="!Exists('..\..\..\serval\src\SIL.DataAccess\SIL.DataAccess.csproj')" />
<PackageReference Include="SIL.WritingSystems" Version="12.0.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>
Expand Down
Loading

0 comments on commit c8a27cc

Please sign in to comment.