From 19eb6b843ea4336ab5287e1f1e2d6285e4321ac4 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Thu, 21 Nov 2024 21:37:55 -0500 Subject: [PATCH] Add tokenizer to trainer --- src/SIL.Machine.Tool/AlignCommand.cs | 8 +- .../AlignmentModelCommandSpec.cs | 22 ++--- src/SIL.Machine.Tool/SymmetrizeCommand.cs | 13 ++- src/SIL.Machine.Tool/ToolHelpers.cs | 81 +++---------------- .../TranslationModelCommandSpec.cs | 4 +- .../SIL.Machine.Translation.Thot.csproj | 1 + .../ThotWordAlignmentModel.cs | 14 +++- .../ThotWordAlignmentModelType.cs | 40 ++++++++- src/SIL.Machine/SIL.Machine.csproj | 1 + .../Translation/SymmetrizationHeuristic.cs | 57 ++++++++++++- 10 files changed, 149 insertions(+), 92 deletions(-) diff --git a/src/SIL.Machine.Tool/AlignCommand.cs b/src/SIL.Machine.Tool/AlignCommand.cs index 7888c7396..a00750632 100644 --- a/src/SIL.Machine.Tool/AlignCommand.cs +++ b/src/SIL.Machine.Tool/AlignCommand.cs @@ -39,7 +39,7 @@ public AlignCommand() ); _symHeuristicOption = Option( "-sh|--sym-heuristic ", - $"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\", \"{ToolHelpers.None}\".", + $"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\", \"{SymmetrizationHelpers.None}\".", CommandOptionType.SingleValue ); _scoresOption = Option("-s|--scores", "Include scores in the output.", CommandOptionType.NoValue); @@ -53,7 +53,7 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell if (code != 0) return code; - if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value())) + if (!SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value())) { Out.WriteLine("The specified symmetrization heuristic is invalid."); return 1; @@ -75,7 +75,9 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell int processorCount = Environment.ProcessorCount; - SymmetrizationHeuristic symHeuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption?.Value()); + SymmetrizationHeuristic symHeuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic( + _symHeuristicOption?.Value() + ); if (!_quietOption.HasValue()) Out.Write("Loading model... "); diff --git a/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs b/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs index b40c4fe11..a6804d6e6 100644 --- a/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs +++ b/src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs @@ -37,7 +37,7 @@ public void AddParameters(CommandBase command) _modelArgument = command.Argument("MODEL_PATH", "The word alignment model.").IsRequired(); _modelTypeOption = command.Option( "-mt|--model-type ", - $"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.Ibm3}\", \"{ToolHelpers.Ibm4}\", \"{ToolHelpers.FastAlign}\".", + $"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.Ibm3}\", \"{ThotWordAlignmentHelpers.Ibm4}\", \"{ThotWordAlignmentHelpers.FastAlign}\".", CommandOptionType.SingleValue ); _pluginOption = command.Option( @@ -90,7 +90,7 @@ public IWordAlignmentModel CreateAlignmentModel( ThotWordAlignmentModelType modelType = ThotWordAlignmentModelType.Hmm; if (_modelTypeOption.HasValue()) { - modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); + modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); } else { @@ -103,7 +103,7 @@ public IWordAlignmentModel CreateAlignmentModel( yaml.Load(reader); var root = (YamlMappingNode)yaml.Documents.First().RootNode; var modelTypeStr = (string)root[new YamlScalarNode("model")]; - modelType = ToolHelpers.GetThotWordAlignmentModelType(modelTypeStr); + modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(modelTypeStr); } } } @@ -144,7 +144,9 @@ public ITrainer CreateAlignmentModelTrainer( return _modelFactory.CreateTrainer(_modelArgument.Value, corpus, maxSize, parameters, direct); } - ThotWordAlignmentModelType modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value()); + ThotWordAlignmentModelType modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( + _modelTypeOption.Value() + ); string modelPath = _modelArgument.Value; if (ToolHelpers.IsDirectoryPath(modelPath)) @@ -227,12 +229,12 @@ private static bool ValidateAlignmentModelTypeOption(string value, IEnumerable { - ToolHelpers.Hmm, - ToolHelpers.Ibm1, - ToolHelpers.Ibm2, - ToolHelpers.FastAlign, - ToolHelpers.Ibm3, - ToolHelpers.Ibm4 + ThotWordAlignmentHelpers.Hmm, + ThotWordAlignmentHelpers.Ibm1, + ThotWordAlignmentHelpers.Ibm2, + ThotWordAlignmentHelpers.FastAlign, + ThotWordAlignmentHelpers.Ibm3, + ThotWordAlignmentHelpers.Ibm4 }; validTypes.UnionWith(pluginTypes); return string.IsNullOrEmpty(value) || validTypes.Contains(value); diff --git a/src/SIL.Machine.Tool/SymmetrizeCommand.cs b/src/SIL.Machine.Tool/SymmetrizeCommand.cs index eaa524206..7a403b8ee 100644 --- a/src/SIL.Machine.Tool/SymmetrizeCommand.cs +++ b/src/SIL.Machine.Tool/SymmetrizeCommand.cs @@ -37,7 +37,7 @@ public SymmetrizeCommand() ); _symHeuristicOption = Option( "-sh|--sym-heuristic ", - $"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\".", + $"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\".", CommandOptionType.SingleValue ); _quietOption = Option("-q|--quiet", "Only display results.", CommandOptionType.NoValue); @@ -67,14 +67,21 @@ protected override async Task ExecuteCommandAsync(CancellationToken cancell return 1; } - if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption.Value(), noneAllowed: false)) + if ( + !SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption( + _symHeuristicOption.Value(), + noneAllowed: false + ) + ) { Out.WriteLine("The specified symmetrization heuristic is invalid."); return 1; } string outputFormat = _outputFormatOption.Value() ?? Pharaoh; - SymmetrizationHeuristic heuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption.Value()); + SymmetrizationHeuristic heuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic( + _symHeuristicOption.Value() + ); using var directReader = new StreamReader(_directArgument.Value); using var inverseReader = new StreamReader(_inverseArgument.Value); diff --git a/src/SIL.Machine.Tool/ToolHelpers.cs b/src/SIL.Machine.Tool/ToolHelpers.cs index 2c3b00ed1..497075a26 100644 --- a/src/SIL.Machine.Tool/ToolHelpers.cs +++ b/src/SIL.Machine.Tool/ToolHelpers.cs @@ -14,22 +14,6 @@ namespace SIL.Machine; internal static class ToolHelpers { - public const string FastAlign = "fast_align"; - public const string Ibm1 = "ibm1"; - public const string Ibm2 = "ibm2"; - public const string Hmm = "hmm"; - public const string Ibm3 = "ibm3"; - public const string Ibm4 = "ibm4"; - - public const string Och = "och"; - public const string Union = "union"; - public const string Intersection = "intersection"; - public const string Grow = "grow"; - public const string GrowDiag = "grow-diag"; - public const string GrowDiagFinal = "grow-diag-final"; - public const string GrowDiagFinalAnd = "grow-diag-final-and"; - public const string None = "none"; - public static bool ValidateCorpusFormatOption(string value) { return string.IsNullOrEmpty(value) || value.ToLowerInvariant().IsOneOf("dbl", "usx", "text", "pt", "pt_m"); @@ -154,29 +138,14 @@ public static string GetTranslationModelConfigFileName(string path) public static bool ValidateTranslationModelTypeOption(string value) { - var validTypes = new HashSet { Hmm, Ibm1, Ibm2, FastAlign }; - return string.IsNullOrEmpty(value) || validTypes.Contains(value); - } - - public static ThotWordAlignmentModelType GetThotWordAlignmentModelType(string modelType) - { - switch (modelType) + var validTypes = new HashSet { - case "fastAlign": - case FastAlign: - return ThotWordAlignmentModelType.FastAlign; - case Ibm1: - return ThotWordAlignmentModelType.Ibm1; - case Ibm2: - return ThotWordAlignmentModelType.Ibm2; - default: - case Hmm: - return ThotWordAlignmentModelType.Hmm; - case Ibm3: - return ThotWordAlignmentModelType.Ibm3; - case Ibm4: - return ThotWordAlignmentModelType.Ibm4; - } + ThotWordAlignmentHelpers.Hmm, + ThotWordAlignmentHelpers.Ibm1, + ThotWordAlignmentHelpers.Ibm2, + ThotWordAlignmentHelpers.FastAlign + }; + return string.IsNullOrEmpty(value) || validTypes.Contains(value); } public static ITrainer CreateTranslationModelTrainer( @@ -186,7 +155,9 @@ public static ITrainer CreateTranslationModelTrainer( int maxSize ) { - ThotWordAlignmentModelType wordAlignmentModelType = GetThotWordAlignmentModelType(modelType); + ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( + modelType + ); string modelDir = Path.GetDirectoryName(modelConfigFileName); if (!Directory.Exists(modelDir)) @@ -201,38 +172,6 @@ int maxSize }; } - public static bool ValidateSymmetrizationHeuristicOption(string value, bool noneAllowed = true) - { - var validHeuristics = new HashSet - { - Och, - Union, - Intersection, - Grow, - GrowDiag, - GrowDiagFinal, - GrowDiagFinalAnd - }; - if (noneAllowed) - validHeuristics.Add(None); - return string.IsNullOrEmpty(value) || validHeuristics.Contains(value.ToLowerInvariant()); - } - - public static SymmetrizationHeuristic GetSymmetrizationHeuristic(string value) - { - return value switch - { - None => SymmetrizationHeuristic.None, - Union => SymmetrizationHeuristic.Union, - Intersection => SymmetrizationHeuristic.Intersection, - Grow => SymmetrizationHeuristic.Grow, - GrowDiag => SymmetrizationHeuristic.GrowDiag, - GrowDiagFinal => SymmetrizationHeuristic.GrowDiagFinal, - GrowDiagFinalAnd => SymmetrizationHeuristic.GrowDiagFinalAnd, - _ => SymmetrizationHeuristic.Och, - }; - } - public static StreamWriter CreateStreamWriter(string fileName) { var utf8Encoding = new UTF8Encoding(false); diff --git a/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs b/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs index 3ca50cc79..8523d691b 100644 --- a/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs +++ b/src/SIL.Machine.Tool/TranslationModelCommandSpec.cs @@ -22,7 +22,7 @@ public void AddParameters(CommandBase command) _modelArgument = command.Argument("MODEL_PATH", "The translation model.").IsRequired(); _modelTypeOption = command.Option( "-mt|--model-type ", - $"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.FastAlign}\".", + $"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.FastAlign}\".", CommandOptionType.SingleValue ); _pluginOption = command.Option( @@ -67,7 +67,7 @@ public ITranslationModel CreateModel() if (_modelFactory != null) return _modelFactory.CreateModel(_modelArgument.Value); - ThotWordAlignmentModelType wordAlignmentModelType = ToolHelpers.GetThotWordAlignmentModelType( + ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType( _modelTypeOption.Value() ); diff --git a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj index e7e5b9b27..b51de93f4 100644 --- a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj +++ b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj @@ -13,6 +13,7 @@ + diff --git a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs index b10135bd5..bad13bcf0 100644 --- a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs +++ b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using SIL.Extensions; using SIL.Machine.Corpora; +using SIL.Machine.Tokenization; using SIL.ObjectModel; namespace SIL.Machine.Translation.Thot @@ -116,6 +117,11 @@ public void CreateNew(string prefFileName) } public ITrainer CreateTrainer(IParallelTextCorpus corpus) + { + return CreateTrainer(corpus, null); + } + + public ITrainer CreateTrainer(IParallelTextCorpus corpus, ITokenizer tokenizer = null) { CheckDisposed(); @@ -126,7 +132,13 @@ public ITrainer CreateTrainer(IParallelTextCorpus corpus) ); } - return new Trainer(this, corpus); + var trainer = new Trainer(this, corpus); + if (tokenizer != null) + { + trainer.SourceTokenizer = tokenizer; + trainer.TargetTokenizer = tokenizer; + } + return trainer; } public Task SaveAsync() diff --git a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs index 1e3142fde..27cb264e2 100644 --- a/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs +++ b/src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs @@ -1,4 +1,7 @@ -namespace SIL.Machine.Translation.Thot +using System; +using CaseExtensions; + +namespace SIL.Machine.Translation.Thot { public enum ThotWordAlignmentModelType { @@ -9,4 +12,39 @@ public enum ThotWordAlignmentModelType Ibm3, Ibm4 } + + public static class ThotWordAlignmentHelpers + { + public const string FastAlign = "fast_align"; + public const string Ibm1 = "ibm1"; + public const string Ibm2 = "ibm2"; + public const string Hmm = "hmm"; + public const string Ibm3 = "ibm3"; + public const string Ibm4 = "ibm4"; + + public static ThotWordAlignmentModelType GetThotWordAlignmentModelType( + string modelType, + ThotWordAlignmentModelType? defaultType = null + ) + { + switch (modelType.ToSnakeCase()) + { + case FastAlign: + return ThotWordAlignmentModelType.FastAlign; + case Ibm1: + return ThotWordAlignmentModelType.Ibm1; + case Ibm2: + return ThotWordAlignmentModelType.Ibm2; + case Hmm: + return ThotWordAlignmentModelType.Hmm; + case Ibm3: + return ThotWordAlignmentModelType.Ibm3; + case Ibm4: + return ThotWordAlignmentModelType.Ibm4; + default: + return defaultType + ?? throw new ArgumentException($"Invalid word alignment model type: {modelType}"); + } + } + } } diff --git a/src/SIL.Machine/SIL.Machine.csproj b/src/SIL.Machine/SIL.Machine.csproj index 57cf0ccb4..6a7cfbcd0 100644 --- a/src/SIL.Machine/SIL.Machine.csproj +++ b/src/SIL.Machine/SIL.Machine.csproj @@ -42,6 +42,7 @@ + diff --git a/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs b/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs index e18b96111..544bd3fcc 100644 --- a/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs +++ b/src/SIL.Machine/Translation/SymmetrizationHeuristic.cs @@ -1,4 +1,7 @@ -namespace SIL.Machine.Translation +using System.Collections.Generic; +using CaseExtensions; + +namespace SIL.Machine.Translation { public enum SymmetrizationHeuristic { @@ -39,4 +42,56 @@ public enum SymmetrizationHeuristic /// GrowDiagFinalAnd } + + public static class SymmetrizationHelpers + { + public const string Och = "och"; + public const string Union = "union"; + public const string Intersection = "intersection"; + public const string Grow = "grow"; + public const string GrowDiag = "grow-diag"; + public const string GrowDiagFinal = "grow-diag-final"; + public const string GrowDiagFinalAnd = "grow-diag-final-and"; + public const string None = "none"; + + public static bool ValidateSymmetrizationHeuristicOption(string value, bool noneAllowed = true) + { + var validHeuristics = new HashSet + { + Och, + Union, + Intersection, + Grow, + GrowDiag, + GrowDiagFinal, + GrowDiagFinalAnd + }; + if (noneAllowed) + validHeuristics.Add(None); + return string.IsNullOrEmpty(value) || validHeuristics.Contains(value.ToLowerInvariant()); + } + + public static SymmetrizationHeuristic GetSymmetrizationHeuristic(string value) + { + switch (value.ToKebabCase()) + { + case None: + return SymmetrizationHeuristic.None; + case Union: + return SymmetrizationHeuristic.Union; + case Intersection: + return SymmetrizationHeuristic.Intersection; + case Grow: + return SymmetrizationHeuristic.Grow; + case GrowDiag: + return SymmetrizationHeuristic.GrowDiag; + case GrowDiagFinal: + return SymmetrizationHeuristic.GrowDiagFinal; + case GrowDiagFinalAnd: + return SymmetrizationHeuristic.GrowDiagFinalAnd; + default: + return SymmetrizationHeuristic.Och; + } + } + } }