Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes Throw error for unbuilt engine #81 #82

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,21 @@ public override async Task<Empty> Delete(DeleteRequest request, ServerCallContex
public override async Task<TranslateResponse> Translate(TranslateRequest request, ServerCallContext context)
{
ITranslationEngineService engineService = GetEngineService(request.EngineType);
IEnumerable<Translation.TranslationResult> results = await engineService.TranslateAsync(
request.EngineId,
request.N,
request.Segment,
context.CancellationToken
);
IEnumerable<Translation.TranslationResult> results;
try
{
results = await engineService.TranslateAsync(
request.EngineId,
request.N,
request.Segment,
context.CancellationToken
);
}
catch (EngineNotBuiltException e)
{
throw new RpcException(new Status(StatusCode.Aborted, e.Message));
}

return new TranslateResponse { Results = { results.Select(Map) } };
}

Expand All @@ -52,11 +61,19 @@ ServerCallContext context
)
{
ITranslationEngineService engineService = GetEngineService(request.EngineType);
Translation.WordGraph wordGraph = await engineService.GetWordGraphAsync(
request.EngineId,
request.Segment,
context.CancellationToken
);
Translation.WordGraph wordGraph;
try
{
wordGraph = await engineService.GetWordGraphAsync(
request.EngineId,
request.Segment,
context.CancellationToken
);
}
catch (EngineNotBuiltException e)
{
throw new RpcException(new Status(StatusCode.Aborted, e.Message));
}
return new GetWordGraphResponse { WordGraph = Map(wordGraph) };
}

Expand Down
12 changes: 2 additions & 10 deletions src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public override async Task<IReadOnlyList<TranslationResult>> TranslateAsync(
IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken))
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken);
HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision);
IReadOnlyList<TranslationResult> results = await hybridEngine.TranslateAsync(n, segment, cancellationToken);
state.LastUsedTime = DateTime.Now;
Expand All @@ -88,7 +88,7 @@ public override async Task<WordGraph> GetWordGraphAsync(
IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken))
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken);
HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision);
WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken);
state.LastUsedTime = DateTime.Now;
Expand Down Expand Up @@ -170,12 +170,4 @@ IReadOnlyList<Corpus> corpora
// Token "None" is used here because hangfire injects the proper cancellation token
return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None);
}

private async Task<TranslationEngine> GetEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new InvalidOperationException("");
return engine;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,20 @@ CancellationToken cancellationToken
}
return false;
}

protected async Task<TranslationEngine> GetEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new InvalidOperationException($"Engine with id {engineId} does not exist");
return engine;
}

protected async Task<TranslationEngine> GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
if (engine.BuildState != BuildState.None || engine.BuildRevision == 0)
throw new EngineNotBuiltException("The engine must be built first");
return engine;
}
}
9 changes: 9 additions & 0 deletions src/SIL.Machine.AspNetCore/Utils/EngineNotBuiltException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace SIL.Machine.AspNetCore.Utils
{
/// </summary> This exception is thrown when an unbuilt engine is requested to perform an action that requires it being built <summary>
public class EngineNotBuiltException : Exception
{
public EngineNotBuiltException(string message)
: base(message) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public async Task StartBuildAsync()
{
using var env = new TestEnvironment();
TranslationEngine engine = env.Engines.Get("engine1");
Assert.That(engine.BuildRevision, Is.EqualTo(0));
Assert.That(engine.BuildRevision, Is.EqualTo(1)); //For testing purposes BuildRevision is set to 1 (i.e., an already built engine)
// ensure that the SMT model was loaded before training
await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba.");
await env.Service.StartBuildAsync("engine1", "build1", Array.Empty<Corpus>());
Expand All @@ -25,7 +25,7 @@ await env.TruecaserTrainer
await env.TruecaserTrainer.Received().SaveAsync(Arg.Any<CancellationToken>());
engine = env.Engines.Get("engine1");
Assert.That(engine.BuildState, Is.EqualTo(BuildState.None));
Assert.That(engine.BuildRevision, Is.EqualTo(1));
Assert.That(engine.BuildRevision, Is.EqualTo(2)); //For testing purposes BuildRevision was initially set to 1 (i.e., an already built engine), so now it ought to be 2
// check if SMT model was reloaded upon first use after training
env.SmtModel.ClearReceivedCalls();
await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba.");
Expand Down Expand Up @@ -137,7 +137,9 @@ public TestEnvironment()
Id = "engine1",
EngineId = "engine1",
SourceLanguage = "es",
TargetLanguage = "en"
TargetLanguage = "en",
BuildRevision = 1,
BuildState = BuildState.None,
}
);
TrainSegmentPairs = new MemoryRepository<TrainSegmentPair>();
Expand Down
Loading