-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(algorithm): support biased second order random walk #280
Changes from 12 commits
9b7ec3e
7d760ee
5d8eb98
3e4da8d
a2b6efc
0c14381
5914bff
77d3752
70d8f0f
548148e
b4b94b1
7d6e8ca
18f1260
9a93341
37d2aa7
8cc1734
94bd35b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,43 +17,90 @@ | |
|
||
package org.apache.hugegraph.computer.algorithm.sampling; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Iterator; | ||
import java.util.List; | ||
import java.util.Random; | ||
|
||
import org.apache.hugegraph.computer.core.common.exception.ComputerException; | ||
import org.apache.hugegraph.computer.core.config.Config; | ||
import org.apache.hugegraph.computer.core.graph.edge.Edge; | ||
import org.apache.hugegraph.computer.core.graph.edge.Edges; | ||
import org.apache.hugegraph.computer.core.graph.id.Id; | ||
import org.apache.hugegraph.computer.core.graph.value.DoubleValue; | ||
import org.apache.hugegraph.computer.core.graph.value.IdList; | ||
import org.apache.hugegraph.computer.core.graph.value.IdListList; | ||
import org.apache.hugegraph.computer.core.graph.value.Value; | ||
import org.apache.hugegraph.computer.core.graph.vertex.Vertex; | ||
import org.apache.hugegraph.computer.core.worker.Computation; | ||
import org.apache.hugegraph.computer.core.worker.ComputationContext; | ||
import org.apache.hugegraph.util.Log; | ||
import org.slf4j.Logger; | ||
|
||
import java.util.Iterator; | ||
import java.util.Random; | ||
|
||
public class RandomWalk implements Computation<RandomWalkMessage> { | ||
|
||
private static final Logger LOG = Log.logger(RandomWalk.class); | ||
|
||
public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node"; | ||
public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length"; | ||
public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node"; | ||
public static final String OPTION_WALK_LENGTH = "random_walk.walk_length"; | ||
|
||
public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property"; | ||
public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight"; | ||
public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold"; | ||
public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold"; | ||
|
||
public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor"; | ||
public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor"; | ||
|
||
/** | ||
* number of times per vertex(source vertex) walks | ||
* Random | ||
*/ | ||
private Random random; | ||
|
||
/** | ||
* Number of times per vertex(source vertex) walks | ||
*/ | ||
private Integer walkPerNode; | ||
|
||
/** | ||
* walk length | ||
* Walk length | ||
*/ | ||
private Integer walkLength; | ||
|
||
/** | ||
* random | ||
* Weight property, related to the walking probability | ||
*/ | ||
private Random random; | ||
private String weightProperty; | ||
|
||
/** | ||
* Biased walk | ||
* Default 1 | ||
*/ | ||
private Double defaultWeight; | ||
|
||
/** | ||
* Weight less than this threshold will be truncated. | ||
* Default 0 | ||
*/ | ||
private Double minWeightThreshold; | ||
|
||
/** | ||
* Weight greater than this threshold will be truncated. | ||
* Default Integer.MAX_VALUE | ||
*/ | ||
private Double maxWeightThreshold; | ||
|
||
/** | ||
* Controls the probability of re-walk to a previously walked vertex. | ||
* Default 1 | ||
*/ | ||
private Double returnFactor; | ||
|
||
/** | ||
* Controls whether to walk inward or outward. | ||
* Default 1 | ||
*/ | ||
private Double inOutFactor; | ||
|
||
@Override | ||
public String category() { | ||
|
@@ -67,23 +114,63 @@ | |
|
||
@Override | ||
public void init(Config config) { | ||
this.random = new Random(); | ||
|
||
this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3); | ||
if (this.walkPerNode <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 121 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L121
|
||
"actual got '%s'", | ||
OPTION_WALK_PER_NODE, this.walkPerNode); | ||
"actual got '%s'", | ||
OPTION_WALK_PER_NODE, this.walkPerNode); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode); | ||
|
||
this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3); | ||
if (this.walkLength <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 128 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L128
|
||
"actual got '%s'", | ||
OPTION_WALK_LENGTH, this.walkLength); | ||
"actual got '%s'", | ||
OPTION_WALK_LENGTH, this.walkLength); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength); | ||
|
||
this.random = new Random(); | ||
this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, ""); | ||
|
||
this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1); | ||
if (this.defaultWeight <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 137 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L137
|
||
"actual got '%s'", | ||
OPTION_DEFAULT_WEIGHT, this.defaultWeight); | ||
} | ||
|
||
this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0); | ||
if (this.minWeightThreshold < 0) { | ||
throw new ComputerException("The param %s must be greater than or equal 0, " + | ||
Check warning on line 144 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L144
|
||
"actual got '%s'", | ||
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold); | ||
} | ||
|
||
this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE); | ||
if (this.maxWeightThreshold < 0) { | ||
throw new ComputerException("The param %s must be greater than or equal 0, " + | ||
Check warning on line 151 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L151
|
||
"actual got '%s'", | ||
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold); | ||
} | ||
|
||
if (this.minWeightThreshold > this.maxWeightThreshold) { | ||
throw new ComputerException("%s must be greater than or equal %s, ", | ||
Check warning on line 157 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L157
|
||
OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD); | ||
} | ||
|
||
this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1); | ||
if (this.returnFactor <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 163 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L163
|
||
"actual got '%s'", | ||
OPTION_RETURN_FACTOR, this.returnFactor); | ||
} | ||
|
||
this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1); | ||
if (this.inOutFactor <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 170 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L170
|
||
"actual got '%s'", | ||
OPTION_INOUT_FACTOR, this.inOutFactor); | ||
} | ||
} | ||
|
||
@Override | ||
|
@@ -95,14 +182,16 @@ | |
|
||
if (vertex.numEdges() <= 0) { | ||
// isolated vertex | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
vertex.inactivate(); | ||
return; | ||
} | ||
|
||
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); | ||
|
||
for (int i = 0; i < walkPerNode; ++i) { | ||
// random select one edge and walk | ||
Edge selectedEdge = this.randomSelectEdge(vertex.edges()); | ||
Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges()); | ||
context.sendMessage(selectedEdge.targetId(), message); | ||
} | ||
} | ||
|
@@ -112,9 +201,11 @@ | |
Iterator<RandomWalkMessage> messages) { | ||
while (messages.hasNext()) { | ||
RandomWalkMessage message = messages.next(); | ||
// the last id of path is the previous id | ||
Id preVertexId = message.path().getLast(); | ||
|
||
if (message.isFinish()) { | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
|
||
vertex.inactivate(); | ||
continue; | ||
|
@@ -123,7 +214,7 @@ | |
message.addToPath(vertex); | ||
|
||
if (vertex.numEdges() <= 0) { | ||
// there is nowhere to walk,finish eariler | ||
// there is nowhere to walk, finish eariler | ||
message.finish(); | ||
context.sendMessage(this.getSourceId(message.path()), message); | ||
|
||
|
@@ -137,7 +228,7 @@ | |
|
||
if (vertex.id().equals(sourceId)) { | ||
// current vertex is the source vertex,no need to send message once more | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
} else { | ||
context.sendMessage(sourceId, message); | ||
} | ||
|
@@ -146,29 +237,129 @@ | |
continue; | ||
} | ||
|
||
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); | ||
|
||
// random select one edge and walk | ||
Edge selectedEdge = this.randomSelectEdge(vertex.edges()); | ||
Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(), | ||
vertex.edges()); | ||
context.sendMessage(selectedEdge.targetId(), message); | ||
} | ||
} | ||
|
||
/** | ||
* random select one edge | ||
*/ | ||
private Edge randomSelectEdge(Edges edges) { | ||
Edge selectedEdge = null; | ||
int randomNum = random.nextInt(edges.size()); | ||
private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) { | ||
List<Double> weightList = new ArrayList<>(); | ||
|
||
int i = 0; | ||
Iterator<Edge> iterator = edges.iterator(); | ||
while (iterator.hasNext()) { | ||
selectedEdge = iterator.next(); | ||
if (i == randomNum) { | ||
Edge edge = iterator.next(); | ||
// calculate edge weight | ||
double weight = this.getEdgeWeight(edge); | ||
Double finalWeight = this.calculateEdgeWeight(preVertexId, preVertexAdjacenceIdList, | ||
edge.targetId(), weight); | ||
weightList.add(finalWeight); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a mark here: |
||
} | ||
|
||
int selectedIndex = this.randomSelectIndex(weightList); | ||
Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex); | ||
return selectedEdge; | ||
} | ||
|
||
/** | ||
* get the weight of an edge by its weight property | ||
*/ | ||
private double getEdgeWeight(Edge edge) { | ||
Value property = edge.property(this.weightProperty); | ||
if (property == null) { | ||
property = new DoubleValue(this.defaultWeight); | ||
} | ||
|
||
if (!property.isNumber()) { | ||
throw new ComputerException("The value of %s must be a numeric value, " + | ||
Check warning on line 280 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L280
|
||
"actual got '%s'", | ||
this.weightProperty, property.string()); | ||
Check warning on line 282 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L282
|
||
} | ||
|
||
// weight threshold truncation | ||
DoubleValue weight = (DoubleValue) property; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. based on a double value like:
|
||
if (weight.doubleValue() < this.minWeightThreshold) { | ||
weight = new DoubleValue(this.minWeightThreshold); | ||
} | ||
if (weight.doubleValue() > this.maxWeightThreshold) { | ||
weight = new DoubleValue(this.maxWeightThreshold); | ||
Check warning on line 291 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L291
|
||
} | ||
return weight.doubleValue(); | ||
} | ||
|
||
/** | ||
* calculate edge weight | ||
*/ | ||
private Double calculateEdgeWeight(Id preVertexId, IdList preVertexAdjacenceIdList, | ||
Id nextVertexId, double weight) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also keep |
||
/* | ||
* 3 types of vertices. | ||
* 1. current vertex, called v | ||
* 2. previous vertex, called t | ||
* 3. current vertex outer vertex, called x(x1, x2.. xn) | ||
* | ||
* Definition of weight correction coefficient α: | ||
* if distance(t, x) = 0, then α = 1.0 / returnFactor | ||
* if distance(t, x) = 1, then α = 1.0 | ||
* if distance(t, x) = 2, then α = 1.0 / inOutFactor | ||
* | ||
* Final edge weight π(v, x) = α * edgeWeight | ||
*/ | ||
Double finalWeight = 0.0; | ||
if (preVertexId != null && preVertexId.equals(nextVertexId)) { | ||
// distance(t, x) = 0 | ||
finalWeight = 1.0 / this.returnFactor * weight; | ||
} else if (preVertexAdjacenceIdList != null && | ||
preVertexAdjacenceIdList.contains(nextVertexId)) { | ||
// distance(t, x) = 1 | ||
finalWeight = 1.0 * weight; | ||
} else { | ||
// distance(t, x) = 2 | ||
finalWeight = 1.0 / this.inOutFactor * weight; | ||
} | ||
return finalWeight; | ||
} | ||
|
||
/** | ||
* random select index | ||
*/ | ||
private int randomSelectIndex(List<Double> weightList) { | ||
int selectedIndex = 0; | ||
double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum(); | ||
double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight) | ||
|
||
// determine which interval the random number falls into | ||
double cumulativeWeight = 0; | ||
for (int i = 0; i < weightList.size(); ++i) { | ||
cumulativeWeight += weightList.get(i); | ||
if (randomNum < cumulativeWeight) { | ||
selectedIndex = i; | ||
break; | ||
} | ||
i++; | ||
} | ||
return selectedIndex; | ||
} | ||
|
||
/** | ||
* select edge from iterator by index | ||
*/ | ||
private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) { | ||
Edge selectedEdge = null; | ||
|
||
int index = 0; | ||
while (iterator.hasNext()) { | ||
selectedEdge = iterator.next(); | ||
if (index == selectedIndex) { | ||
break; | ||
} | ||
index++; | ||
} | ||
return selectedEdge; | ||
} | ||
|
||
|
@@ -177,7 +368,7 @@ | |
*/ | ||
private Id getSourceId(IdList path) { | ||
// the first id of path is the source id | ||
return path.get(0); | ||
return path.getFirst(); | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a mark here:
TODO: use primitive array instead, like DoubleArray, in order to reduce memory fragmentation generated during calculations
the same to https://github.com/search?q=repo%3Aapache%2Fincubator-hugegraph-computer+path%3A%2F%5Ecomputer-algorithm%5C%2F%2F++new+ArrayList&type=code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean
double[]
? I'll submit a new issue for this and try to fix it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but we can use third-party libraries such as eclipse collections: https://eclipse.dev/collections/javadoc/11.1.0/org/eclipse/collections/api/list/primitive/DoubleList.html
some libs of primitive collections:
we also used them in https://github.com/apache/incubator-hugegraph/blob/25301f62288293e53c852f9c4eeb116a3a201594/hugegraph-server/hugegraph-core/src/main/java/org/apache/hugegraph/util/collection/CollectionFactory.java#L43
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome