Skip to content

Commit

Permalink
more broken
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 5, 2024
1 parent d9f6118 commit ea5aa40
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ IOptionsMonitor<BuildJobOptions> options
buildJobFactories.ToDictionary(f => f.EngineType);

private readonly Dictionary<EngineType, ClearMLBuildQueue> _options = options.CurrentValue.ClearML.ToDictionary(o =>
o.EngineType
Enum.Parse<EngineType>(o.EngineType)
);

public BuildJobRunnerType Type => BuildJobRunnerType.ClearML;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Linq;

namespace Serval.Machine.Shared.Services;
namespace Serval.Machine.Shared.Services;

public class ClearMLMonitorService(
IServiceProvider services,
Expand Down Expand Up @@ -51,7 +49,7 @@ private async Task MonitorClearMLTasksPerDomain<TEngine>(IServiceScope scope, Ca
try
{
var translationBuildJobService = scope.ServiceProvider.GetRequiredService<
IBuildJobService<TranslationEngine>
IBuildJobService<ITrainingEngine>
>();
var wordAlignmentBuildJobService = scope.ServiceProvider.GetRequiredService<
IBuildJobService<WordAlignmentEngine>
Expand Down Expand Up @@ -100,7 +98,7 @@ await _clearMLService.GetTasksByIdAsync(

var dataAccessContext = scope.ServiceProvider.GetRequiredService<IDataAccessContext>();
var platformService = scope.ServiceProvider.GetRequiredService<IPlatformService>();
foreach (TranslationEngine engine in trainingEngines)
foreach (ITrainingEngine engine in trainingEngines)
{
if (engine.CurrentBuild is null || !tasks.TryGetValue(engine.CurrentBuild.JobId, out ClearMLTask? task))
continue;
Expand Down
46 changes: 29 additions & 17 deletions src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
namespace Serval.Machine.Shared.Services;

public abstract class HangfireBuildJob(
public abstract class HangfireBuildJob<TEngine>(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IRepository<TEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
ILogger<HangfireBuildJob> logger
) : HangfireBuildJob<object?>(platformService, engines, dataAccessContext, buildJobService, logger)
IBuildJobService<TEngine> buildJobService,
ILogger<HangfireBuildJob<TEngine>> logger
) : HangfireBuildJob<TEngine, object?>(platformService, engines, dataAccessContext, buildJobService, logger)
where TEngine : ITrainingEngine
{
public virtual Task RunAsync(
string engineId,
Expand All @@ -19,24 +20,25 @@ CancellationToken cancellationToken
}
}

public abstract class HangfireBuildJob<T>(
public abstract class HangfireBuildJob<TEngine, TData>(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IRepository<TEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService buildJobService,
ILogger<HangfireBuildJob<T>> logger
IBuildJobService<TEngine> buildJobService,
ILogger<HangfireBuildJob<TEngine, TData>> logger
)
where TEngine : ITrainingEngine
{
protected IPlatformService PlatformService { get; } = platformService;
protected IRepository<TranslationEngine> Engines { get; } = engines;
protected IRepository<TEngine> Engines { get; } = engines;
protected IDataAccessContext DataAccessContext { get; } = dataAccessContext;
protected IBuildJobService BuildJobService { get; } = buildJobService;
protected ILogger<HangfireBuildJob<T>> Logger { get; } = logger;
protected IBuildJobService<TEngine> BuildJobService { get; } = buildJobService;
protected ILogger<HangfireBuildJob<TEngine, TData>> Logger { get; } = logger;

public virtual async Task RunAsync(
string engineId,
string buildId,
T data,
TData data,
string? buildOptions,
CancellationToken cancellationToken
)
Expand All @@ -56,7 +58,7 @@ CancellationToken cancellationToken
catch (OperationCanceledException)
{
// Check if the cancellation was initiated by an API call or a shutdown.
TranslationEngine? engine = await Engines.GetAsync(
TEngine? engine = await Engines.GetAsync(
e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId,
CancellationToken.None
);
Expand Down Expand Up @@ -123,20 +125,30 @@ await BuildJobService.BuildJobFinishedAsync(
}
}

protected virtual Task InitializeAsync(string engineId, string buildId, T data, CancellationToken cancellationToken)
protected virtual Task InitializeAsync(
string engineId,
string buildId,
TData data,
CancellationToken cancellationToken
)
{
return Task.CompletedTask;
}

protected abstract Task DoWorkAsync(
string engineId,
string buildId,
T data,
TData data,
string? buildOptions,
CancellationToken cancellationToken
);

protected virtual Task CleanupAsync(string engineId, string buildId, T data, JobCompletionStatus completionStatus)
protected virtual Task CleanupAsync(
string engineId,
string buildId,
TData data,
JobCompletionStatus completionStatus
)
{
return Task.CompletedTask;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
namespace Serval.Machine.Shared.Services;

public class HangfireBuildJobRunner(
public class HangfireBuildJobRunner<TEngine>(
IBackgroundJobClient jobClient,
IEnumerable<IHangfireBuildJobFactory> buildJobFactories
) : IBuildJobRunner
where TEngine : ITrainingEngine
{
public static Job CreateJob<TJob, TData>(
string engineId,
Expand All @@ -12,7 +13,7 @@ public static Job CreateJob<TJob, TData>(
object? data,
string? buildOptions
)
where TJob : HangfireBuildJob<TData>
where TJob : HangfireBuildJob<TEngine, TData>
{
ArgumentNullException.ThrowIfNull(data);
// Token "None" is used here because hangfire injects the proper cancellation token
Expand All @@ -23,7 +24,7 @@ public static Job CreateJob<TJob, TData>(
}

public static Job CreateJob<TJob>(string engineId, string buildId, string queue, string? buildOptions)
where TJob : HangfireBuildJob
where TJob : HangfireBuildJob<TEngine>
{
// Token "None" is used here because hangfire injects the proper cancellation token
return Job.FromExpression<TJob>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
public interface IBuildJobService<TEngine>
where TEngine : ITrainingEngine
{
Task<IReadOnlyList<ITrainingEngine>> GetBuildingEnginesAsync(
Task<IReadOnlyList<TEngine>> GetBuildingEnginesAsync(
BuildJobRunnerType runner,
CancellationToken cancellationToken = default
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

public interface IClearMLQueueService
{
public int GetQueueSize<TEnum>(TEnum engineType)
where TEnum : Enum;
public int GetQueueSize(EngineType engineType);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using static Serval.Machine.Shared.Services.HangfireBuildJobRunner;

namespace Serval.Machine.Shared.Services;
namespace Serval.Machine.Shared.Services;

public class NmtHangfireBuildJobFactory : IHangfireBuildJobFactory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ public class NmtPreprocessBuildJob(
IRepository<TranslationEngine> engines,
IDataAccessContext dataAccessContext,
ILogger<NmtPreprocessBuildJob> logger,
IBuildJobService buildJobService,
IBuildJobService<TranslationEngine> buildJobService,
ISharedFileService sharedFileService,
ICorpusService corpusService,
ILanguageTagService languageTagService
)
: PreprocessBuildJob(
: PreprocessBuildJob<TranslationEngine>(
platformService,
engines,
dataAccessContext,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
namespace Serval.Machine.Shared.Services;

public class PreprocessBuildJob : HangfireBuildJob<IReadOnlyList<ParallelCorpus>>
public class PreprocessBuildJob<TEngine> : HangfireBuildJob<TEngine, IReadOnlyList<ParallelCorpus>>
where TEngine : ITrainingEngine
{
private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true };

Expand All @@ -13,10 +14,10 @@ public class PreprocessBuildJob : HangfireBuildJob<IReadOnlyList<ParallelCorpus>

public PreprocessBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IRepository<TEngine> engines,
IDataAccessContext dataAccessContext,
ILogger<PreprocessBuildJob> logger,
IBuildJobService buildJobService,
ILogger<PreprocessBuildJob<TEngine>> logger,
IBuildJobService<TEngine> buildJobService,
ISharedFileService sharedFileService,
ICorpusService corpusService
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ public class SmtTransferPreprocessBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IDataAccessContext dataAccessContext,
ILogger<PreprocessBuildJob> logger,
IBuildJobService buildJobService,
ILogger<SmtTransferPreprocessBuildJob> logger,
IBuildJobService<TranslationEngine> buildJobService,
ISharedFileService sharedFileService,
ICorpusService corpusService,
IDistributedReaderWriterLockFactory lockFactory,
IRepository<TrainSegmentPair> trainSegmentPairs
)
: PreprocessBuildJob(
: PreprocessBuildJob<TranslationEngine>(
platformService,
engines,
dataAccessContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class SmtTransferTrainBuildJob(
ITruecaserFactory truecaserFactory,
ISmtModelFactory smtModelFactory,
ITransferEngineFactory transferEngineFactory
) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger)
) : HangfireBuildJob<TranslationEngine>(platformService, engines, dataAccessContext, buildJobService, logger)
{
private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true };
private static readonly JsonSerializerOptions JsonSerializerOptions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ public TestEnvironment()
]
}
);
BuildJobService = new BuildJobService(
BuildJobService = new BuildJobService<TranslationEngine>(
[
new HangfireBuildJobRunner(_jobClient, [new NmtHangfireBuildJobFactory()]),
new HangfireBuildJobRunner<TranslationEngine>(_jobClient, [new NmtHangfireBuildJobFactory()]),
new ClearMLBuildJobRunner(
ClearMLService,
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,9 @@ public TestEnvironment()
)
.Returns(Task.FromResult("job1"));
SharedFileService = new SharedFileService(Substitute.For<ILoggerFactory>());
BuildJobService = new BuildJobService(
BuildJobService = new BuildJobService<TranslationEngine>(
[
new HangfireBuildJobRunner(
new HangfireBuildJobRunner<TranslationEngine>(
Substitute.For<IBackgroundJobClient>(),
[new NmtHangfireBuildJobFactory()]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,12 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp
buildJobOptions,
Substitute.For<ILogger<ClearMLMonitorService>>()
);
BuildJobService = new BuildJobService(
BuildJobService = new BuildJobService<TranslationEngine>(
[
new HangfireBuildJobRunner(_jobClient, [new SmtTransferHangfireBuildJobFactory()]),
new HangfireBuildJobRunner<TranslationEngine>(
_jobClient,
[new SmtTransferHangfireBuildJobFactory()]
),
new ClearMLBuildJobRunner(
ClearMLService,
[new SmtTransferClearMLBuildJobFactory(SharedFileService, Engines)],
Expand Down

0 comments on commit ea5aa40

Please sign in to comment.