-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(algorithm): support biased second order random walk #280
Changes from 10 commits
9b7ec3e
7d760ee
5d8eb98
3e4da8d
a2b6efc
0c14381
5914bff
77d3752
70d8f0f
548148e
b4b94b1
7d6e8ca
18f1260
9a93341
37d2aa7
8cc1734
94bd35b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,43 +17,90 @@ | |
|
||
package org.apache.hugegraph.computer.algorithm.sampling; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Iterator; | ||
import java.util.List; | ||
import java.util.Random; | ||
|
||
import org.apache.hugegraph.computer.core.common.exception.ComputerException; | ||
import org.apache.hugegraph.computer.core.config.Config; | ||
import org.apache.hugegraph.computer.core.graph.edge.Edge; | ||
import org.apache.hugegraph.computer.core.graph.edge.Edges; | ||
import org.apache.hugegraph.computer.core.graph.id.Id; | ||
import org.apache.hugegraph.computer.core.graph.value.DoubleValue; | ||
import org.apache.hugegraph.computer.core.graph.value.IdList; | ||
import org.apache.hugegraph.computer.core.graph.value.IdListList; | ||
import org.apache.hugegraph.computer.core.graph.value.Value; | ||
import org.apache.hugegraph.computer.core.graph.vertex.Vertex; | ||
import org.apache.hugegraph.computer.core.worker.Computation; | ||
import org.apache.hugegraph.computer.core.worker.ComputationContext; | ||
import org.apache.hugegraph.util.Log; | ||
import org.slf4j.Logger; | ||
|
||
import java.util.Iterator; | ||
import java.util.Random; | ||
|
||
public class RandomWalk implements Computation<RandomWalkMessage> { | ||
|
||
private static final Logger LOG = Log.logger(RandomWalk.class); | ||
|
||
public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node"; | ||
public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length"; | ||
public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node"; | ||
public static final String OPTION_WALK_LENGTH = "random_walk.walk_length"; | ||
|
||
public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property"; | ||
public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight"; | ||
public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold"; | ||
public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold"; | ||
|
||
public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor"; | ||
public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor"; | ||
|
||
/** | ||
* Random | ||
*/ | ||
private Random random; | ||
|
||
/** | ||
* number of times per vertex(source vertex) walks | ||
* Number of times per vertex(source vertex) walks | ||
*/ | ||
private Integer walkPerNode; | ||
|
||
/** | ||
* walk length | ||
* Walk length | ||
*/ | ||
private Integer walkLength; | ||
|
||
/** | ||
* random | ||
* Weight property, related to the walking probability | ||
*/ | ||
private Random random; | ||
private String weightProperty; | ||
|
||
/** | ||
* Biased walk | ||
* Default 1 | ||
*/ | ||
private Double defaultWeight; | ||
|
||
/** | ||
* Weight less than this threshold will be truncated. | ||
* Default 0 | ||
*/ | ||
private Double minWeightThreshold; | ||
|
||
/** | ||
* Weight greater than this threshold will be truncated. | ||
* Default Integer.MAX_VALUE | ||
*/ | ||
private Double maxWeightThreshold; | ||
|
||
/** | ||
* Controls the probability of re-walk to a previously walked vertex. | ||
* Default 1 | ||
*/ | ||
private Double returnFactor; | ||
|
||
/** | ||
* Controls whether to walk inward or outward. | ||
* Default 1 | ||
*/ | ||
private Double inOutFactor; | ||
|
||
@Override | ||
public String category() { | ||
|
@@ -67,23 +114,77 @@ | |
|
||
@Override | ||
public void init(Config config) { | ||
this.random = new Random(); | ||
|
||
this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3); | ||
if (this.walkPerNode <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 121 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L121
|
||
"actual got '%s'", | ||
OPTION_WALK_PER_NODE, this.walkPerNode); | ||
"actual got '%s'", | ||
OPTION_WALK_PER_NODE, this.walkPerNode); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode); | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, this.walkPerNode); | ||
|
||
this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3); | ||
if (this.walkLength <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 129 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L129
|
||
"actual got '%s'", | ||
OPTION_WALK_LENGTH, this.walkLength); | ||
"actual got '%s'", | ||
OPTION_WALK_LENGTH, this.walkLength); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength); | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, this.walkLength); | ||
|
||
this.random = new Random(); | ||
this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, ""); | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_WEIGHT_PROPERTY, this.weightProperty); | ||
|
||
this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1); | ||
if (this.defaultWeight <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 141 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L141
|
||
"actual got '%s'", | ||
OPTION_DEFAULT_WEIGHT, this.defaultWeight); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_DEFAULT_WEIGHT, this.defaultWeight); | ||
|
||
this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0); | ||
if (this.minWeightThreshold < 0) { | ||
throw new ComputerException("The param %s must be greater than or equal 0, " + | ||
Check warning on line 150 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L150
|
||
"actual got '%s'", | ||
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold); | ||
|
||
this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE); | ||
if (this.maxWeightThreshold < 0) { | ||
throw new ComputerException("The param %s must be greater than or equal 0, " + | ||
Check warning on line 159 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L159
|
||
"actual got '%s'", | ||
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold); | ||
|
||
if (this.minWeightThreshold > this.maxWeightThreshold) { | ||
throw new ComputerException("%s must be greater than or equal %s, ", | ||
Check warning on line 167 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L167
|
||
OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD); | ||
} | ||
|
||
this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1); | ||
if (this.returnFactor <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 173 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L173
|
||
"actual got '%s'", | ||
OPTION_RETURN_FACTOR, this.returnFactor); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_RETURN_FACTOR, this.returnFactor); | ||
|
||
this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1); | ||
if (this.inOutFactor <= 0) { | ||
throw new ComputerException("The param %s must be greater than 0, " + | ||
Check warning on line 182 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L182
|
||
"actual got '%s'", | ||
OPTION_INOUT_FACTOR, this.inOutFactor); | ||
} | ||
LOG.info("[RandomWalk] algorithm param, {}: {}", | ||
OPTION_INOUT_FACTOR, this.inOutFactor); | ||
} | ||
|
||
@Override | ||
|
@@ -95,14 +196,16 @@ | |
|
||
if (vertex.numEdges() <= 0) { | ||
// isolated vertex | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
vertex.inactivate(); | ||
return; | ||
} | ||
|
||
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); | ||
|
||
for (int i = 0; i < walkPerNode; ++i) { | ||
// random select one edge and walk | ||
Edge selectedEdge = this.randomSelectEdge(vertex.edges()); | ||
Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges()); | ||
context.sendMessage(selectedEdge.targetId(), message); | ||
} | ||
} | ||
|
@@ -112,9 +215,11 @@ | |
Iterator<RandomWalkMessage> messages) { | ||
while (messages.hasNext()) { | ||
RandomWalkMessage message = messages.next(); | ||
// the last id of path is the previous id | ||
Id preVertexId = message.path().getLast(); | ||
|
||
if (message.isFinish()) { | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
|
||
vertex.inactivate(); | ||
continue; | ||
|
@@ -123,7 +228,7 @@ | |
message.addToPath(vertex); | ||
|
||
if (vertex.numEdges() <= 0) { | ||
// there is nowhere to walk,finish eariler | ||
// there is nowhere to walk, finish eariler | ||
message.finish(); | ||
context.sendMessage(this.getSourceId(message.path()), message); | ||
|
||
|
@@ -137,7 +242,7 @@ | |
|
||
if (vertex.id().equals(sourceId)) { | ||
// current vertex is the source vertex,no need to send message once more | ||
this.savePath(vertex, message.path()); // save result | ||
this.savePath(vertex, message.path()); | ||
} else { | ||
context.sendMessage(sourceId, message); | ||
} | ||
|
@@ -146,29 +251,128 @@ | |
continue; | ||
} | ||
|
||
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); | ||
|
||
// random select one edge and walk | ||
Edge selectedEdge = this.randomSelectEdge(vertex.edges()); | ||
Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(), | ||
vertex.edges()); | ||
context.sendMessage(selectedEdge.targetId(), message); | ||
} | ||
} | ||
|
||
/** | ||
* random select one edge | ||
*/ | ||
private Edge randomSelectEdge(Edges edges) { | ||
Edge selectedEdge = null; | ||
int randomNum = random.nextInt(edges.size()); | ||
private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) { | ||
List<Double> weightList = new ArrayList<>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a mark here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but we can use third-party libraries such as eclipse collections: https://eclipse.dev/collections/javadoc/11.1.0/org/eclipse/collections/api/list/primitive/DoubleList.html some libs of primitive collections:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. awesome |
||
|
||
int i = 0; | ||
Iterator<Edge> iterator = edges.iterator(); | ||
while (iterator.hasNext()) { | ||
selectedEdge = iterator.next(); | ||
if (i == randomNum) { | ||
Edge edge = iterator.next(); | ||
// calculate weight | ||
Value weight = this.getWeight(edge); | ||
Double finalWeight = this.calculateWeight(preVertexId, preVertexAdjacenceIdList, | ||
edge.targetId(), weight); | ||
weightList.add(finalWeight); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a mark here: |
||
} | ||
|
||
int selectedIndex = this.randomSelectIndex(weightList); | ||
Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex); | ||
return selectedEdge; | ||
} | ||
|
||
/** | ||
* get edge weight by weight property | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Get the weight of a edge by its weight property" |
||
*/ | ||
private Value getWeight(Edge edge) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer getEdgeWeight() |
||
Value weight = edge.property(this.weightProperty); | ||
if (weight == null) { | ||
weight = new DoubleValue(this.defaultWeight); | ||
} | ||
|
||
if (!weight.isNumber()) { | ||
throw new ComputerException("The value of %s must be a numeric value, " + | ||
Check warning on line 294 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L294
|
||
"actual got '%s'", | ||
this.weightProperty, weight.string()); | ||
Check warning on line 296 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L296
|
||
} | ||
|
||
// weight threshold truncation | ||
if ((Double) weight.value() < this.minWeightThreshold) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we call weight.doubleValue() here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Property value may not be a numeric value. |
||
weight = new DoubleValue(this.minWeightThreshold); | ||
} | ||
if ((Double) weight.value() > this.maxWeightThreshold) { | ||
weight = new DoubleValue(this.maxWeightThreshold); | ||
Check warning on line 304 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java Codecov / codecov/patchcomputer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L304
|
||
} | ||
return weight; | ||
} | ||
|
||
/** | ||
* calculate edge weight | ||
*/ | ||
private Double calculateWeight(Id preVertexId, IdList preVertexAdjacenceIdList, | ||
Id nextVertexId, Value weight) { | ||
/* | ||
* 3 types of vertices. | ||
* 1. current vertex, called v | ||
* 2. previous vertex, called t | ||
* 3. current vertex outer vertex, called x(x1, x2.. xn) | ||
* | ||
* Definition of weight correction coefficient α: | ||
* if distance(t, x) = 0, then α = 1.0 / returnFactor | ||
* if distance(t, x) = 1, then α = 1.0 | ||
* if distance(t, x) = 2, then α = 1.0 / inOutFactor | ||
* | ||
* Final edge weight π(v, x) = α * edgeWeight | ||
*/ | ||
Double finalWeight = 0.0; | ||
if (preVertexId != null && preVertexId.equals(nextVertexId)) { | ||
// distance(t, x) = 0 | ||
finalWeight = 1.0 / this.returnFactor * (Double) weight.value(); | ||
} else if (preVertexAdjacenceIdList != null && | ||
preVertexAdjacenceIdList.contains(nextVertexId)) { | ||
// distance(t, x) = 1 | ||
finalWeight = 1.0 * (Double) weight.value(); | ||
} else { | ||
// distance(t, x) = 2 | ||
finalWeight = 1.0 / this.inOutFactor * (Double) weight.value(); | ||
} | ||
return finalWeight; | ||
} | ||
|
||
/** | ||
* random select index | ||
*/ | ||
private int randomSelectIndex(List<Double> weightList) { | ||
int selectedIndex = 0; | ||
double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum(); | ||
double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight) | ||
|
||
// determine which interval the random number falls into | ||
double cumulativeWeight = 0; | ||
for (int i = 0; i < weightList.size(); ++i) { | ||
cumulativeWeight += weightList.get(i); | ||
if (randomNum < cumulativeWeight) { | ||
selectedIndex = i; | ||
break; | ||
} | ||
i++; | ||
} | ||
return selectedIndex; | ||
} | ||
|
||
/** | ||
* select edge from iterator by index | ||
*/ | ||
private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) { | ||
Edge selectedEdge = null; | ||
|
||
int index = 0; | ||
while (iterator.hasNext()) { | ||
selectedEdge = iterator.next(); | ||
if (index == selectedIndex) { | ||
break; | ||
} | ||
index++; | ||
} | ||
return selectedEdge; | ||
} | ||
|
||
|
@@ -177,7 +381,7 @@ | |
*/ | ||
private Id getSourceId(IdList path) { | ||
// the first id of path is the source id | ||
return path.get(0); | ||
return path.getFirst(); | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add a common method like
logAlgorithmParam(name, value)
, and just use thethis.name()
as logged algorithm nameThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, I feel that the logs of algorithm param here should be placed in the framework. I'll remove thoes logs.