diff --git a/src/SIL.Machine.Tool/AlignCommand.cs b/src/SIL.Machine.Tool/AlignCommand.cs index 7888c7396..a00750632 100644 --- a/src/SIL.Machine.Tool/AlignCommand.cs +++ b/src/SIL.Machine.Tool/AlignCommand.cs @@ -39,7 +39,7 @@ public AlignCommand() ); _symHeuristicOption = Option( "-sh|--sym-heuristic ", - $"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\", \"{ToolHelpers.None}\".", + $"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\", \"{SymmetrizationHelpers.None}\".", CommandOptionType.SingleValue ); _scoresOption = Option("-s|--scores", "Include scores in the output.", CommandOptionType.NoValue); @@ -53,7 +53,7 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell if (code != 0) return code; - if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value())) + if (!SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value())) { Out.WriteLine("The specified symmetrization heuristic is invalid."); return 1; @@ -75,7 +75,9 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell int processorCount = Environment.ProcessorCount; - SymmetrizationHeuristic symHeuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption?.Value()); + SymmetrizationHeuristic symHeuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic( + _symHeuristicOption?.Value() + ); if (!_quietOption.HasValue()) Out.Write("Loading model... "); diff --git a/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs b/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs index b40c4fe11..a6804d6e6 100644 --- a/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs +++ b/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs @@ -37,7 +37,7 @@ public void AddParameters(CommandBase command) _modelArgument = command.Argument("MODEL_PATH", "The word alignment model.").IsRequired(); _modelTypeOption = command.Option( "-mt|--model-type ", - $"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.Ibm3}\", \"{ToolHelpers.Ibm4}\", \"{ToolHelpers.FastAlign}\".", + $"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.Ibm3}\", \"{ThotWordAlignmentHelpers.Ibm4}\", \"{ThotWordAlignmentHelpers.FastAlign}\".", CommandOptionType.SingleValue ); _pluginOption = command.Option( @@ -90,7 +90,7 @@ public IWordAlignmentModel CreateAlignmentModel( ThotWordAlignmentModelType modelType = ThotWordAlignmentModelType.Hmm; if (_modelTypeOption.HasValue()) { - modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); + modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); } else { @@ -103,7 +103,7 @@ public IWordAlignmentModel CreateAlignmentModel( yaml.Load(reader); var root = (YamlMappingNode)yaml.Documents.First().RootNode; var modelTypeStr = (string)root[new YamlScalarNode("model")]; - modelType = ToolHelpers.GetThotWordAlignmentModelType(modelTypeStr); + modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(modelTypeStr); } } } @@ -144,7 +144,9 @@ public ITrainer CreateAlignmentModelTrainer( return _modelFactory.CreateTrainer(_modelArgument.Value, corpus, maxSize, parameters, direct); } - ThotWordAlignmentModelType modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); + ThotWordAlignmentModelType modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( + _modelTypeOption.Value() + ); string modelPath = _modelArgument.Value; if (ToolHelpers.IsDirectoryPath(modelPath)) @@ -227,12 +229,12 @@ private static bool ValidateAlignmentModelTypeOption(string value, IEnumerable { - ToolHelpers.Hmm, - ToolHelpers.Ibm1, - ToolHelpers.Ibm2, - ToolHelpers.FastAlign, - ToolHelpers.Ibm3, - ToolHelpers.Ibm4 + ThotWordAlignmentHelpers.Hmm, + ThotWordAlignmentHelpers.Ibm1, + ThotWordAlignmentHelpers.Ibm2, + ThotWordAlignmentHelpers.FastAlign, + ThotWordAlignmentHelpers.Ibm3, + ThotWordAlignmentHelpers.Ibm4 }; validTypes.UnionWith(pluginTypes); return string.IsNullOrEmpty(value) || validTypes.Contains(value); diff --git a/src/SIL.Machine.Tool/SymmetrizeCommand.cs b/src/SIL.Machine.Tool/SymmetrizeCommand.cs index eaa524206..7a403b8ee 100644 --- a/src/SIL.Machine.Tool/SymmetrizeCommand.cs +++ b/src/SIL.Machine.Tool/SymmetrizeCommand.cs @@ -37,7 +37,7 @@ public SymmetrizeCommand() ); _symHeuristicOption = Option( "-sh|--sym-heuristic ", - $"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\".", + $"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\".", CommandOptionType.SingleValue ); _quietOption = Option("-q|--quiet", "Only display results.", CommandOptionType.NoValue); @@ -67,14 +67,21 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell return 1; } - if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption.Value(), noneAllowed: false)) + if ( + !SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption( + _symHeuristicOption.Value(), + noneAllowed: false + ) + ) { Out.WriteLine("The specified symmetrization heuristic is invalid."); return 1; } string outputFormat = _outputFormatOption.Value() ?? Pharaoh; - SymmetrizationHeuristic heuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption.Value()); + SymmetrizationHeuristic heuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic( + _symHeuristicOption.Value() + ); using var directReader = new StreamReader(_directArgument.Value); using var inverseReader = new StreamReader(_inverseArgument.Value); diff --git a/src/SIL.Machine.Tool/ToolHelpers.cs b/src/SIL.Machine.Tool/ToolHelpers.cs index 2c3b00ed1..497075a26 100644 --- a/src/SIL.Machine.Tool/ToolHelpers.cs +++ b/src/SIL.Machine.Tool/ToolHelpers.cs @@ -14,22 +14,6 @@ namespace SIL.Machine; internal static class ToolHelpers { - public const string FastAlign = "fast_align"; - public const string Ibm1 = "ibm1"; - public const string Ibm2 = "ibm2"; - public const string Hmm = "hmm"; - public const string Ibm3 = "ibm3"; - public const string Ibm4 = "ibm4"; - - public const string Och = "och"; - public const string Union = "union"; - public const string Intersection = "intersection"; - public const string Grow = "grow"; - public const string GrowDiag = "grow-diag"; - public const string GrowDiagFinal = "grow-diag-final"; - public const string GrowDiagFinalAnd = "grow-diag-final-and"; - public const string None = "none"; - public static bool ValidateCorpusFormatOption(string value) { return string.IsNullOrEmpty(value) || value.ToLowerInvariant().IsOneOf("dbl", "usx", "text", "pt", "pt_m"); @@ -154,29 +138,14 @@ public static string GetTranslationModelConfigFileName(string path) public static bool ValidateTranslationModelTypeOption(string value) { - var validTypes = new HashSet { Hmm, Ibm1, Ibm2, FastAlign }; - return string.IsNullOrEmpty(value) || validTypes.Contains(value); - } - - public static ThotWordAlignmentModelType GetThotWordAlignmentModelType(string modelType) - { - switch (modelType) + var validTypes = new HashSet { - case "fastAlign": - case FastAlign: - return ThotWordAlignmentModelType.FastAlign; - case Ibm1: - return ThotWordAlignmentModelType.Ibm1; - case Ibm2: - return ThotWordAlignmentModelType.Ibm2; - default: - case Hmm: - return ThotWordAlignmentModelType.Hmm; - case Ibm3: - return ThotWordAlignmentModelType.Ibm3; - case Ibm4: - return ThotWordAlignmentModelType.Ibm4; - } + ThotWordAlignmentHelpers.Hmm, + ThotWordAlignmentHelpers.Ibm1, + ThotWordAlignmentHelpers.Ibm2, + ThotWordAlignmentHelpers.FastAlign + }; + return string.IsNullOrEmpty(value) || validTypes.Contains(value); } public static ITrainer CreateTranslationModelTrainer( @@ -186,7 +155,9 @@ public static ITrainer CreateTranslationModelTrainer( int maxSize ) { - ThotWordAlignmentModelType wordAlignmentModelType = GetThotWordAlignmentModelType(modelType); + ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( + modelType + ); string modelDir = Path.GetDirectoryName(modelConfigFileName); if (!Directory.Exists(modelDir)) @@ -201,38 +172,6 @@ int maxSize }; } - public static bool ValidateSymmetrizationHeuristicOption(string value, bool noneAllowed = true) - { - var validHeuristics = new HashSet - { - Och, - Union, - Intersection, - Grow, - GrowDiag, - GrowDiagFinal, - GrowDiagFinalAnd - }; - if (noneAllowed) - validHeuristics.Add(None); - return string.IsNullOrEmpty(value) || validHeuristics.Contains(value.ToLowerInvariant()); - } - - public static SymmetrizationHeuristic GetSymmetrizationHeuristic(string value) - { - return value switch - { - None => SymmetrizationHeuristic.None, - Union => SymmetrizationHeuristic.Union, - Intersection => SymmetrizationHeuristic.Intersection, - Grow => SymmetrizationHeuristic.Grow, - GrowDiag => SymmetrizationHeuristic.GrowDiag, - GrowDiagFinal => SymmetrizationHeuristic.GrowDiagFinal, - GrowDiagFinalAnd => SymmetrizationHeuristic.GrowDiagFinalAnd, - _ => SymmetrizationHeuristic.Och, - }; - } - public static StreamWriter CreateStreamWriter(string fileName) { var utf8Encoding = new UTF8Encoding(false); diff --git a/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs b/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs index 3ca50cc79..8523d691b 100644 --- a/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs +++ b/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs @@ -22,7 +22,7 @@ public void AddParameters(CommandBase command) _modelArgument = command.Argument("MODEL_PATH", "The translation model.").IsRequired(); _modelTypeOption = command.Option( "-mt|--model-type ", - $"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.FastAlign}\".", + $"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.FastAlign}\".", CommandOptionType.SingleValue ); _pluginOption = command.Option( @@ -67,7 +67,7 @@ public ITranslationModel CreateModel() if (_modelFactory != null) return _modelFactory.CreateModel(_modelArgument.Value); - ThotWordAlignmentModelType wordAlignmentModelType = ToolHelpers.GetThotWordAlignmentModelType( + ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( _modelTypeOption.Value() ); diff --git a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj index e7e5b9b27..b51de93f4 100644 --- a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj +++ b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj @@ -13,6 +13,7 @@ + diff --git a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs index b10135bd5..bad13bcf0 100644 --- a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs +++ b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using SIL.Extensions; using SIL.Machine.Corpora; +using SIL.Machine.Tokenization; using SIL.ObjectModel; namespace SIL.Machine.Translation.Thot @@ -116,6 +117,11 @@ public void CreateNew(string prefFileName) } public ITrainer CreateTrainer(IParallelTextCorpus corpus) + { + return CreateTrainer(corpus, null); + } + + public ITrainer CreateTrainer(IParallelTextCorpus corpus, ITokenizer tokenizer = null) { CheckDisposed(); @@ -126,7 +132,13 @@ public ITrainer CreateTrainer(IParallelTextCorpus corpus) ); } - return new Trainer(this, corpus); + var trainer = new Trainer(this, corpus); + if (tokenizer != null) + { + trainer.SourceTokenizer = tokenizer; + trainer.TargetTokenizer = tokenizer; + } + return trainer; } public Task SaveAsync() diff --git a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs index 1e3142fde..27cb264e2 100644 --- a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs +++ b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs @@ -1,4 +1,7 @@ -namespace SIL.Machine.Translation.Thot +using System; +using CaseExtensions; + +namespace SIL.Machine.Translation.Thot { public enum ThotWordAlignmentModelType { @@ -9,4 +12,39 @@ public enum ThotWordAlignmentModelType Ibm3, Ibm4 } + + public static class ThotWordAlignmentHelpers + { + public const string FastAlign = "fast_align"; + public const string Ibm1 = "ibm1"; + public const string Ibm2 = "ibm2"; + public const string Hmm = "hmm"; + public const string Ibm3 = "ibm3"; + public const string Ibm4 = "ibm4"; + + public static ThotWordAlignmentModelType GetThotWordAlignmentModelType( + string modelType, + ThotWordAlignmentModelType? defaultType = null + ) + { + switch (modelType.ToSnakeCase()) + { + case FastAlign: + return ThotWordAlignmentModelType.FastAlign; + case Ibm1: + return ThotWordAlignmentModelType.Ibm1; + case Ibm2: + return ThotWordAlignmentModelType.Ibm2; + case Hmm: + return ThotWordAlignmentModelType.Hmm; + case Ibm3: + return ThotWordAlignmentModelType.Ibm3; + case Ibm4: + return ThotWordAlignmentModelType.Ibm4; + default: + return defaultType + ?? throw new ArgumentException($"Invalid word alignment model type: {modelType}"); + } + } + } } diff --git a/src/SIL.Machine/Corpora/AlignedWordPair.cs b/src/SIL.Machine/Corpora/AlignedWordPair.cs index cf25108c8..534b87f81 100644 --- a/src/SIL.Machine/Corpora/AlignedWordPair.cs +++ b/src/SIL.Machine/Corpora/AlignedWordPair.cs @@ -23,6 +23,20 @@ public static IReadOnlyCollection Parse(string alignments, bool return result; } + public static bool TryParse(string alignments, out IReadOnlyCollection alignedWordPairs) + { + alignedWordPairs = null; + try + { + alignedWordPairs = Parse(alignments); + return true; + } + catch + { + return false; + } + } + public AlignedWordPair(int sourceIndex, int targetIndex) { SourceIndex = sourceIndex; diff --git a/src/SIL.Machine/SIL.Machine.csproj b/src/SIL.Machine/SIL.Machine.csproj index 57cf0ccb4..6a7cfbcd0 100644 --- a/src/SIL.Machine/SIL.Machine.csproj +++ b/src/SIL.Machine/SIL.Machine.csproj @@ -42,6 +42,7 @@ + diff --git a/src/SIL.Machine/Translation/IWordAlignmentEngine.cs b/src/SIL.Machine/Translation/IWordAlignmentEngine.cs new file mode 100644 index 000000000..b033bb9ff --- /dev/null +++ b/src/SIL.Machine/Translation/IWordAlignmentEngine.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using SIL.Machine.Corpora; +using SIL.ObjectModel; + +namespace SIL.Machine.Translation +{ + public interface IWordAlignmentEngine : IWordAligner, IDisposable + { + IWordVocabulary SourceWords { get; } + IWordVocabulary TargetWords { get; } + IReadOnlySet SpecialSymbolIndices { get; } + + IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0); + IEnumerable<(int TargetWordIndex, double Score)> GetTranslations(int sourceWordIndex, double threshold = 0); + + double GetTranslationScore(string sourceWord, string targetWord); + double GetTranslationScore(int sourceWordIndex, int targetWordIndex); + + IReadOnlyCollection GetBestAlignedWordPairs( + IReadOnlyList sourceSegment, + IReadOnlyList targetSegment + ); + void ComputeAlignedWordPairScores( + IReadOnlyList sourceSegment, + IReadOnlyList targetSegment, + IReadOnlyCollection wordPairs + ); + } +} diff --git a/src/SIL.Machine/Translation/IWordAlignmentModel.cs b/src/SIL.Machine/Translation/IWordAlignmentModel.cs index 483f47d30..7fa07505c 100644 --- a/src/SIL.Machine/Translation/IWordAlignmentModel.cs +++ b/src/SIL.Machine/Translation/IWordAlignmentModel.cs @@ -1,32 +1,9 @@ -using System; -using System.Collections.Generic; -using SIL.Machine.Corpora; -using SIL.ObjectModel; +using SIL.Machine.Corpora; namespace SIL.Machine.Translation { - public interface IWordAlignmentModel : IWordAligner, IDisposable + public interface IWordAlignmentModel : IWordAlignmentEngine { - IWordVocabulary SourceWords { get; } - IWordVocabulary TargetWords { get; } - IReadOnlySet SpecialSymbolIndices { get; } - ITrainer CreateTrainer(IParallelTextCorpus corpus); - - IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0); - IEnumerable<(int TargetWordIndex, double Score)> GetTranslations(int sourceWordIndex, double threshold = 0); - - double GetTranslationScore(string sourceWord, string targetWord); - double GetTranslationScore(int sourceWordIndex, int targetWordIndex); - - IReadOnlyCollection GetBestAlignedWordPairs( - IReadOnlyList sourceSegment, - IReadOnlyList targetSegment - ); - void ComputeAlignedWordPairScores( - IReadOnlyList sourceSegment, - IReadOnlyList targetSegment, - IReadOnlyCollection wordPairs - ); } } diff --git a/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs b/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs index e18b96111..544bd3fcc 100644 --- a/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs +++ b/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs @@ -1,4 +1,7 @@ -namespace SIL.Machine.Translation +using System.Collections.Generic; +using CaseExtensions; + +namespace SIL.Machine.Translation { public enum SymmetrizationHeuristic { @@ -39,4 +42,56 @@ public enum SymmetrizationHeuristic /// GrowDiagFinalAnd } + + public static class SymmetrizationHelpers + { + public const string Och = "och"; + public const string Union = "union"; + public const string Intersection = "intersection"; + public const string Grow = "grow"; + public const string GrowDiag = "grow-diag"; + public const string GrowDiagFinal = "grow-diag-final"; + public const string GrowDiagFinalAnd = "grow-diag-final-and"; + public const string None = "none"; + + public static bool ValidateSymmetrizationHeuristicOption(string value, bool noneAllowed = true) + { + var validHeuristics = new HashSet + { + Och, + Union, + Intersection, + Grow, + GrowDiag, + GrowDiagFinal, + GrowDiagFinalAnd + }; + if (noneAllowed) + validHeuristics.Add(None); + return string.IsNullOrEmpty(value) || validHeuristics.Contains(value.ToLowerInvariant()); + } + + public static SymmetrizationHeuristic GetSymmetrizationHeuristic(string value) + { + switch (value.ToKebabCase()) + { + case None: + return SymmetrizationHeuristic.None; + case Union: + return SymmetrizationHeuristic.Union; + case Intersection: + return SymmetrizationHeuristic.Intersection; + case Grow: + return SymmetrizationHeuristic.Grow; + case GrowDiag: + return SymmetrizationHeuristic.GrowDiag; + case GrowDiagFinal: + return SymmetrizationHeuristic.GrowDiagFinal; + case GrowDiagFinalAnd: + return SymmetrizationHeuristic.GrowDiagFinalAnd; + default: + return SymmetrizationHeuristic.Och; + } + } + } } diff --git a/src/SIL.Machine/Translation/SymmetrizedWordAlignmentEngine.cs b/src/SIL.Machine/Translation/SymmetrizedWordAlignmentEngine.cs new file mode 100644 index 000000000..924d908b9 --- /dev/null +++ b/src/SIL.Machine/Translation/SymmetrizedWordAlignmentEngine.cs @@ -0,0 +1,173 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using SIL.Machine.Corpora; +using SIL.ObjectModel; + +namespace SIL.Machine.Translation +{ + public class SymmetrizedWordAlignmentEngine : DisposableBase, IWordAlignmentEngine + { + private readonly IWordAlignmentEngine _directWordAlignmentEngine; + private readonly IWordAlignmentEngine _inverseWordAlignmentEngine; + private readonly SymmetrizedWordAligner _aligner; + + public SymmetrizedWordAlignmentEngine( + IWordAlignmentEngine directWordAlignmentEngine, + IWordAlignmentEngine inverseWordAlignmentEngine + ) + { + _directWordAlignmentEngine = directWordAlignmentEngine; + _inverseWordAlignmentEngine = inverseWordAlignmentEngine; + _aligner = new SymmetrizedWordAligner(DirectWordAlignmentEngine, InverseWordAlignmentEngine); + } + + public SymmetrizationHeuristic Heuristic + { + get => _aligner.Heuristic; + set => _aligner.Heuristic = value; + } + + public IWordAlignmentEngine DirectWordAlignmentEngine + { + get + { + CheckDisposed(); + + return _directWordAlignmentEngine; + } + } + + public IWordAlignmentEngine InverseWordAlignmentEngine + { + get + { + CheckDisposed(); + + return _inverseWordAlignmentEngine; + } + } + + public IWordVocabulary SourceWords + { + get + { + CheckDisposed(); + + return _directWordAlignmentEngine.SourceWords; + } + } + + public IWordVocabulary TargetWords + { + get + { + CheckDisposed(); + + return _directWordAlignmentEngine.TargetWords; + } + } + + public IReadOnlySet SpecialSymbolIndices => _directWordAlignmentEngine.SpecialSymbolIndices; + + public WordAlignmentMatrix Align(IReadOnlyList sourceSegment, IReadOnlyList targetSegment) + { + CheckDisposed(); + + return _aligner.Align(sourceSegment, targetSegment); + } + + public IReadOnlyList AlignBatch( + IReadOnlyList<(IReadOnlyList SourceSegment, IReadOnlyList TargetSegment)> segments + ) + { + CheckDisposed(); + + return _aligner.AlignBatch(segments); + } + + public IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0) + { + CheckDisposed(); + + foreach ((string targetWord, double dirScore) in _directWordAlignmentEngine.GetTranslations(sourceWord)) + { + double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWord, sourceWord); + double score = Math.Max(dirScore, invScore); + if (score > threshold) + yield return (targetWord, score); + } + } + + public IEnumerable<(int TargetWordIndex, double Score)> GetTranslations( + int sourceWordIndex, + double threshold = 0 + ) + { + CheckDisposed(); + + foreach ( + (int targetWordIndex, double dirScore) in _directWordAlignmentEngine.GetTranslations(sourceWordIndex) + ) + { + double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWordIndex, sourceWordIndex); + double score = Math.Max(dirScore, invScore); + if (score > threshold) + yield return (targetWordIndex, score); + } + } + + public double GetTranslationScore(string sourceWord, string targetWord) + { + CheckDisposed(); + + double dirScore = _directWordAlignmentEngine.GetTranslationScore(sourceWord, targetWord); + double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWord, sourceWord); + return Math.Max(dirScore, invScore); + } + + public double GetTranslationScore(int sourceWordIndex, int targetWordIndex) + { + CheckDisposed(); + + double dirScore = _directWordAlignmentEngine.GetTranslationScore(sourceWordIndex, targetWordIndex); + double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWordIndex, sourceWordIndex); + return Math.Max(dirScore, invScore); + } + + public IReadOnlyCollection GetBestAlignedWordPairs( + IReadOnlyList sourceSegment, + IReadOnlyList targetSegment + ) + { + CheckDisposed(); + + WordAlignmentMatrix matrix = Align(sourceSegment, targetSegment); + IReadOnlyCollection wordPairs = matrix.ToAlignedWordPairs(); + ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs); + return wordPairs; + } + + public void ComputeAlignedWordPairScores( + IReadOnlyList sourceSegment, + IReadOnlyList targetSegment, + IReadOnlyCollection wordPairs + ) + { + AlignedWordPair[] inverseWordPairs = wordPairs.Select(wp => wp.Invert()).ToArray(); + _directWordAlignmentEngine.ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs); + _inverseWordAlignmentEngine.ComputeAlignedWordPairScores(targetSegment, sourceSegment, inverseWordPairs); + foreach (var (wordPair, inverseWordPair) in wordPairs.Zip(inverseWordPairs, (wp, invWp) => (wp, invWp))) + { + wordPair.TranslationScore = Math.Max(wordPair.TranslationScore, inverseWordPair.TranslationScore); + wordPair.AlignmentScore = Math.Max(wordPair.AlignmentScore, inverseWordPair.AlignmentScore); + } + } + + protected override void DisposeManagedResources() + { + _directWordAlignmentEngine.Dispose(); + _inverseWordAlignmentEngine.Dispose(); + } + } +} diff --git a/src/SIL.Machine/Translation/SymmetrizedWordAlignmentModel.cs b/src/SIL.Machine/Translation/SymmetrizedWordAlignmentModel.cs index aedfdc573..a98e8e666 100644 --- a/src/SIL.Machine/Translation/SymmetrizedWordAlignmentModel.cs +++ b/src/SIL.Machine/Translation/SymmetrizedWordAlignmentModel.cs @@ -1,167 +1,20 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using SIL.Machine.Corpora; -using SIL.ObjectModel; +using SIL.Machine.Corpora; namespace SIL.Machine.Translation { - public class SymmetrizedWordAlignmentModel : DisposableBase, IWordAlignmentModel + public class SymmetrizedWordAlignmentModel : SymmetrizedWordAlignmentEngine, IWordAlignmentModel { private readonly IWordAlignmentModel _directWordAlignmentModel; private readonly IWordAlignmentModel _inverseWordAlignmentModel; - private readonly SymmetrizedWordAligner _aligner; public SymmetrizedWordAlignmentModel( IWordAlignmentModel directWordAlignmentModel, IWordAlignmentModel inverseWordAlignmentModel ) + : base(directWordAlignmentModel, inverseWordAlignmentModel) { _directWordAlignmentModel = directWordAlignmentModel; _inverseWordAlignmentModel = inverseWordAlignmentModel; - _aligner = new SymmetrizedWordAligner(DirectWordAlignmentModel, InverseWordAlignmentModel); - } - - public SymmetrizationHeuristic Heuristic - { - get => _aligner.Heuristic; - set => _aligner.Heuristic = value; - } - - public IWordAlignmentModel DirectWordAlignmentModel - { - get - { - CheckDisposed(); - - return _directWordAlignmentModel; - } - } - - public IWordAlignmentModel InverseWordAlignmentModel - { - get - { - CheckDisposed(); - - return _inverseWordAlignmentModel; - } - } - - public IWordVocabulary SourceWords - { - get - { - CheckDisposed(); - - return _directWordAlignmentModel.SourceWords; - } - } - - public IWordVocabulary TargetWords - { - get - { - CheckDisposed(); - - return _directWordAlignmentModel.TargetWords; - } - } - - public IReadOnlySet SpecialSymbolIndices => _directWordAlignmentModel.SpecialSymbolIndices; - - public WordAlignmentMatrix Align(IReadOnlyList sourceSegment, IReadOnlyList targetSegment) - { - CheckDisposed(); - - return _aligner.Align(sourceSegment, targetSegment); - } - - public IReadOnlyList AlignBatch( - IReadOnlyList<(IReadOnlyList SourceSegment, IReadOnlyList TargetSegment)> segments - ) - { - CheckDisposed(); - - return _aligner.AlignBatch(segments); - } - - public IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0) - { - CheckDisposed(); - - foreach ((string targetWord, double dirScore) in _directWordAlignmentModel.GetTranslations(sourceWord)) - { - double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWord, sourceWord); - double score = Math.Max(dirScore, invScore); - if (score > threshold) - yield return (targetWord, score); - } - } - - public IEnumerable<(int TargetWordIndex, double Score)> GetTranslations( - int sourceWordIndex, - double threshold = 0 - ) - { - CheckDisposed(); - - foreach ( - (int targetWordIndex, double dirScore) in _directWordAlignmentModel.GetTranslations(sourceWordIndex) - ) - { - double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWordIndex, sourceWordIndex); - double score = Math.Max(dirScore, invScore); - if (score > threshold) - yield return (targetWordIndex, score); - } - } - - public double GetTranslationScore(string sourceWord, string targetWord) - { - CheckDisposed(); - - double dirScore = _directWordAlignmentModel.GetTranslationScore(sourceWord, targetWord); - double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWord, sourceWord); - return Math.Max(dirScore, invScore); - } - - public double GetTranslationScore(int sourceWordIndex, int targetWordIndex) - { - CheckDisposed(); - - double dirScore = _directWordAlignmentModel.GetTranslationScore(sourceWordIndex, targetWordIndex); - double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWordIndex, sourceWordIndex); - return Math.Max(dirScore, invScore); - } - - public IReadOnlyCollection GetBestAlignedWordPairs( - IReadOnlyList sourceSegment, - IReadOnlyList targetSegment - ) - { - CheckDisposed(); - - WordAlignmentMatrix matrix = Align(sourceSegment, targetSegment); - IReadOnlyCollection wordPairs = matrix.ToAlignedWordPairs(); - ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs); - return wordPairs; - } - - public void ComputeAlignedWordPairScores( - IReadOnlyList sourceSegment, - IReadOnlyList targetSegment, - IReadOnlyCollection wordPairs - ) - { - AlignedWordPair[] inverseWordPairs = wordPairs.Select(wp => wp.Invert()).ToArray(); - _directWordAlignmentModel.ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs); - _inverseWordAlignmentModel.ComputeAlignedWordPairScores(targetSegment, sourceSegment, inverseWordPairs); - foreach (var (wordPair, inverseWordPair) in wordPairs.Zip(inverseWordPairs, (wp, invWp) => (wp, invWp))) - { - wordPair.TranslationScore = Math.Max(wordPair.TranslationScore, inverseWordPair.TranslationScore); - wordPair.AlignmentScore = Math.Max(wordPair.AlignmentScore, inverseWordPair.AlignmentScore); - } } public ITrainer CreateTrainer(IParallelTextCorpus corpus) diff --git a/src/SIL.Machine/Translation/WordAlignmentResult.cs b/src/SIL.Machine/Translation/WordAlignmentResult.cs new file mode 100644 index 000000000..01f7fb728 --- /dev/null +++ b/src/SIL.Machine/Translation/WordAlignmentResult.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace SIL.Machine.Translation +{ + public class WordAlignmentResult + { + public WordAlignmentResult( + IEnumerable sourceTokens, + IEnumerable targetTokens, + IEnumerable confidences, + WordAlignmentMatrix alignment + ) + { + SourceTokens = sourceTokens.ToArray(); + TargetTokens = targetTokens.ToArray(); + Confidences = confidences.ToArray(); + if (Confidences.Count != TargetTokens.Count) + { + throw new ArgumentException( + "The confidences must be the same length as the target segment.", + nameof(confidences) + ); + } + Alignment = alignment; + if (Alignment.RowCount != SourceTokens.Count) + { + throw new ArgumentException( + "The alignment source length must be the same length as the source segment.", + nameof(alignment) + ); + } + if (Alignment.ColumnCount != TargetTokens.Count) + { + throw new ArgumentException( + "The alignment target length must be the same length as the target segment.", + nameof(alignment) + ); + } + } + + public IReadOnlyList SourceTokens { get; } + public IReadOnlyList TargetTokens { get; } + public IReadOnlyList Confidences { get; } + public WordAlignmentMatrix Alignment { get; } + } +}