Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call 'GetAsksById' once per DoWork #504

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,32 @@ protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken
if (trainingEngines.Count == 0)
return;

Dictionary<string, ClearMLTask> tasks = new();
Dictionary<string, int> queuePositions = new();
Dictionary<string, ClearMLTask> tasks = (
await _clearMLService.GetTasksByIdAsync(
trainingEngines.Select(e => e.CurrentBuild!.JobId),
cancellationToken
)
).ToDictionary(t => t.Id);
Dictionary<TranslationEngineType, Dictionary<string, int>> queuePositionsPerEngineType = new();

foreach (TranslationEngineType engineType in _queuePerEngineType.Keys)
foreach ((TranslationEngineType engineType, string queueName) in _queuePerEngineType)
{
var tasksPerEngineType = (
await _clearMLService.GetTasksByIdAsync(
trainingEngines.Select(e => e.CurrentBuild!.JobId),
cancellationToken
)
)
.UnionBy(
await _clearMLService.GetTasksForQueueAsync(_queuePerEngineType[engineType], cancellationToken),
t => t.Id
var tasksPerEngineType = tasks
.Where(kvp =>
trainingEngines.Where(te => te.CurrentBuild?.JobId == kvp.Key).FirstOrDefault()?.Type
== engineType
)
.Select(kvp => kvp.Value)
.UnionBy(await _clearMLService.GetTasksForQueueAsync(queueName, cancellationToken), t => t.Id)
.ToDictionary(t => t.Id);
// add new keys to dictionary
foreach (KeyValuePair<string, ClearMLTask> kvp in tasksPerEngineType)
tasks.TryAdd(kvp.Key, kvp.Value);

var queuePositionsPerEngineType = tasksPerEngineType
queuePositionsPerEngineType[engineType] = tasksPerEngineType
.Values.Where(t => t.Status is ClearMLTaskStatus.Queued or ClearMLTaskStatus.Created)
.OrderBy(t => t.Created)
.Select((t, i) => (Position: i, Task: t))
.ToDictionary(e => e.Task.Name, e => e.Position);
// add new keys to dictionary
foreach (KeyValuePair<string, int> kvp in queuePositionsPerEngineType)
queuePositions.TryAdd(kvp.Key, kvp.Value);

_queueSizePerEngineType[engineType] = queuePositionsPerEngineType.Count;
_queueSizePerEngineType[engineType] = queuePositionsPerEngineType[engineType].Count;
}

var dataAccessContext = scope.ServiceProvider.GetRequiredService<IDataAccessContext>();
Expand All @@ -100,7 +96,7 @@ await UpdateTrainJobStatus(
engine.CurrentBuild.BuildId,
new ProgressStatus(step: 0, percentCompleted: 0.0),
//CurrentBuild.BuildId should always equal the corresponding task.Name
queuePositions[engine.CurrentBuild.BuildId] + 1,
queuePositionsPerEngineType[engine.Type][engine.CurrentBuild.BuildId] + 1,
cancellationToken
);
}
Expand Down
39 changes: 35 additions & 4 deletions src/Machine/src/Serval.Machine.Shared/Services/ClearMLService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ILogger<ClearMLService> logger

private readonly IClearMLAuthenticationService _clearMLAuthService = clearMLAuthService;
private readonly ILogger<ClearMLService> _logger = logger;
private readonly IDictionary<string, string> _queueNamesToIds = new ConcurrentDictionary<string, string>();

public async Task<string?> GetProjectIdAsync(string name, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -145,13 +146,40 @@ public async Task<IReadOnlyList<ClearMLTask>> GetTasksForQueueAsync(
CancellationToken cancellationToken = default
)
{
var body = new JsonObject { ["name"] = queue };
JsonObject? result = await CallAsync("queues", "get_all_ex", body, cancellationToken);
var tasks = (JsonArray?)result?["data"]?["queues"]?[0]?["entries"];
IDictionary<string, string> queueNamesToIds = await PopulateQueueNamesToIdsAsync(
cancellationToken: cancellationToken
);
if (!queueNamesToIds.TryGetValue(queue, out string? queueId))
{
queueNamesToIds = await PopulateQueueNamesToIdsAsync(refresh: true, cancellationToken);
}
var body = new JsonObject { ["queue"] = queueId ?? queueNamesToIds[queue] };
JsonObject? result = await CallAsync("queues", "get_by_id", body, cancellationToken);
var tasks = (JsonArray?)result?["data"]?["queue"]?["entries"];
IEnumerable<string> taskIds = tasks?.Select(t => (string)t?["id"]!) ?? new List<string>();
return await GetTasksByIdAsync(taskIds, cancellationToken);
}

private async Task<IDictionary<string, string>> PopulateQueueNamesToIdsAsync(
bool refresh = false,
CancellationToken cancellationToken = default
)
{
if (!refresh && _queueNamesToIds.Count > 0)
return _queueNamesToIds;
JsonObject? result = await CallAsync("queues", "get_all", new JsonObject(), cancellationToken);
var queues = (JsonArray?)result?["data"]?["queues"];
if (queues is null)
throw new InvalidOperationException("Malformed response from ClearML server.");
foreach (
KeyValuePair<string, string> kvp in queues.ToDictionary(q => (string)q!["name"]!, q => (string)q!["id"]!)
)
{
_queueNamesToIds.TryAdd(kvp.Key, kvp.Value);
}
return _queueNamesToIds;
}

public async Task<ClearMLTask?> GetTaskByNameAsync(string name, CancellationToken cancellationToken = default)
{
IReadOnlyList<ClearMLTask> tasks = await GetTasksAsync(new JsonObject { ["name"] = name }, cancellationToken);
Expand All @@ -165,7 +193,10 @@ public Task<IReadOnlyList<ClearMLTask>> GetTasksByIdAsync(
CancellationToken cancellationToken = default
)
{
return GetTasksAsync(new JsonObject { ["id"] = JsonValue.Create(ids.ToArray()) }, cancellationToken);
string[] idArray = ids.ToArray();
if (!idArray.Any())
return Task.FromResult(Array.Empty<ClearMLTask>() as IReadOnlyList<ClearMLTask>);
return GetTasksAsync(new JsonObject { ["id"] = JsonValue.Create(idArray) }, cancellationToken);
}

private async Task<IReadOnlyList<ClearMLTask>> GetTasksAsync(
Expand Down
Loading