From ba8faf9655489a374dff81f8a2eed25ab8828b69 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Sat, 16 Sep 2023 10:39:49 -0700 Subject: [PATCH] Gradient descent reasoner. --- .../mpe/GradientDescentInference.java | 46 ++++ .../java/org/linqs/psl/config/Options.java | 73 +++++- .../java/org/linqs/psl/reasoner/Reasoner.java | 20 +- .../GradientDescentReasoner.java | 216 ++++++++++++++++++ .../term/GradientDescentObjectiveTerm.java | 41 ++++ .../term/GradientDescentTermGenerator.java | 68 ++++++ .../term/GradientDescentTermStore.java | 38 +++ .../mpe/GradientDescentInferenceTest.java | 73 ++++++ .../gradient/optimalvalue/EnergyTest.java | 8 + 9 files changed, 577 insertions(+), 6 deletions(-) create mode 100644 psl-core/src/main/java/org/linqs/psl/application/inference/mpe/GradientDescentInference.java create mode 100644 psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/GradientDescentReasoner.java create mode 100644 psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentObjectiveTerm.java create mode 100644 psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermGenerator.java create mode 100644 psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermStore.java create mode 100644 psl-core/src/test/java/org/linqs/psl/application/inference/mpe/GradientDescentInferenceTest.java diff --git a/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/GradientDescentInference.java b/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/GradientDescentInference.java new file mode 100644 index 000000000..e07398e7d --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/GradientDescentInference.java @@ -0,0 +1,46 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.inference.mpe; + +import org.linqs.psl.database.Database; +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.reasoner.Reasoner; +import org.linqs.psl.reasoner.gradientdescent.GradientDescentReasoner; +import org.linqs.psl.reasoner.gradientdescent.term.GradientDescentTermStore; +import org.linqs.psl.reasoner.term.TermStore; + +import java.util.List; + +/** + * Use an GradientDescent reasoner to perform MPE inference. + */ +public class GradientDescentInference extends MPEInference { + public GradientDescentInference(List rules, Database db) { + super(rules, db, true); + } + + @Override + protected Reasoner createReasoner() { + return new GradientDescentReasoner(); + } + + @Override + public TermStore createTermStore() { + return new GradientDescentTermStore(database.getAtomStore()); + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index c78636b2c..a3d4fe93c 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -26,6 +26,7 @@ import org.linqs.psl.evaluation.statistics.DiscreteEvaluator; import org.linqs.psl.evaluation.statistics.AUCEvaluator; import org.linqs.psl.reasoner.InitialValue; +import org.linqs.psl.reasoner.gradientdescent.GradientDescentReasoner; import org.linqs.psl.reasoner.sgd.SGDReasoner; import org.linqs.psl.util.SystemUtils; @@ -547,7 +548,7 @@ public class Options { public static final Option INFERENCE_RELAX_SQUARED = new Option( "inference.relax.squared", - true, + false, "When relaxing a hard constraint into a soft one, this determines if the resulting weighted rule is squared." ); @@ -769,10 +770,66 @@ public class Options { "If true, run the suite of evaluators specified for the post-inference evaluation stage at regular intervals during inference." ); - public static final Option REASONER_OBJECTIVE_BREAK = new Option( - "reasoner.objectivebreak", - false, - "Stop if the objective has not changed since the last iteration (or logging period)." + public static final Option GRADIENT_DESCENT_EXTENSION = new Option( + "reasoner.gradientdescent.extension", + GradientDescentReasoner.GradientDescentExtension.NONE.toString(), + "The GD extension to use for GD reasoning." + + " NONE (Default): The standard GD optimizer takes steps in the direction of the negative gradient scaled by the learning rate." + + " MOMENTUM: Modify the descent direction with a momentum term." + + " NESTEROV_ACCELERATION: Use the Nesterov accelerated gradient method." + ); + + public static final Option GRADIENT_DESCENT_FIRST_ORDER_BREAK = new Option( + "reasoner.gradientdescent.firstorderbreak", + true, + "Stop gradient descent when the norm of the gradient is less than reasoner.gradientdescent.firstorderthreshold." + ); + + public static final Option GRADIENT_DESCENT_FIRST_ORDER_NORM = new Option( + "reasoner.gradientdescent.firstordernorm", + Float.POSITIVE_INFINITY, + "The p-norm used to measure the first order optimality condition." + + " Default is the infinity-norm which is the absolute value of the maximum component of the gradient vector." + + " Note that the infinity-norm can be explicitly set with the string literal: 'Infinity'.", + Option.FLAG_NON_NEGATIVE + ); + + public static final Option GRADIENT_DESCENT_FIRST_ORDER_THRESHOLD = new Option( + "reasoner.gradientdescent.firstorderthreshold", + 0.01f, + "Gradient descent stops when the norm of the gradient is less than this threshold.", + Option.FLAG_NON_NEGATIVE + ); + + public static final Option GRADIENT_DESCENT_INVERSE_TIME_EXP = new Option( + "reasoner.gradientdescent.inversescaleexp", + 1.0f, + "If GradientDescent is using the STEPDECAY learning schedule, then this value is the negative" + + " exponent of the iteration count which scales the gradient step using:" + + " (learning_rate / ( iteration ^ - GRADIENT_DESCENT_INVERSE_TIME_EXP)).", + Option.FLAG_POSITIVE + ); + + public static final Option GRADIENT_DESCENT_LEARNING_RATE = new Option( + "reasoner.gradientdescent.learningrate", + 0.1f, + "The learning rate for gradient descent inference.", + Option.FLAG_POSITIVE + ); + + public static final Option GRADIENT_DESCENT_LEARNING_SCHEDULE = new Option( + "reasoner.gradientdescent.learningschedule", + GradientDescentReasoner.GradientDescentLearningSchedule.CONSTANT.toString(), + "The learning schedule of the GradientDescent inference reasoner changes the learning rate during learning." + + " STEPDECAY (Default): Decay the learning rate like: learningRate / (n_epoch^p) where p is set by reasoner.gradientdescent.inversescaleexp." + + " CONSTANT: The learning rate is constant during learning." + ); + + public static final Option GRADIENT_DESCENT_MAX_ITER = new Option( + "reasoner.gradientdescent.maxiterations", + 2500, + "The maximum number of iterations of Gradient Descent to perform in a round of inference.", + Option.FLAG_POSITIVE ); public static final Option REASONER_RUN_FULL_ITERATIONS = new Option( @@ -781,6 +838,12 @@ public class Options { "Ignore all other stopping criteria and run until the maximum number of iterations." ); + public static final Option REASONER_OBJECTIVE_BREAK = new Option( + "reasoner.objectivebreak", + false, + "Stop if the objective has not changed since the last iteration (or logging period)." + ); + public static final Option REASONER_OBJECTIVE_TOLERANCE = new Option( "reasoner.objectivetolerance", 1e-5f, diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java b/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java index a6cf16b4c..44057db5d 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java @@ -227,7 +227,7 @@ public void parallelComputeGradient(TermStore termStore, float[] rvAtomGradie Arrays.fill(rvAtomGradient, 0.0f); Arrays.fill(deepAtomGradient, 0.0f); for(int j = 0; j < numTermBlocks; j++) { - for(int i = 0; i < termStore.getAtomStore().getMaxRVAIndex(); i++) { + for(int i = 0; i <= termStore.getAtomStore().getMaxRVAIndex(); i++) { rvAtomGradient[i] += workerRVAtomGradients[j][i]; deepAtomGradient[i] += workerDeepGradients[j][i]; } @@ -247,6 +247,24 @@ protected void clipGradient(float[] variableValues, float[] gradient) { } } + /** + * Clip (sub)gradient magnitude. + */ + protected void clipGradientMagnitude(float[] gradient, float maxMagnitude) { + float maxGradient = 0.0f; + for (int i = 0; i < gradient.length; i++) { + if (Math.abs(gradient[i]) > maxGradient) { + maxGradient = Math.abs(gradient[i]); + } + } + + if (maxGradient > maxMagnitude) { + for (int i = 0; i < gradient.length; i++) { + gradient[i] = gradient[i] * maxMagnitude / maxGradient; + } + } + } + /** * Compute the total weighted objective of the terms in their current state. */ diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/GradientDescentReasoner.java b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/GradientDescentReasoner.java new file mode 100644 index 000000000..11180c0bd --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/GradientDescentReasoner.java @@ -0,0 +1,216 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2021 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.reasoner.gradientdescent; + +import org.linqs.psl.application.learning.weight.TrainingMap; +import org.linqs.psl.config.Options; +import org.linqs.psl.evaluation.EvaluationInstance; +import org.linqs.psl.model.atom.GroundAtom; +import org.linqs.psl.reasoner.Reasoner; +import org.linqs.psl.reasoner.gradientdescent.term.GradientDescentObjectiveTerm; +import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm; +import org.linqs.psl.reasoner.term.TermStore; +import org.linqs.psl.util.Logger; +import org.linqs.psl.util.MathUtils; + +import java.util.List; + +/** + * Uses a Gradient optimization method to optimize its GroundRules. + */ +public class GradientDescentReasoner extends Reasoner { + private static final Logger log = Logger.getLogger(GradientDescentReasoner.class); + + /** + * The Gradient Descent Extension to use. + */ + public enum GradientDescentExtension { + NONE, + MOMENTUM, + NESTEROV_ACCELERATION + } + + /** + * The Gradient Descent learning schedule to use. + */ + public static enum GradientDescentLearningSchedule { + CONSTANT, + STEPDECAY + } + + private final boolean firstOrderBreak; + private final float firstOrderTolerance; + private final float firstOrderNorm; + + private float[] gradient; + + private final float initialLearningRate; + private final float learningRateInverseScaleExp; + private final GradientDescentReasoner.GradientDescentLearningSchedule learningSchedule; + private final GradientDescentExtension gdExtension; + + public GradientDescentReasoner() { + maxIterations = Options.GRADIENT_DESCENT_MAX_ITER.getInt(); + firstOrderBreak = Options.GRADIENT_DESCENT_FIRST_ORDER_BREAK.getBoolean(); + firstOrderTolerance = Options.GRADIENT_DESCENT_FIRST_ORDER_THRESHOLD.getFloat(); + firstOrderNorm = Options.GRADIENT_DESCENT_FIRST_ORDER_NORM.getFloat(); + + gradient = null; + + gdExtension = GradientDescentExtension.valueOf(Options.GRADIENT_DESCENT_EXTENSION.getString().toUpperCase()); + + initialLearningRate = Options.GRADIENT_DESCENT_LEARNING_RATE.getFloat(); + learningRateInverseScaleExp = Options.GRADIENT_DESCENT_INVERSE_TIME_EXP.getFloat(); + learningSchedule = GradientDescentReasoner.GradientDescentLearningSchedule.valueOf(Options.GRADIENT_DESCENT_LEARNING_SCHEDULE.getString().toUpperCase()); + } + + @Override + public double optimize(TermStore termStore, List evaluations, TrainingMap trainingMap) { + termStore.initForOptimization(); + initForOptimization(termStore); + + // Return if there are no decision variables. + boolean hasDecisionVariables = false; + for (GroundAtom atom : termStore.getAtomStore()) { + if (!(atom.isFixed())) { + hasDecisionVariables = true; + break; + } + } + + if (!hasDecisionVariables){ + log.trace("No random variable atoms to optimize."); + return parallelComputeObjective(termStore).objective; + } + + // Return if there are no terms. + if (termStore.size() == 0) { + log.trace("No terms to optimize."); + return parallelComputeObjective(termStore).objective; + } + + float learningRate = 0.0f; + + GroundAtom[] atoms = termStore.getAtomStore().getAtoms(); + float[] atomValues = termStore.getAtomStore().getAtomValues(); + float[] update = new float[termStore.getAtomStore().size()]; + gradient = new float[termStore.getAtomStore().size()]; + float[] deepAtomGradients = new float[termStore.getAtomStore().size()]; + ObjectiveResult objectiveResult = parallelComputeObjective(termStore); + ObjectiveResult oldObjectiveResult = null; + + long totalTime = 0; + boolean breakGradientDescent = false; + int iteration = 1; + while (!breakGradientDescent) { + long startTime = System.currentTimeMillis(); + + learningRate = calculateAnnealedLearningRate(iteration); + + if (gdExtension == GradientDescentExtension.NESTEROV_ACCELERATION) { + for (int i = 0; i < gradient.length; i++) { + if (atoms[i].isFixed()) { + continue; + } + + atomValues[i] = Math.min(Math.max(atomValues[i] - 0.9f * update[i], 0.0f), 1.0f); + } + } + + parallelComputeGradient(termStore, gradient, deepAtomGradients); + clipGradientMagnitude(gradient, 1.0f); + + for (int i = 0; i < gradient.length; i++) { + if (atoms[i].isFixed()) { + continue; + } + + switch (gdExtension) { + case MOMENTUM: + update[i] = 0.9f * update[i] + learningRate * gradient[i]; + atomValues[i] = Math.min(Math.max(atomValues[i] - update[i], 0.0f), 1.0f); + break; + case NESTEROV_ACCELERATION: + update[i] = 0.9f * update[i] + learningRate * gradient[i]; + atomValues[i] = Math.min(Math.max(atomValues[i] - learningRate * gradient[i], 0.0f), 1.0f); + break; + case NONE: + atomValues[i] = Math.min(Math.max(atomValues[i] - learningRate * gradient[i], 0.0f), 1.0f); + break; + } + } + + oldObjectiveResult = objectiveResult; + objectiveResult = parallelComputeObjective(termStore); + + long endTime = System.currentTimeMillis(); + totalTime += System.currentTimeMillis() - startTime; + + breakGradientDescent = breakOptimization(iteration, termStore, objectiveResult, oldObjectiveResult); + + log.trace("Iteration {} -- Objective: {}, Iteration Time: {}, Total Optimization Time: {}.", + iteration, objectiveResult.objective, (endTime - startTime), totalTime); + + evaluate(termStore, iteration, evaluations, trainingMap); + + iteration++; + } + + optimizationComplete(termStore, parallelComputeObjective(termStore), totalTime); + return objectiveResult.objective; + } + + @Override + protected boolean breakOptimization(int iteration, TermStore termStore, + ObjectiveResult objective, ObjectiveResult oldObjective) { + if (super.breakOptimization(iteration, termStore, objective, oldObjective)) { + return true; + } + + // Run through the maximum number of iterations. + if (runFullIterations) { + return false; + } + + // Don't break if there are violated constraints. + if (objective != null && objective.violatedConstraints > 0) { + return false; + } + + // Break if the norm of the gradient is zero. + if (firstOrderBreak + && MathUtils.equals(MathUtils.pNorm(gradient, firstOrderNorm), 0.0f, firstOrderTolerance)) { + log.trace("Breaking optimization. Gradient magnitude: {} below tolerance: {}.", + MathUtils.pNorm(gradient, firstOrderNorm), firstOrderTolerance); + return true; + } + + return false; + } + + private float calculateAnnealedLearningRate(int iteration) { + switch (learningSchedule) { + case CONSTANT: + return initialLearningRate; + case STEPDECAY: + return initialLearningRate / ((float)Math.pow(iteration, learningRateInverseScaleExp)); + default: + throw new IllegalArgumentException(String.format("Illegal value found for gradient descent learning schedule: '%s'", learningSchedule)); + } + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentObjectiveTerm.java b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentObjectiveTerm.java new file mode 100644 index 000000000..e72e06983 --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentObjectiveTerm.java @@ -0,0 +1,41 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.reasoner.gradientdescent.term; + +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.model.rule.WeightedRule; +import org.linqs.psl.reasoner.function.FunctionComparator; +import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm; +import org.linqs.psl.reasoner.term.Hyperplane; +import org.linqs.psl.reasoner.term.ReasonerTerm; + +public class GradientDescentObjectiveTerm extends ReasonerTerm { + public GradientDescentObjectiveTerm(WeightedRule rule, boolean squared, boolean hinge, Hyperplane hyperplane) { + super(hyperplane, rule, squared, hinge, null); + } + + public GradientDescentObjectiveTerm(short size, float[] coefficients, float constant, int[] atomIndexes, + Rule rule, boolean squared, boolean hinge, FunctionComparator comparator) { + super(size, coefficients, constant, atomIndexes, rule, squared, hinge, comparator); + } + + @Override + public SGDObjectiveTerm copy() { + return new SGDObjectiveTerm(size, coefficients, constant, atomIndexes, rule, squared, hinge, comparator); + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermGenerator.java b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermGenerator.java new file mode 100644 index 000000000..035f07adb --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermGenerator.java @@ -0,0 +1,68 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.reasoner.gradientdescent.term; + +import org.linqs.psl.model.rule.GroundRule; +import org.linqs.psl.model.rule.WeightedGroundRule; +import org.linqs.psl.reasoner.function.FunctionComparator; +import org.linqs.psl.reasoner.term.Hyperplane; +import org.linqs.psl.reasoner.term.TermGenerator; +import org.linqs.psl.util.Logger; + +import java.util.Collection; + +/** + * A TermGenerator for GradientDescent objective terms. + */ +public class GradientDescentTermGenerator extends TermGenerator { + private static final Logger log = Logger.getLogger(GradientDescentTermGenerator.class); + + private boolean warnOnConstraint; + + public GradientDescentTermGenerator() { + this(true, true); + } + + public GradientDescentTermGenerator(boolean mergeConstants, boolean warnOnConstraint) { + super(mergeConstants); + this.warnOnConstraint = warnOnConstraint; + } + + public void setWarnOnConstraint(boolean warn) { + warnOnConstraint = warn; + } + + + @Override + public int createLossTerm(Collection newTerms, + boolean isHinge, boolean isSquared, GroundRule groundRule, Hyperplane hyperplane) { + newTerms.add(new GradientDescentObjectiveTerm(((WeightedGroundRule)groundRule).getRule(), isSquared, isHinge, hyperplane)); + return 1; + } + + @Override + public int createLinearConstraintTerm(Collection newTerms, + GroundRule groundRule, Hyperplane hyperplane, FunctionComparator comparator) { + if (warnOnConstraint) { + log.warn("GradientDescent does not support hard constraints, i.e. " + groundRule); + warnOnConstraint = false; + } + + return 0; + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermStore.java b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermStore.java new file mode 100644 index 000000000..c05c20781 --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/gradientdescent/term/GradientDescentTermStore.java @@ -0,0 +1,38 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.reasoner.gradientdescent.term; + +import org.linqs.psl.database.AtomStore; +import org.linqs.psl.reasoner.term.SimpleTermStore; + +public class GradientDescentTermStore extends SimpleTermStore { + public GradientDescentTermStore(AtomStore atomStore) { + super(atomStore, new GradientDescentTermGenerator()); + } + + @Override + public GradientDescentTermStore copy() { + GradientDescentTermStore gradientDescentTermStoreCopy = new GradientDescentTermStore(atomStore.copy()); + + for (GradientDescentObjectiveTerm term : allTerms) { + gradientDescentTermStoreCopy.add(term.copy()); + } + + return gradientDescentTermStoreCopy; + } +} diff --git a/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/GradientDescentInferenceTest.java b/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/GradientDescentInferenceTest.java new file mode 100644 index 000000000..ed92f3067 --- /dev/null +++ b/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/GradientDescentInferenceTest.java @@ -0,0 +1,73 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.inference.mpe; + +import org.junit.Before; +import org.junit.Test; +import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.application.inference.InferenceTest; +import org.linqs.psl.reasoner.gradientdescent.GradientDescentReasoner; +import org.linqs.psl.config.Options; +import org.linqs.psl.database.Database; +import org.linqs.psl.model.rule.Rule; + +import java.util.List; + +public class GradientDescentInferenceTest extends InferenceTest { + @Before + public void setup() { + Options.REASONER_OBJECTIVE_BREAK.set(false); + } + + @Override + protected InferenceApplication getInference(List rules, Database db) { + return new GradientDescentInference(rules, db); + } + + @Override + public void initialValueTest() { + // Skip this test in favor of specific GradientDescent variants. + } + + @Test + public void initialValueTestNoExtension() { + // No extension. + Options.GRADIENT_DESCENT_EXTENSION.set(GradientDescentReasoner.GradientDescentExtension.NONE); + super.initialValueTest(); + } + + @Test + public void initialValueTestNesterovAccelerationExtension() { + // Nesterov Acceleration. + Options.GRADIENT_DESCENT_EXTENSION.set(GradientDescentReasoner.GradientDescentExtension.NESTEROV_ACCELERATION); + super.initialValueTest(); + } + + @Test + public void initialValueTestMomentumExtension() { + // Momentum. + Options.GRADIENT_DESCENT_EXTENSION.set(GradientDescentReasoner.GradientDescentExtension.MOMENTUM); + super.initialValueTest(); + } + + @Override + public void testSimplexConstraints() { + Options.GRADIENT_DESCENT_LEARNING_RATE.set(0.01f); + super.testSimplexConstraints(); + } +} diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java index 29ae60798..d88387802 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java @@ -19,6 +19,7 @@ import org.linqs.psl.application.inference.mpe.ADMMInference; import org.linqs.psl.application.inference.mpe.DualBCDInference; +import org.linqs.psl.application.inference.mpe.GradientDescentInference; import org.linqs.psl.application.inference.mpe.SGDInference; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.WeightLearningTest; @@ -59,6 +60,13 @@ public void SGDFriendshipRankTest() { super.friendshipRankTest(); } + @Test + public void GradientDescentFriendshipRankTest() { + Options.WLA_INFERENCE.set(GradientDescentInference.class.getName()); + + super.friendshipRankTest(); + } + @Test public void DualBCDFriendshipRankTest() { Options.WLA_INFERENCE.set(DualBCDInference.class.getName());