Skip to content

Commit

Permalink
Store grounded deep weight rules in abstract rule class.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed May 13, 2024
1 parent db6544f commit 73ac8d4
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public abstract class WeightLearningApplication implements ModelApplication {

protected List<Rule> allRules;
protected List<WeightedRule> mutableRules;
protected List<WeightedRule> deepRules;

protected TrainingMap trainingMap;
protected TrainingMap validationMap;
Expand Down Expand Up @@ -106,12 +107,17 @@ public WeightLearningApplication(List<Rule> rules, Database trainTargetDatabase,

allRules = new ArrayList<Rule>();
mutableRules = new ArrayList<WeightedRule>();
deepRules = new ArrayList<WeightedRule>();

for (Rule rule : rules) {
allRules.add(rule);

if (rule instanceof WeightedRule) {
mutableRules.add((WeightedRule)rule);
if (((WeightedRule) rule).getWeight().isDeep()) {
mutableRules.add((WeightedRule) rule);
} else {
deepRules.add((WeightedRule) rule);
}
}
}

Expand Down
37 changes: 35 additions & 2 deletions psl-core/src/main/java/org/linqs/psl/grounding/Grounding.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
import org.linqs.psl.grounding.collective.Coverage;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.AbstractRule;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.rule.arithmetic.AbstractArithmeticRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.reasoner.term.TermStore;
Expand Down Expand Up @@ -60,11 +63,41 @@ public static void setGroundRuleCallback(GroundRuleCallback groundRuleCallback)

public static long groundAll(List<Rule> rules, TermStore termStore, Database database) {
boolean collective = Options.GROUNDING_COLLECTIVE.getBoolean();

long termCount = 0;
if (collective) {
return groundCollective(rules, termStore, database);
termCount = groundCollective(rules, termStore, database);
} else {
termCount = groundIndependent(rules, termStore, database);
}

// Substitute the rules with deepWeights for their grounded versions.
List<Rule> childrenDeepWeightRules = new ArrayList<Rule>();
List<Rule> parentDeepWeightRules = new ArrayList<Rule>();
for (Rule rule : rules) {
if (rule instanceof AbstractRule) {
AbstractRule abstractRule = (AbstractRule) rule;

if (abstractRule.getChildHashCodes().isEmpty()) {
// This is a base rule.
continue;
}

parentDeepWeightRules.add(rule);

for (int childHashCode : abstractRule.getChildHashCodes()) {
Rule childRule = AbstractRule.getRule(childHashCode);

childrenDeepWeightRules.add(childRule);
}
}
}

return groundIndependent(rules, termStore, database);
// Remove the parent rules from the list of rules and add the child rules.
rules.removeAll(parentDeepWeightRules);
rules.addAll(childrenDeepWeightRules);

return termCount;
}

/**
Expand Down
27 changes: 27 additions & 0 deletions psl-core/src/main/java/org/linqs/psl/model/rule/AbstractRule.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.linqs.psl.model.rule;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Base class for all (first order, i.e., not ground) rules.
Expand All @@ -31,6 +33,9 @@ public abstract class AbstractRule implements Rule {
protected Boolean active;
protected int hashcode;

protected int parentHashCode;
protected Set<Integer> childHashCodes;

public static Rule getRule(int hashcode) {
return rules.get(hashcode);
}
Expand All @@ -44,6 +49,9 @@ protected AbstractRule() {
this.name = null;
this.hashcode = 0;
this.active = true;

this.parentHashCode = hashcode;
this.childHashCodes = new HashSet<Integer>();
}

protected AbstractRule(String name, int hashcode) {
Expand All @@ -52,6 +60,9 @@ protected AbstractRule(String name, int hashcode) {
this.active = true;

ensureRegistration();

this.parentHashCode = hashcode;
this.childHashCodes = new HashSet<Integer>();
}

public boolean isActive() {
Expand All @@ -66,6 +77,22 @@ public String getName() {
return this.name;
}

public int getParentHashCode() {
return parentHashCode;
}

public void setParentHashCode(int parentHashCode) {
this.parentHashCode = parentHashCode;
}

public Set<Integer> getChildHashCodes() {
return childHashCodes;
}

public void addChildHashCode(int childHashCode) {
this.childHashCodes.add(childHashCode);
}

private static void registerRule(Rule rule) {
rules.put(rule.hashCode(), rule);
}
Expand Down
12 changes: 10 additions & 2 deletions psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.DeepPredicate;

/**
* A weight for a rule.
Expand Down Expand Up @@ -78,7 +77,7 @@ public Atom getAtom() {
/**
* Returns whether the term is constant or if it is a function of an atom.
*/
public boolean isConstant() {
public boolean isDeep() {
return atom == null;
}

Expand All @@ -89,4 +88,13 @@ public String toString() {
return Float.toString(constantValue);
}
}

public int hashCode() {
// Use the hash of the atom if it exists. Else, use the object's hash.
if (atom != null) {
return atom.hashCode();
} else {
return super.hashCode();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.linqs.psl.model.term.VariableTypeMap;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.HashCode;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
Expand Down Expand Up @@ -84,8 +85,8 @@ public abstract class AbstractArithmeticRule extends AbstractRule {

private volatile boolean validatedByDatabase;

public AbstractArithmeticRule(ArithmeticRuleExpression expression, Map<SummationVariable, Formula> filterClauses, String name) {
super(name, expression.hashCode());
public AbstractArithmeticRule(ArithmeticRuleExpression expression, Map<SummationVariable, Formula> filterClauses, String name, int hashcode) {
super(name, hashcode);
this.expression = expression;
this.filters = filterClauses;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void parseExpression(ArithmeticRuleExpression expression, boolean compute
argumentBuffer[i] = new Constant[queryAtoms.get(i).getArity()];
}

if ((weight != null) && !(weight.isConstant())) {
if ((weight != null) && !(weight.isDeep())) {
assert (weight.getAtom() instanceof QueryAtom);

weightQueryAtom = (QueryAtom)weight.getAtom();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public UnweightedArithmeticRule(ArithmeticRuleExpression expression, Map<Summati
}

public UnweightedArithmeticRule(ArithmeticRuleExpression expression, Map<SummationVariable, Formula> filterClauses, String name) {
super(expression, filterClauses, name);
super(expression, filterClauses, name, expression.hashCode());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.util.HashCode;
import org.linqs.psl.util.Parallel;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -53,7 +55,7 @@ public WeightedArithmeticRule(
ArithmeticRuleExpression expression, Map<SummationVariable, Formula> filterClauses,
Weight weight, boolean squared, String name
) {
super(expression, filterClauses, name);
super(expression, filterClauses, name, HashCode.build(expression.hashCode(), weight));

this.weight = weight;
this.squared = squared;
Expand All @@ -66,10 +68,18 @@ protected AbstractGroundArithmeticRule makeGroundRule(
if (groundedWeight == null) {
return new WeightedGroundArithmeticRule(this, coeffs, atoms, comparator, constant);
} else {
// Create grounded expression.
ArithmeticRuleExpression newExpression = new ArithmeticRuleExpression(
expression.getAtomCoefficients(), Arrays.asList(atoms), comparator, expression.getFinalCoefficient()
);

WeightedArithmeticRule groundedDeepWeightedRule = new WeightedArithmeticRule(
expression, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name
newExpression, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name
);

groundedDeepWeightedRule.setParentHashCode(hashCode());
addChildHashCode(groundedDeepWeightedRule.hashCode());

return new WeightedGroundArithmeticRule(groundedDeepWeightedRule, coeffs, atoms, comparator, constant);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ protected AbstractLogicalRule(Formula formula, String name) {
}

this.hashcode = hash;
this.parentHashCode = hash;

ensureRegistration();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.arithmetic.AbstractArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.WeightedArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtomOrAtom;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Coefficient;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.model.rule.logical.WeightedLogicalRule;
import org.linqs.psl.model.term.UniqueStringID;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.test.PSLBaseTest;
import org.linqs.psl.test.TestModel;
import org.linqs.psl.util.MathUtils;
Expand All @@ -42,6 +49,7 @@
import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -304,6 +312,34 @@ public void ruleWithNoGroundingsTest() {
weightLearner.close();
}

@Test
public void ruleWithObservedAtomWeight() {
Rule newRule;
List<String> expected;
List<Coefficient> coefficients;
List<SummationAtomOrAtom> atoms;

// Nice(A): Friends(A, B) >= 1.0 ^2
coefficients = Arrays.asList(
(Coefficient)(new ConstantNumber(1))
);

atoms = Arrays.asList(
(SummationAtomOrAtom)(new QueryAtom(info.predicates.get("Friends"), new Variable("A"), new Variable("B")))
);

newRule = new WeightedArithmeticRule(
new ArithmeticRuleExpression(coefficients, atoms, FunctionComparator.GTE, new ConstantNumber(1)),
new Weight(1.0f, new QueryAtom(info.predicates.get("Nice"), new Variable("A"))),
true
);
info.model.addRule(newRule);

WeightLearningApplication weightLearner = getWLA();
weightLearner.learn();
weightLearner.close();
}

/**
* Assert that the rules (specified by the keys on the rule map) are in the same order as passed in.
* The order should be ascending.
Expand Down

0 comments on commit 73ac8d4

Please sign in to comment.