Skip to content

Commit

Permalink
Merge branch 'master' into run_e2etests_on_release_#100
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkidu93 authored Sep 14, 2023
2 parents 6063402 + 6716915 commit ee48366
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,21 @@ public override async Task<Empty> Delete(DeleteRequest request, ServerCallContex
public override async Task<TranslateResponse> Translate(TranslateRequest request, ServerCallContext context)
{
ITranslationEngineService engineService = GetEngineService(request.EngineType);
IEnumerable<Translation.TranslationResult> results = await engineService.TranslateAsync(
request.EngineId,
request.N,
request.Segment,
context.CancellationToken
);
IEnumerable<Translation.TranslationResult> results;
try
{
results = await engineService.TranslateAsync(
request.EngineId,
request.N,
request.Segment,
context.CancellationToken
);
}
catch (EngineNotBuiltException e)
{
throw new RpcException(new Status(StatusCode.Aborted, e.Message));
}

return new TranslateResponse { Results = { results.Select(Map) } };
}

Expand All @@ -52,11 +61,19 @@ ServerCallContext context
)
{
ITranslationEngineService engineService = GetEngineService(request.EngineType);
Translation.WordGraph wordGraph = await engineService.GetWordGraphAsync(
request.EngineId,
request.Segment,
context.CancellationToken
);
Translation.WordGraph wordGraph;
try
{
wordGraph = await engineService.GetWordGraphAsync(
request.EngineId,
request.Segment,
context.CancellationToken
);
}
catch (EngineNotBuiltException e)
{
throw new RpcException(new Status(StatusCode.Aborted, e.Message));
}
return new GetWordGraphResponse { WordGraph = Map(wordGraph) };
}

Expand Down
12 changes: 2 additions & 10 deletions src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public override async Task<IReadOnlyList<TranslationResult>> TranslateAsync(
IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken))
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken);
HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision);
IReadOnlyList<TranslationResult> results = await hybridEngine.TranslateAsync(n, segment, cancellationToken);
state.LastUsedTime = DateTime.Now;
Expand All @@ -88,7 +88,7 @@ public override async Task<WordGraph> GetWordGraphAsync(
IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken))
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken);
HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision);
WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken);
state.LastUsedTime = DateTime.Now;
Expand Down Expand Up @@ -170,12 +170,4 @@ IReadOnlyList<Corpus> corpora
// Token "None" is used here because hangfire injects the proper cancellation token
return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None);
}

private async Task<TranslationEngine> GetEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new InvalidOperationException("");
return engine;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,20 @@ CancellationToken cancellationToken
}
return false;
}

protected async Task<TranslationEngine> GetEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new InvalidOperationException($"Engine with id {engineId} does not exist");
return engine;
}

protected async Task<TranslationEngine> GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
if (engine.BuildState != BuildState.None || engine.BuildRevision == 0)
throw new EngineNotBuiltException("The engine must be built first");
return engine;
}
}
9 changes: 9 additions & 0 deletions src/SIL.Machine.AspNetCore/Utils/EngineNotBuiltException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace SIL.Machine.AspNetCore.Utils
{
/// </summary> This exception is thrown when an unbuilt engine is requested to perform an action that requires it being built <summary>
public class EngineNotBuiltException : Exception
{
public EngineNotBuiltException(string message)
: base(message) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public async Task StartBuildAsync()
{
using var env = new TestEnvironment();
TranslationEngine engine = env.Engines.Get("engine1");
Assert.That(engine.BuildRevision, Is.EqualTo(0));
Assert.That(engine.BuildRevision, Is.EqualTo(1)); //For testing purposes BuildRevision is set to 1 (i.e., an already built engine)
// ensure that the SMT model was loaded before training
await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba.");
await env.Service.StartBuildAsync("engine1", "build1", Array.Empty<Corpus>());
Expand All @@ -25,7 +25,7 @@ await env.TruecaserTrainer
await env.TruecaserTrainer.Received().SaveAsync(Arg.Any<CancellationToken>());
engine = env.Engines.Get("engine1");
Assert.That(engine.BuildState, Is.EqualTo(BuildState.None));
Assert.That(engine.BuildRevision, Is.EqualTo(1));
Assert.That(engine.BuildRevision, Is.EqualTo(2)); //For testing purposes BuildRevision was initially set to 1 (i.e., an already built engine), so now it ought to be 2
// check if SMT model was reloaded upon first use after training
env.SmtModel.ClearReceivedCalls();
await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba.");
Expand Down Expand Up @@ -137,7 +137,9 @@ public TestEnvironment()
Id = "engine1",
EngineId = "engine1",
SourceLanguage = "es",
TargetLanguage = "en"
TargetLanguage = "en",
BuildRevision = 1,
BuildState = BuildState.None,
}
);
TrainSegmentPairs = new MemoryRepository<TrainSegmentPair>();
Expand Down

0 comments on commit ee48366

Please sign in to comment.