diff --git a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs index 8fbcdee20..28d95580e 100644 --- a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs +++ b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs @@ -37,12 +37,21 @@ public override async Task Delete(DeleteRequest request, ServerCallContex public override async Task Translate(TranslateRequest request, ServerCallContext context) { ITranslationEngineService engineService = GetEngineService(request.EngineType); - IEnumerable results = await engineService.TranslateAsync( - request.EngineId, - request.N, - request.Segment, - context.CancellationToken - ); + IEnumerable results; + try + { + results = await engineService.TranslateAsync( + request.EngineId, + request.N, + request.Segment, + context.CancellationToken + ); + } + catch (EngineNotBuiltException e) + { + throw new RpcException(new Status(StatusCode.Aborted, e.Message)); + } + return new TranslateResponse { Results = { results.Select(Map) } }; } @@ -52,11 +61,19 @@ ServerCallContext context ) { ITranslationEngineService engineService = GetEngineService(request.EngineType); - Translation.WordGraph wordGraph = await engineService.GetWordGraphAsync( - request.EngineId, - request.Segment, - context.CancellationToken - ); + Translation.WordGraph wordGraph; + try + { + wordGraph = await engineService.GetWordGraphAsync( + request.EngineId, + request.Segment, + context.CancellationToken + ); + } + catch (EngineNotBuiltException e) + { + throw new RpcException(new Status(StatusCode.Aborted, e.Message)); + } return new GetWordGraphResponse { WordGraph = Map(wordGraph) }; } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index e7619bc63..2cfccdc7e 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -70,7 +70,7 @@ public override async Task> TranslateAsync( IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); IReadOnlyList results = await hybridEngine.TranslateAsync(n, segment, cancellationToken); state.LastUsedTime = DateTime.Now; @@ -88,7 +88,7 @@ public override async Task GetWordGraphAsync( IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken); state.LastUsedTime = DateTime.Now; @@ -170,12 +170,4 @@ IReadOnlyList corpora // Token "None" is used here because hangfire injects the proper cancellation token return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None); } - - private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) - { - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new InvalidOperationException(""); - return engine; - } } diff --git a/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs b/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs index 0f3d710af..bb11a4395 100644 --- a/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs +++ b/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs @@ -209,4 +209,20 @@ CancellationToken cancellationToken } return false; } + + protected async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) + { + TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException($"Engine with id {engineId} does not exist"); + return engine; + } + + protected async Task GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + { + TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + if (engine.BuildState != BuildState.None || engine.BuildRevision == 0) + throw new EngineNotBuiltException("The engine must be built first"); + return engine; + } } diff --git a/src/SIL.Machine.AspNetCore/Utils/EngineNotBuiltException.cs b/src/SIL.Machine.AspNetCore/Utils/EngineNotBuiltException.cs new file mode 100644 index 000000000..96d52e31b --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Utils/EngineNotBuiltException.cs @@ -0,0 +1,9 @@ +namespace SIL.Machine.AspNetCore.Utils +{ + /// This exception is thrown when an unbuilt engine is requested to perform an action that requires it being built + public class EngineNotBuiltException : Exception + { + public EngineNotBuiltException(string message) + : base(message) { } + } +} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index e2a712b4d..7c69e79e4 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -10,7 +10,7 @@ public async Task StartBuildAsync() { using var env = new TestEnvironment(); TranslationEngine engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildRevision, Is.EqualTo(0)); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); //For testing purposes BuildRevision is set to 1 (i.e., an already built engine) // ensure that the SMT model was loaded before training await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); @@ -25,7 +25,7 @@ await env.TruecaserTrainer await env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); engine = env.Engines.Get("engine1"); Assert.That(engine.BuildState, Is.EqualTo(BuildState.None)); - Assert.That(engine.BuildRevision, Is.EqualTo(1)); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); //For testing purposes BuildRevision was initially set to 1 (i.e., an already built engine), so now it ought to be 2 // check if SMT model was reloaded upon first use after training env.SmtModel.ClearReceivedCalls(); await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); @@ -137,7 +137,9 @@ public TestEnvironment() Id = "engine1", EngineId = "engine1", SourceLanguage = "es", - TargetLanguage = "en" + TargetLanguage = "en", + BuildRevision = 1, + BuildState = BuildState.None, } ); TrainSegmentPairs = new MemoryRepository();