From 5c49c8745b7a633bf41a4f5083d91f14666e1465 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Fri, 8 Sep 2023 07:09:55 -0700 Subject: [PATCH] Training evaluation break in gradient descent weight learning. --- .../weight/gradient/GradientDescent.java | 35 ++++++++++++++++--- .../weight/gradient/minimizer/Minimizer.java | 5 +++ .../java/org/linqs/psl/config/Options.java | 12 +++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 4dcdfbbbc..319f35fb9 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -20,8 +20,6 @@ import org.linqs.psl.application.inference.InferenceApplication; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.gradient.batchgenerator.BatchGenerator; -import org.linqs.psl.application.learning.weight.gradient.batchgenerator.ConnectedComponentBatchGenerator; -import org.linqs.psl.application.learning.weight.gradient.batchgenerator.FullBatchGenerator; import org.linqs.psl.config.Options; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; @@ -78,6 +76,11 @@ public static enum GDExtension { protected List trainFullDeepModelPredicates; protected TermState[] trainFullMAPTermState; protected float[] trainFullMAPAtomValueState; + double currentFullMAPEvaluationMetric; + double bestFullMAPEvaluationMetric; + protected boolean fullMAPEvaluationBreak; + protected int fullMAPEvaluationPatience; + protected int lastFullMAPImprovementEpoch; protected TermState[] trainMAPTermState; protected float[] trainMAPAtomValueState; @@ -137,6 +140,11 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database trainFullDeepModelPredicates = null; trainFullMAPTermState = null; trainFullMAPAtomValueState = null; + currentFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY; + bestFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY; + fullMAPEvaluationBreak = Options.WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_BREAK.getBoolean(); + fullMAPEvaluationPatience = Options.WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_PATIENCE.getInt(); + lastFullMAPImprovementEpoch = 0; trainMAPTermState = null; trainMAPAtomValueState = null; @@ -282,12 +290,17 @@ protected void initForLearning() { deepPredicate.predictDeepModel(true); } + currentFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY; + bestFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY; + lastFullMAPImprovementEpoch = 0; + bestValidationWeights = new float[mutableRules.size()]; for (int i = 0; i < mutableRules.size(); i++) { bestValidationWeights[i] = mutableRules.get(i).getWeight(); } currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY; bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY; + lastValidationImprovementEpoch = 0; trainMAPTermState = trainFullMAPTermState; trainMAPAtomValueState = trainFullMAPAtomValueState; @@ -309,7 +322,7 @@ protected void doLearn() { } if (log.isTraceEnabled() && (evaluation != null) && (epoch % trainingEvaluationComputePeriod == 0)) { - runMAPEvaluation(); + runMAPEvaluation(epoch); log.trace("MAP State Training Evaluation Metric: {}", evaluation.getNormalizedRepMetric()); } @@ -385,7 +398,7 @@ protected void doLearn() { if (saveBestValidationWeights) { finalMAPStateEvaluation = bestValidationEvaluationMetric; } else { - runMAPEvaluation(); + runMAPEvaluation(epoch); finalMAPStateEvaluation = evaluation.getNormalizedRepMetric(); } log.info("Final MAP State Evaluation Metric: {}", finalMAPStateEvaluation); @@ -502,7 +515,7 @@ protected void setValidationModel() { } } - protected void runMAPEvaluation() { + protected void runMAPEvaluation(int epoch) { setFullTrainModel(); // Compute the MAP state before evaluating so variables have assigned values. @@ -510,6 +523,13 @@ protected void runMAPEvaluation() { computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); evaluation.compute(trainingMap); + currentFullMAPEvaluationMetric = evaluation.getNormalizedRepMetric(); + + if (currentFullMAPEvaluationMetric > bestFullMAPEvaluationMetric) { + lastFullMAPImprovementEpoch = epoch; + + bestFullMAPEvaluationMetric = currentFullMAPEvaluationMetric; + } // Evaluate the deep predicates. This calls predict with learning set to false. for (DeepPredicate deepPredicate : deepPredicates) { @@ -556,6 +576,11 @@ protected boolean breakOptimization(int epoch) { return false; } + if (fullMAPEvaluationBreak && (epoch - lastFullMAPImprovementEpoch) > fullMAPEvaluationPatience) { + log.trace("Breaking Weight Learning. No improvement in training evaluation metric for {} epochs.", (epoch - lastFullMAPImprovementEpoch)); + return true; + } + if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) { log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch)); return true; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 7d6bfba81..6eb337b83 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -320,6 +320,11 @@ protected boolean breakOptimization(int epoch) { return false; } + if (fullMAPEvaluationBreak && (epoch - lastFullMAPImprovementEpoch) > fullMAPEvaluationPatience) { + log.trace("Breaking Weight Learning. No improvement in training evaluation metric for {} epochs.", (epoch - lastFullMAPImprovementEpoch)); + return true; + } + if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) { log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch)); return true; diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index f5d9630e4..c78636b2c 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -419,6 +419,18 @@ public class Options { "Compute training evaluation every this many iterations of gradient descent weight learning." ); + public static final Option WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_BREAK = new Option( + "gradientdescent.trainingevaluationbreak", + false, + "Break gradient descent weight learning when the training evaluation stops improving." + ); + + public static final Option WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_PATIENCE = new Option( + "gradientdescent.trainingevaluationpatience", + 25, + "Break gradient descent weight learning when the training evaluation stops improving after this many epochs." + ); + public static final Option WLA_GRADIENT_DESCENT_VALIDATION_BREAK = new Option( "gradientdescent.validationbreak", false,