From c70ff92096d4c0468286d436a730867cecfe0847 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Fri, 1 Sep 2023 16:06:16 -0700 Subject: [PATCH] Validation break. --- .../weight/gradient/GradientDescent.java | 21 +++++++-- .../weight/gradient/minimizer/Minimizer.java | 11 +++-- .../java/org/linqs/psl/config/Options.java | 45 ++++++------------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index d8e385a2b..dc7bdfceb 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -87,6 +87,9 @@ public static enum GDExtension { protected List batchMAPAtomValueStates; protected int validationEvaluationComputePeriod; + protected boolean validationBreak; + protected int validationPatience; + protected int lastValidationImprovementEpoch; protected TermState[] validationMAPTermState; protected float[] validationMAPAtomValueState; protected boolean saveBestValidationWeights; @@ -152,6 +155,9 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database bestValidationWeights = null; currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY; bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY; + validationBreak = Options.WLA_GRADIENT_DESCENT_VALIDATION_BREAK.getBoolean(); + validationPatience = Options.WLA_GRADIENT_DESCENT_VALIDATION_PATIENCE.getInt(); + lastValidationImprovementEpoch = 0; if (saveBestValidationWeights && (!this.runValidation)) { throw new IllegalArgumentException("If saveBestValidationWeights is true, then runValidation must also be true."); @@ -307,7 +313,7 @@ protected void doLearn() { } if (runValidation && (epoch % validationEvaluationComputePeriod == 0)) { - runValidationEvaluation(); + runValidationEvaluation(epoch); log.debug("Current MAP State Validation Evaluation Metric: {}", currentValidationEvaluationMetric); } @@ -389,7 +395,7 @@ protected void doLearn() { if (saveBestValidationWeights) { finalMAPStateValidationEvaluation = bestValidationEvaluationMetric; } else { - runValidationEvaluation(); + runValidationEvaluation(epoch); finalMAPStateValidationEvaluation = currentValidationEvaluationMetric; } log.info("Final MAP State Validation Evaluation Metric: {}", finalMAPStateValidationEvaluation); @@ -510,7 +516,7 @@ protected void runMAPEvaluation() { } } - protected void runValidationEvaluation() { + protected void runValidationEvaluation(int epoch) { setValidationModel(); log.trace("Running Validation Inference."); @@ -520,6 +526,8 @@ protected void runValidationEvaluation() { currentValidationEvaluationMetric = evaluation.getNormalizedRepMetric(); if (currentValidationEvaluationMetric > bestValidationEvaluationMetric) { + lastValidationImprovementEpoch = epoch; + bestValidationEvaluationMetric = currentValidationEvaluationMetric; // Save the best rule weights. @@ -539,7 +547,7 @@ protected void runValidationEvaluation() { protected boolean breakOptimization(int epoch) { if (epoch >= maxNumSteps) { - log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", maxNumSteps); + log.trace("Breaking Weight Learning. Reached maximum number of epochs: {}", maxNumSteps); return true; } @@ -547,6 +555,11 @@ protected boolean breakOptimization(int epoch) { return false; } + if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) { + log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch)); + return true; + } + if (movementBreak && MathUtils.equals(parameterMovement, 0.0f, movementTolerance)) { log.trace("Breaking Weight Learning. Parameter Movement: {} is within tolerance: {}", parameterMovement, movementTolerance); return true; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 08f88d6a5..4b9f56258 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -310,9 +310,9 @@ protected void setBatch(int batch) { } @Override - protected boolean breakOptimization(int iteration) { - if (iteration >= maxNumSteps) { - log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", maxNumSteps); + protected boolean breakOptimization(int epoch) { + if (epoch >= maxNumSteps) { + log.trace("Breaking Weight Learning. Reached maximum number of epochs: {}", maxNumSteps); return true; } @@ -320,6 +320,11 @@ protected boolean breakOptimization(int iteration) { return false; } + if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) { + log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch)); + return true; + } + float totalObjectiveDifference = computeTotalObjectiveDifference(); if (totalObjectiveDifference < finalConstraintTolerance) { log.trace("Breaking Weight Learning. Objective difference {} is less than final constraint tolerance {}.", diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index 183497e5e..a63eee7ef 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -350,22 +350,6 @@ public class Options { Option.FLAG_POSITIVE ); - public static final Option WLA_GRADIENT_DESCENT_NORM_BREAK = new Option( - "gradientdescent.normbreak", - false, - "When the gradient norm is below the tolerance " - + " set by gradientdescent.normtolerance, gradient descent weight learning is stopped." - ); - - public static final Option WLA_GRADIENT_DESCENT_NORM_TOLERANCE = new Option( - "gradientdescent.normtolerance", - 1.0e-3f, - "If gradientdescent.runfulliterations=false and gradientdescent.normbreak=true," - + " then when the norm of the gradient is below this tolerance " - + " gradient descent weight learning is stopped.", - Option.FLAG_POSITIVE - ); - public static final Option WLA_GRADIENT_DESCENT_NUM_STEPS = new Option( "gradientdescent.numsteps", 500, @@ -387,22 +371,6 @@ public class Options { Option.FLAG_POSITIVE ); - public static final Option WLA_GRADIENT_DESCENT_OBJECTIVE_BREAK = new Option( - "gradientdescent.objectivebreak", - false, - "When the objective change between iterates is below the tolerance " - + " set by gradientdescent.objectivetolerance, gradient descent weight learning is stopped." - ); - - public static final Option WLA_GRADIENT_DESCENT_OBJECTIVE_TOLERANCE = new Option( - "gradientdescent.objectivetolerance", - 1.0e-5f, - "If gradientdescent.runfulliterations=false and gradientdescent.objectivebreak=true," - + " then when the objective change between iterates is below this tolerance" - + " gradient descent weight learning is stopped.", - Option.FLAG_POSITIVE - ); - public static final Option WLA_GRADIENT_DESCENT_RUN_FULL_ITERATIONS = new Option( "gradientdescent.runfulliterations", false, @@ -417,6 +385,7 @@ public class Options { + " If true, then gradientdescent.runvalidation must be true." ); + public static final Option WLA_GRADIENT_DESCENT_SCALE_STEP = new Option( "gradientdescent.scalestepsize", true, @@ -444,12 +413,24 @@ public class Options { "Compute training evaluation every this many iterations of gradient descent weight learning." ); + public static final Option WLA_GRADIENT_DESCENT_VALIDATION_BREAK = new Option( + "gradientdescent.validationbreak", + false, + "Break gradient descent weight learning when the validation evaluation stops improving." + ); + public static final Option WLA_GRADIENT_DESCENT_VALIDATION_COMPUTE_PERIOD = new Option( "gradientdescent.validationcomputeperiod", 1, "Compute validation evaluation every this many iterations of gradient descent weight learning." ); + public static final Option WLA_GRADIENT_DESCENT_VALIDATION_PATIENCE = new Option( + "gradientdescent.validationpatience", + 25, + "Break gradient descent weight learning when the validation evaluation stops improving after this many epochs." + ); + public static final Option WLA_GS_POSSIBLE_WEIGHTS = new Option( "gridsearch.weights", "0.001:0.01:0.1:1:10",