Skip to content

Commit

Permalink
Gradient descent reasoner.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Sep 16, 2023
1 parent 6c203f2 commit ba8faf9
Show file tree
Hide file tree
Showing 9 changed files with 577 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.inference.mpe;

import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.gradientdescent.GradientDescentReasoner;
import org.linqs.psl.reasoner.gradientdescent.term.GradientDescentTermStore;
import org.linqs.psl.reasoner.term.TermStore;

import java.util.List;

/**
* Use an GradientDescent reasoner to perform MPE inference.
*/
public class GradientDescentInference extends MPEInference {
public GradientDescentInference(List<Rule> rules, Database db) {
super(rules, db, true);
}

@Override
protected Reasoner createReasoner() {
return new GradientDescentReasoner();
}

@Override
public TermStore createTermStore() {
return new GradientDescentTermStore(database.getAtomStore());
}
}
73 changes: 68 additions & 5 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.linqs.psl.evaluation.statistics.DiscreteEvaluator;
import org.linqs.psl.evaluation.statistics.AUCEvaluator;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.gradientdescent.GradientDescentReasoner;
import org.linqs.psl.reasoner.sgd.SGDReasoner;
import org.linqs.psl.util.SystemUtils;

Expand Down Expand Up @@ -547,7 +548,7 @@ public class Options {

public static final Option INFERENCE_RELAX_SQUARED = new Option(
"inference.relax.squared",
true,
false,
"When relaxing a hard constraint into a soft one, this determines if the resulting weighted rule is squared."
);

Expand Down Expand Up @@ -769,10 +770,66 @@ public class Options {
"If true, run the suite of evaluators specified for the post-inference evaluation stage at regular intervals during inference."
);

public static final Option REASONER_OBJECTIVE_BREAK = new Option(
"reasoner.objectivebreak",
false,
"Stop if the objective has not changed since the last iteration (or logging period)."
public static final Option GRADIENT_DESCENT_EXTENSION = new Option(
"reasoner.gradientdescent.extension",
GradientDescentReasoner.GradientDescentExtension.NONE.toString(),
"The GD extension to use for GD reasoning."
+ " NONE (Default): The standard GD optimizer takes steps in the direction of the negative gradient scaled by the learning rate."
+ " MOMENTUM: Modify the descent direction with a momentum term."
+ " NESTEROV_ACCELERATION: Use the Nesterov accelerated gradient method."
);

public static final Option GRADIENT_DESCENT_FIRST_ORDER_BREAK = new Option(
"reasoner.gradientdescent.firstorderbreak",
true,
"Stop gradient descent when the norm of the gradient is less than reasoner.gradientdescent.firstorderthreshold."
);

public static final Option GRADIENT_DESCENT_FIRST_ORDER_NORM = new Option(
"reasoner.gradientdescent.firstordernorm",
Float.POSITIVE_INFINITY,
"The p-norm used to measure the first order optimality condition."
+ " Default is the infinity-norm which is the absolute value of the maximum component of the gradient vector."
+ " Note that the infinity-norm can be explicitly set with the string literal: 'Infinity'.",
Option.FLAG_NON_NEGATIVE
);

public static final Option GRADIENT_DESCENT_FIRST_ORDER_THRESHOLD = new Option(
"reasoner.gradientdescent.firstorderthreshold",
0.01f,
"Gradient descent stops when the norm of the gradient is less than this threshold.",
Option.FLAG_NON_NEGATIVE
);

public static final Option GRADIENT_DESCENT_INVERSE_TIME_EXP = new Option(
"reasoner.gradientdescent.inversescaleexp",
1.0f,
"If GradientDescent is using the STEPDECAY learning schedule, then this value is the negative"
+ " exponent of the iteration count which scales the gradient step using:"
+ " (learning_rate / ( iteration ^ - GRADIENT_DESCENT_INVERSE_TIME_EXP)).",
Option.FLAG_POSITIVE
);

public static final Option GRADIENT_DESCENT_LEARNING_RATE = new Option(
"reasoner.gradientdescent.learningrate",
0.1f,
"The learning rate for gradient descent inference.",
Option.FLAG_POSITIVE
);

public static final Option GRADIENT_DESCENT_LEARNING_SCHEDULE = new Option(
"reasoner.gradientdescent.learningschedule",
GradientDescentReasoner.GradientDescentLearningSchedule.CONSTANT.toString(),
"The learning schedule of the GradientDescent inference reasoner changes the learning rate during learning."
+ " STEPDECAY (Default): Decay the learning rate like: learningRate / (n_epoch^p) where p is set by reasoner.gradientdescent.inversescaleexp."
+ " CONSTANT: The learning rate is constant during learning."
);

public static final Option GRADIENT_DESCENT_MAX_ITER = new Option(
"reasoner.gradientdescent.maxiterations",
2500,
"The maximum number of iterations of Gradient Descent to perform in a round of inference.",
Option.FLAG_POSITIVE
);

public static final Option REASONER_RUN_FULL_ITERATIONS = new Option(
Expand All @@ -781,6 +838,12 @@ public class Options {
"Ignore all other stopping criteria and run until the maximum number of iterations."
);

public static final Option REASONER_OBJECTIVE_BREAK = new Option(
"reasoner.objectivebreak",
false,
"Stop if the objective has not changed since the last iteration (or logging period)."
);

public static final Option REASONER_OBJECTIVE_TOLERANCE = new Option(
"reasoner.objectivetolerance",
1e-5f,
Expand Down
20 changes: 19 additions & 1 deletion psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public void parallelComputeGradient(TermStore<T> termStore, float[] rvAtomGradie
Arrays.fill(rvAtomGradient, 0.0f);
Arrays.fill(deepAtomGradient, 0.0f);
for(int j = 0; j < numTermBlocks; j++) {
for(int i = 0; i < termStore.getAtomStore().getMaxRVAIndex(); i++) {
for(int i = 0; i <= termStore.getAtomStore().getMaxRVAIndex(); i++) {
rvAtomGradient[i] += workerRVAtomGradients[j][i];
deepAtomGradient[i] += workerDeepGradients[j][i];
}
Expand All @@ -247,6 +247,24 @@ protected void clipGradient(float[] variableValues, float[] gradient) {
}
}

/**
* Clip (sub)gradient magnitude.
*/
protected void clipGradientMagnitude(float[] gradient, float maxMagnitude) {
float maxGradient = 0.0f;
for (int i = 0; i < gradient.length; i++) {
if (Math.abs(gradient[i]) > maxGradient) {
maxGradient = Math.abs(gradient[i]);
}
}

if (maxGradient > maxMagnitude) {
for (int i = 0; i < gradient.length; i++) {
gradient[i] = gradient[i] * maxMagnitude / maxGradient;
}
}
}

/**
* Compute the total weighted objective of the terms in their current state.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2021 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.reasoner.gradientdescent;

import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.gradientdescent.term.GradientDescentObjectiveTerm;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

import java.util.List;

/**
* Uses a Gradient optimization method to optimize its GroundRules.
*/
public class GradientDescentReasoner extends Reasoner<GradientDescentObjectiveTerm> {
private static final Logger log = Logger.getLogger(GradientDescentReasoner.class);

/**
* The Gradient Descent Extension to use.
*/
public enum GradientDescentExtension {
NONE,
MOMENTUM,
NESTEROV_ACCELERATION
}

/**
* The Gradient Descent learning schedule to use.
*/
public static enum GradientDescentLearningSchedule {
CONSTANT,
STEPDECAY
}

private final boolean firstOrderBreak;
private final float firstOrderTolerance;
private final float firstOrderNorm;

private float[] gradient;

private final float initialLearningRate;
private final float learningRateInverseScaleExp;
private final GradientDescentReasoner.GradientDescentLearningSchedule learningSchedule;
private final GradientDescentExtension gdExtension;

public GradientDescentReasoner() {
maxIterations = Options.GRADIENT_DESCENT_MAX_ITER.getInt();
firstOrderBreak = Options.GRADIENT_DESCENT_FIRST_ORDER_BREAK.getBoolean();
firstOrderTolerance = Options.GRADIENT_DESCENT_FIRST_ORDER_THRESHOLD.getFloat();
firstOrderNorm = Options.GRADIENT_DESCENT_FIRST_ORDER_NORM.getFloat();

gradient = null;

gdExtension = GradientDescentExtension.valueOf(Options.GRADIENT_DESCENT_EXTENSION.getString().toUpperCase());

initialLearningRate = Options.GRADIENT_DESCENT_LEARNING_RATE.getFloat();
learningRateInverseScaleExp = Options.GRADIENT_DESCENT_INVERSE_TIME_EXP.getFloat();
learningSchedule = GradientDescentReasoner.GradientDescentLearningSchedule.valueOf(Options.GRADIENT_DESCENT_LEARNING_SCHEDULE.getString().toUpperCase());
}

@Override
public double optimize(TermStore<GradientDescentObjectiveTerm> termStore, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
termStore.initForOptimization();
initForOptimization(termStore);

// Return if there are no decision variables.
boolean hasDecisionVariables = false;
for (GroundAtom atom : termStore.getAtomStore()) {
if (!(atom.isFixed())) {
hasDecisionVariables = true;
break;
}
}

if (!hasDecisionVariables){
log.trace("No random variable atoms to optimize.");
return parallelComputeObjective(termStore).objective;
}

// Return if there are no terms.
if (termStore.size() == 0) {
log.trace("No terms to optimize.");
return parallelComputeObjective(termStore).objective;
}

float learningRate = 0.0f;

GroundAtom[] atoms = termStore.getAtomStore().getAtoms();
float[] atomValues = termStore.getAtomStore().getAtomValues();
float[] update = new float[termStore.getAtomStore().size()];
gradient = new float[termStore.getAtomStore().size()];
float[] deepAtomGradients = new float[termStore.getAtomStore().size()];
ObjectiveResult objectiveResult = parallelComputeObjective(termStore);
ObjectiveResult oldObjectiveResult = null;

long totalTime = 0;
boolean breakGradientDescent = false;
int iteration = 1;
while (!breakGradientDescent) {
long startTime = System.currentTimeMillis();

learningRate = calculateAnnealedLearningRate(iteration);

if (gdExtension == GradientDescentExtension.NESTEROV_ACCELERATION) {
for (int i = 0; i < gradient.length; i++) {
if (atoms[i].isFixed()) {
continue;
}

atomValues[i] = Math.min(Math.max(atomValues[i] - 0.9f * update[i], 0.0f), 1.0f);
}
}

parallelComputeGradient(termStore, gradient, deepAtomGradients);
clipGradientMagnitude(gradient, 1.0f);

for (int i = 0; i < gradient.length; i++) {
if (atoms[i].isFixed()) {
continue;
}

switch (gdExtension) {
case MOMENTUM:
update[i] = 0.9f * update[i] + learningRate * gradient[i];
atomValues[i] = Math.min(Math.max(atomValues[i] - update[i], 0.0f), 1.0f);
break;
case NESTEROV_ACCELERATION:
update[i] = 0.9f * update[i] + learningRate * gradient[i];
atomValues[i] = Math.min(Math.max(atomValues[i] - learningRate * gradient[i], 0.0f), 1.0f);
break;
case NONE:
atomValues[i] = Math.min(Math.max(atomValues[i] - learningRate * gradient[i], 0.0f), 1.0f);
break;
}
}

oldObjectiveResult = objectiveResult;
objectiveResult = parallelComputeObjective(termStore);

long endTime = System.currentTimeMillis();
totalTime += System.currentTimeMillis() - startTime;

breakGradientDescent = breakOptimization(iteration, termStore, objectiveResult, oldObjectiveResult);

log.trace("Iteration {} -- Objective: {}, Iteration Time: {}, Total Optimization Time: {}.",
iteration, objectiveResult.objective, (endTime - startTime), totalTime);

evaluate(termStore, iteration, evaluations, trainingMap);

iteration++;
}

optimizationComplete(termStore, parallelComputeObjective(termStore), totalTime);
return objectiveResult.objective;
}

@Override
protected boolean breakOptimization(int iteration, TermStore<GradientDescentObjectiveTerm> termStore,
ObjectiveResult objective, ObjectiveResult oldObjective) {
if (super.breakOptimization(iteration, termStore, objective, oldObjective)) {
return true;
}

// Run through the maximum number of iterations.
if (runFullIterations) {
return false;
}

// Don't break if there are violated constraints.
if (objective != null && objective.violatedConstraints > 0) {
return false;
}

// Break if the norm of the gradient is zero.
if (firstOrderBreak
&& MathUtils.equals(MathUtils.pNorm(gradient, firstOrderNorm), 0.0f, firstOrderTolerance)) {
log.trace("Breaking optimization. Gradient magnitude: {} below tolerance: {}.",
MathUtils.pNorm(gradient, firstOrderNorm), firstOrderTolerance);
return true;
}

return false;
}

private float calculateAnnealedLearningRate(int iteration) {
switch (learningSchedule) {
case CONSTANT:
return initialLearningRate;
case STEPDECAY:
return initialLearningRate / ((float)Math.pow(iteration, learningRateInverseScaleExp));
default:
throw new IllegalArgumentException(String.format("Illegal value found for gradient descent learning schedule: '%s'", learningSchedule));
}
}
}
Loading

0 comments on commit ba8faf9

Please sign in to comment.