Skip to content

Commit

Permalink
it compiles.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 5, 2024
1 parent ea5aa40 commit ac8bc21
Show file tree
Hide file tree
Showing 19 changed files with 119 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ public int GetQueueSize(EngineType engineType)

protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken)
{
await MonitorClearMLTasksPerDomain<EngineType>(scope, cancellationToken);
await MonitorClearMLTasksPerDomain(scope, cancellationToken);
}

private async Task MonitorClearMLTasksPerDomain<TEngine>(IServiceScope scope, CancellationToken cancellationToken)
where TEngine : Enum
private async Task MonitorClearMLTasksPerDomain(IServiceScope scope, CancellationToken cancellationToken)
{
try
{
Expand All @@ -55,22 +54,26 @@ private async Task MonitorClearMLTasksPerDomain<TEngine>(IServiceScope scope, Ca
IBuildJobService<WordAlignmentEngine>
>();

Dictionary<ITrainingEngine, IBuildJobService<ITrainingEngine>> trainingEngines = (
Dictionary<ITrainingEngine, IBuildJobService<ITrainingEngine>> engineToBuildServiceDict = (
await translationBuildJobService.GetBuildingEnginesAsync(BuildJobRunnerType.ClearML, cancellationToken)
).ToDictionary(e => e, e => translationBuildJobService as IBuildJobService<ITrainingEngine>);
).ToDictionary(e => e, e => translationBuildJobService);

trainingEngines.AddRange(
await wordAlignmentBuildJobService.GetBuildingEnginesAsync(
foreach (
var engine in await wordAlignmentBuildJobService.GetBuildingEnginesAsync(
BuildJobRunnerType.ClearML,
cancellationToken
)
);
if (trainingEngines.Count == 0)
)
{
engineToBuildServiceDict[engine] = (IBuildJobService<ITrainingEngine>)wordAlignmentBuildJobService;
}

if (engineToBuildServiceDict.Count == 0)
return;

Dictionary<string, ClearMLTask> tasks = (
await _clearMLService.GetTasksByIdAsync(
trainingEngines.Select(e => e.CurrentBuild!.JobId),
engineToBuildServiceDict.Select(e => e.Key.CurrentBuild!.JobId),
cancellationToken
)
).ToDictionary(t => t.Id);
Expand All @@ -80,8 +83,11 @@ await _clearMLService.GetTasksByIdAsync(
{
var tasksPerEngineType = tasks
.Where(kvp =>
trainingEngines.Where(te => te.CurrentBuild?.JobId == kvp.Key).FirstOrDefault()?.Type.ToString()
== engineType
engineToBuildServiceDict
.Where(te => te.Key.CurrentBuild?.JobId == kvp.Key)
.FirstOrDefault()
.Key?.Type
.ToString() == engineType
)
.Select(kvp => kvp.Value)
.UnionBy(await _clearMLService.GetTasksForQueueAsync(queueName, cancellationToken), t => t.Id)
Expand All @@ -98,7 +104,7 @@ await _clearMLService.GetTasksByIdAsync(

var dataAccessContext = scope.ServiceProvider.GetRequiredService<IDataAccessContext>();
var platformService = scope.ServiceProvider.GetRequiredService<IPlatformService>();
foreach (ITrainingEngine engine in trainingEngines)
foreach (ITrainingEngine engine in engineToBuildServiceDict.Keys)
{
if (engine.CurrentBuild is null || !tasks.TryGetValue(engine.CurrentBuild.JobId, out ClearMLTask? task))
continue;
Expand Down Expand Up @@ -131,7 +137,7 @@ or ClearMLTaskStatus.Completed
{
bool canceled = !await TrainJobStartedAsync(
dataAccessContext,
buildJobService,
engineToBuildServiceDict[engine],
platformService,
engine.EngineId,
engine.CurrentBuild.BuildId,
Expand Down Expand Up @@ -170,7 +176,7 @@ await UpdateTrainJobStatus(
cancellationToken
);
bool canceling = !await TrainJobCompletedAsync(
buildJobService,
engineToBuildServiceDict[engine],
engine.Type,
engine.EngineId,
engine.CurrentBuild.BuildId,
Expand All @@ -183,7 +189,7 @@ await UpdateTrainJobStatus(
{
await TrainJobCanceledAsync(
dataAccessContext,
buildJobService,
engineToBuildServiceDict[engine],
platformService,
engine.EngineId,
engine.CurrentBuild.BuildId,
Expand All @@ -197,7 +203,7 @@ await TrainJobCanceledAsync(
{
await TrainJobCanceledAsync(
dataAccessContext,
buildJobService,
engineToBuildServiceDict[engine],
platformService,
engine.EngineId,
engine.CurrentBuild.BuildId,
Expand All @@ -210,7 +216,7 @@ await TrainJobCanceledAsync(
{
await TrainJobFaultedAsync(
dataAccessContext,
buildJobService,
engineToBuildServiceDict[engine],
platformService,
engine.EngineId,
engine.CurrentBuild.BuildId,
Expand All @@ -231,7 +237,7 @@ await TrainJobFaultedAsync(

private async Task<bool> TrainJobStartedAsync(
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<ITrainingEngine> buildJobService,
IPlatformService platformService,
string engineId,
string buildId,
Expand All @@ -254,7 +260,7 @@ private async Task<bool> TrainJobStartedAsync(
}

private async Task<bool> TrainJobCompletedAsync(
IBuildJobService buildJobService,
IBuildJobService<ITrainingEngine> buildJobService,
EngineType engineType,
string engineId,
string buildId,
Expand Down Expand Up @@ -285,7 +291,7 @@ CancellationToken cancellationToken

private async Task TrainJobFaultedAsync(
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<ITrainingEngine> buildJobService,
IPlatformService platformService,
string engineId,
string buildId,
Expand Down Expand Up @@ -318,7 +324,7 @@ await buildJobService.BuildJobFinishedAsync(

private async Task TrainJobCanceledAsync(
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<ITrainingEngine> buildJobService,
IPlatformService platformService,
string engineId,
string buildId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
namespace Serval.Machine.Shared.Services;

public class HangfireBuildJobRunner<TEngine>(
public class HangfireBuildJobRunner(
IBackgroundJobClient jobClient,
IEnumerable<IHangfireBuildJobFactory> buildJobFactories
) : IBuildJobRunner
where TEngine : ITrainingEngine
{
public static Job CreateJob<TJob, TData>(
public static Job CreateJob<TEngine, TJob, TData>(
string engineId,
string buildId,
string queue,
object? data,
string? buildOptions
)
where TEngine : ITrainingEngine
where TJob : HangfireBuildJob<TEngine, TData>
{
ArgumentNullException.ThrowIfNull(data);
Expand All @@ -23,7 +23,8 @@ public static Job CreateJob<TJob, TData>(
);
}

public static Job CreateJob<TJob>(string engineId, string buildId, string queue, string? buildOptions)
public static Job CreateJob<TEngine, TJob>(string engineId, string buildId, string queue, string? buildOptions)
where TEngine : ITrainingEngine
where TJob : HangfireBuildJob<TEngine>
{
// Token "None" is used here because hangfire injects the proper cancellation token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Task<WordAlignmentEngine> CreateAsync(
Task DeleteAsync(string engineId, CancellationToken cancellationToken = default);

Task<TranslationResult> GetBestPhraseAlignmentAsync(
string engineId,
string sourceSegment,
string targetSegment,
CancellationToken cancellationToken = default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public class NmtEngineService(
IPlatformService platformService,
IDataAccessContext dataAccessContext,
IRepository<TranslationEngine> engines,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
ILanguageTagService languageTagService,
IClearMLQueueService clearMLQueueService,
ISharedFileService sharedFileService
Expand All @@ -13,7 +13,7 @@ ISharedFileService sharedFileService
private readonly IPlatformService _platformService = platformService;
private readonly IDataAccessContext _dataAccessContext = dataAccessContext;
private readonly IRepository<TranslationEngine> _engines = engines;
private readonly IBuildJobService _buildJobService = buildJobService;
private readonly IBuildJobService<TranslationEngine> _buildJobService = buildJobService;
private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService;
private readonly ILanguageTagService _languageTagService = languageTagService;
private readonly ISharedFileService _sharedFileService = sharedFileService;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace Serval.Machine.Shared.Services;
using static Serval.Machine.Shared.Services.HangfireBuildJobRunner;

namespace Serval.Machine.Shared.Services;

public class NmtHangfireBuildJobFactory : IHangfireBuildJobFactory
{
Expand All @@ -9,15 +11,21 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object?
return stage switch
{
BuildStage.Preprocess
=> CreateJob<NmtPreprocessBuildJob, IReadOnlyList<ParallelCorpus>>(
=> CreateJob<TranslationEngine, NmtPreprocessBuildJob, IReadOnlyList<ParallelCorpus>>(
engineId,
buildId,
"nmt",
data,
buildOptions
),
BuildStage.Postprocess
=> CreateJob<PostprocessBuildJob, (int, double)>(engineId, buildId, "nmt", data, buildOptions),
=> CreateJob<TranslationEngine, PostprocessBuildJob, (int, double)>(
engineId,
buildId,
"nmt",
data,
buildOptions
),
_ => throw new ArgumentException("Unknown build stage.", nameof(stage)),
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@ public class PostprocessBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
ILogger<PostprocessBuildJob> logger,
ISharedFileService sharedFileService,
IOptionsMonitor<BuildJobOptions> options
) : HangfireBuildJob<(int, double)>(platformService, engines, dataAccessContext, buildJobService, logger)
)
: HangfireBuildJob<TranslationEngine, (int, double)>(
platformService,
engines,
dataAccessContext,
buildJobService,
logger
)
{
protected ISharedFileService SharedFileService { get; } = sharedFileService;
private readonly BuildJobOptions _buildJobOptions = options.CurrentValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected override async Task DoWorkAsync(
CancellationToken cancellationToken
)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
TEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new OperationCanceledException($"Engine {engineId} does not exist. Build canceled.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ServerCallContext context
try
{
result = await engineService.GetBestPhraseAlignmentAsync(
request.EngineId,
request.SourceSegment,
request.TargetSegment,
context.CancellationToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class SmtTransferEngineService(
IRepository<TranslationEngine> engines,
IRepository<TrainSegmentPair> trainSegmentPairs,
SmtTransferEngineStateService stateService,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
IClearMLQueueService clearMLQueueService
) : ITranslationEngineService
{
Expand All @@ -17,7 +17,7 @@ IClearMLQueueService clearMLQueueService
private readonly IRepository<TranslationEngine> _engines = engines;
private readonly IRepository<TrainSegmentPair> _trainSegmentPairs = trainSegmentPairs;
private readonly SmtTransferEngineStateService _stateService = stateService;
private readonly IBuildJobService _buildJobService = buildJobService;
private readonly IBuildJobService<TranslationEngine> _buildJobService = buildJobService;
private readonly IClearMLQueueService _clearMLQueueService = clearMLQueueService;

public EngineType Type => EngineType.SmtTransfer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,28 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object?
return stage switch
{
BuildStage.Preprocess
=> CreateJob<SmtTransferPreprocessBuildJob, IReadOnlyList<ParallelCorpus>>(
=> CreateJob<TranslationEngine, SmtTransferPreprocessBuildJob, IReadOnlyList<ParallelCorpus>>(
engineId,
buildId,
"smt_transfer",
data,
buildOptions
),
BuildStage.Postprocess
=> CreateJob<SmtTransferPostprocessBuildJob, (int, double)>(
=> CreateJob<TranslationEngine, SmtTransferPostprocessBuildJob, (int, double)>(
engineId,
buildId,
"smt_transfer",
data,
buildOptions
),
BuildStage.Train => CreateJob<SmtTransferTrainBuildJob>(engineId, buildId, "smt_transfer", buildOptions),
BuildStage.Train
=> CreateJob<TranslationEngine, SmtTransferTrainBuildJob>(
engineId,
buildId,
"smt_transfer",
buildOptions
),
_ => throw new ArgumentException("Unknown build stage.", nameof(stage)),
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public class SmtTransferPostprocessBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
ILogger<SmtTransferPostprocessBuildJob> logger,
ISharedFileService sharedFileService,
IDistributedReaderWriterLockFactory lockFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public class SmtTransferTrainBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
ILogger<SmtTransferTrainBuildJob> logger,
ISharedFileService sharedFileService,
ITruecaserFactory truecaserFactory,
Expand Down
Loading

0 comments on commit ac8bc21

Please sign in to comment.