Skip to content

Commit

Permalink
Training evaluation break in gradient descent weight learning.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Sep 8, 2023
1 parent 37fa46d commit 5c49c87
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.gradient.batchgenerator.BatchGenerator;
import org.linqs.psl.application.learning.weight.gradient.batchgenerator.ConnectedComponentBatchGenerator;
import org.linqs.psl.application.learning.weight.gradient.batchgenerator.FullBatchGenerator;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
Expand Down Expand Up @@ -78,6 +76,11 @@ public static enum GDExtension {
protected List<DeepModelPredicate> trainFullDeepModelPredicates;
protected TermState[] trainFullMAPTermState;
protected float[] trainFullMAPAtomValueState;
double currentFullMAPEvaluationMetric;
double bestFullMAPEvaluationMetric;
protected boolean fullMAPEvaluationBreak;
protected int fullMAPEvaluationPatience;
protected int lastFullMAPImprovementEpoch;

protected TermState[] trainMAPTermState;
protected float[] trainMAPAtomValueState;
Expand Down Expand Up @@ -137,6 +140,11 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
trainFullDeepModelPredicates = null;
trainFullMAPTermState = null;
trainFullMAPAtomValueState = null;
currentFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY;
bestFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY;
fullMAPEvaluationBreak = Options.WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_BREAK.getBoolean();
fullMAPEvaluationPatience = Options.WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_PATIENCE.getInt();
lastFullMAPImprovementEpoch = 0;

trainMAPTermState = null;
trainMAPAtomValueState = null;
Expand Down Expand Up @@ -282,12 +290,17 @@ protected void initForLearning() {
deepPredicate.predictDeepModel(true);
}

currentFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY;
bestFullMAPEvaluationMetric = Double.NEGATIVE_INFINITY;
lastFullMAPImprovementEpoch = 0;

bestValidationWeights = new float[mutableRules.size()];
for (int i = 0; i < mutableRules.size(); i++) {
bestValidationWeights[i] = mutableRules.get(i).getWeight();
}
currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
lastValidationImprovementEpoch = 0;

trainMAPTermState = trainFullMAPTermState;
trainMAPAtomValueState = trainFullMAPAtomValueState;
Expand All @@ -309,7 +322,7 @@ protected void doLearn() {
}

if (log.isTraceEnabled() && (evaluation != null) && (epoch % trainingEvaluationComputePeriod == 0)) {
runMAPEvaluation();
runMAPEvaluation(epoch);
log.trace("MAP State Training Evaluation Metric: {}", evaluation.getNormalizedRepMetric());
}

Expand Down Expand Up @@ -385,7 +398,7 @@ protected void doLearn() {
if (saveBestValidationWeights) {
finalMAPStateEvaluation = bestValidationEvaluationMetric;
} else {
runMAPEvaluation();
runMAPEvaluation(epoch);
finalMAPStateEvaluation = evaluation.getNormalizedRepMetric();
}
log.info("Final MAP State Evaluation Metric: {}", finalMAPStateEvaluation);
Expand Down Expand Up @@ -502,14 +515,21 @@ protected void setValidationModel() {
}
}

protected void runMAPEvaluation() {
protected void runMAPEvaluation(int epoch) {
setFullTrainModel();

// Compute the MAP state before evaluating so variables have assigned values.
log.trace("Running MAP Inference.");
computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);

evaluation.compute(trainingMap);
currentFullMAPEvaluationMetric = evaluation.getNormalizedRepMetric();

if (currentFullMAPEvaluationMetric > bestFullMAPEvaluationMetric) {
lastFullMAPImprovementEpoch = epoch;

bestFullMAPEvaluationMetric = currentFullMAPEvaluationMetric;
}

// Evaluate the deep predicates. This calls predict with learning set to false.
for (DeepPredicate deepPredicate : deepPredicates) {
Expand Down Expand Up @@ -556,6 +576,11 @@ protected boolean breakOptimization(int epoch) {
return false;
}

if (fullMAPEvaluationBreak && (epoch - lastFullMAPImprovementEpoch) > fullMAPEvaluationPatience) {
log.trace("Breaking Weight Learning. No improvement in training evaluation metric for {} epochs.", (epoch - lastFullMAPImprovementEpoch));
return true;
}

if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) {
log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch));
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ protected boolean breakOptimization(int epoch) {
return false;
}

if (fullMAPEvaluationBreak && (epoch - lastFullMAPImprovementEpoch) > fullMAPEvaluationPatience) {
log.trace("Breaking Weight Learning. No improvement in training evaluation metric for {} epochs.", (epoch - lastFullMAPImprovementEpoch));
return true;
}

if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) {
log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch));
return true;
Expand Down
12 changes: 12 additions & 0 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,18 @@ public class Options {
"Compute training evaluation every this many iterations of gradient descent weight learning."
);

public static final Option WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_BREAK = new Option(
"gradientdescent.trainingevaluationbreak",
false,
"Break gradient descent weight learning when the training evaluation stops improving."
);

public static final Option WLA_GRADIENT_DESCENT_FULL_MAP_EVALUATION_PATIENCE = new Option(
"gradientdescent.trainingevaluationpatience",
25,
"Break gradient descent weight learning when the training evaluation stops improving after this many epochs."
);

public static final Option WLA_GRADIENT_DESCENT_VALIDATION_BREAK = new Option(
"gradientdescent.validationbreak",
false,
Expand Down

0 comments on commit 5c49c87

Please sign in to comment.