Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(algorithm): support biased second order random walk #280

Merged
merged 17 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,90 @@

package org.apache.hugegraph.computer.algorithm.sampling;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import org.apache.hugegraph.computer.core.common.exception.ComputerException;
import org.apache.hugegraph.computer.core.config.Config;
import org.apache.hugegraph.computer.core.graph.edge.Edge;
import org.apache.hugegraph.computer.core.graph.edge.Edges;
import org.apache.hugegraph.computer.core.graph.id.Id;
import org.apache.hugegraph.computer.core.graph.value.DoubleValue;
import org.apache.hugegraph.computer.core.graph.value.IdList;
import org.apache.hugegraph.computer.core.graph.value.IdListList;
import org.apache.hugegraph.computer.core.graph.value.Value;
import org.apache.hugegraph.computer.core.graph.vertex.Vertex;
import org.apache.hugegraph.computer.core.worker.Computation;
import org.apache.hugegraph.computer.core.worker.ComputationContext;
import org.apache.hugegraph.util.Log;
import org.slf4j.Logger;

import java.util.Iterator;
import java.util.Random;

public class RandomWalk implements Computation<RandomWalkMessage> {

private static final Logger LOG = Log.logger(RandomWalk.class);

public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node";
public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length";
public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node";
public static final String OPTION_WALK_LENGTH = "random_walk.walk_length";

public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property";
public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight";
public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold";
public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold";

public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor";
public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor";

/**
* Random
*/
private Random random;

/**
* number of times per vertex(source vertex) walks
* Number of times per vertex(source vertex) walks
*/
private Integer walkPerNode;

/**
* walk length
* Walk length
*/
private Integer walkLength;

/**
* random
* Weight property, related to the walking probability
*/
private Random random;
private String weightProperty;

/**
* Biased walk
* Default 1
*/
private Double defaultWeight;

/**
* Weight less than this threshold will be truncated.
* Default 0
*/
private Double minWeightThreshold;

/**
* Weight greater than this threshold will be truncated.
* Default Integer.MAX_VALUE
*/
private Double maxWeightThreshold;

/**
* Controls the probability of re-walk to a previously walked vertex.
* Default 1
*/
private Double returnFactor;

/**
* Controls whether to walk inward or outward.
* Default 1
*/
private Double inOutFactor;

@Override
public String category() {
Expand All @@ -67,23 +114,77 @@

@Override
public void init(Config config) {
this.random = new Random();

this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3);
if (this.walkPerNode <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 121 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L121

Added line #L121 was not covered by tests
"actual got '%s'",
OPTION_WALK_PER_NODE, this.walkPerNode);
"actual got '%s'",
OPTION_WALK_PER_NODE, this.walkPerNode);
}
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode);
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, this.walkPerNode);

this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3);
if (this.walkLength <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 129 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L129

Added line #L129 was not covered by tests
"actual got '%s'",
OPTION_WALK_LENGTH, this.walkLength);
"actual got '%s'",
OPTION_WALK_LENGTH, this.walkLength);
}
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength);
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, this.walkLength);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a common method like logAlgorithmParam(name, value), and just use the this.name() as logged algorithm name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I feel that the logs of algorithm param here should be placed in the framework. I'll remove thoes logs.


this.random = new Random();
this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, "");
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_WEIGHT_PROPERTY, this.weightProperty);

this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1);
if (this.defaultWeight <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 141 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L141

Added line #L141 was not covered by tests
"actual got '%s'",
OPTION_DEFAULT_WEIGHT, this.defaultWeight);
}
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_DEFAULT_WEIGHT, this.defaultWeight);

this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0);
if (this.minWeightThreshold < 0) {
throw new ComputerException("The param %s must be greater than or equal 0, " +

Check warning on line 150 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L150

Added line #L150 was not covered by tests
"actual got '%s'",
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold);
}
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold);

this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE);
if (this.maxWeightThreshold < 0) {
throw new ComputerException("The param %s must be greater than or equal 0, " +

Check warning on line 159 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L159

Added line #L159 was not covered by tests
"actual got '%s'",
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold);
}
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold);

if (this.minWeightThreshold > this.maxWeightThreshold) {
throw new ComputerException("%s must be greater than or equal %s, ",

Check warning on line 167 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L167

Added line #L167 was not covered by tests
OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD);
}

this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1);
if (this.returnFactor <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 173 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L173

Added line #L173 was not covered by tests
"actual got '%s'",
OPTION_RETURN_FACTOR, this.returnFactor);
}
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_RETURN_FACTOR, this.returnFactor);

this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1);
if (this.inOutFactor <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 182 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L182

Added line #L182 was not covered by tests
"actual got '%s'",
OPTION_INOUT_FACTOR, this.inOutFactor);
}
LOG.info("[RandomWalk] algorithm param, {}: {}",
OPTION_INOUT_FACTOR, this.inOutFactor);
}

@Override
Expand All @@ -95,14 +196,16 @@

if (vertex.numEdges() <= 0) {
// isolated vertex
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());
vertex.inactivate();
return;
}

vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));

for (int i = 0; i < walkPerNode; ++i) {
// random select one edge and walk
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges());
context.sendMessage(selectedEdge.targetId(), message);
}
}
Expand All @@ -112,9 +215,11 @@
Iterator<RandomWalkMessage> messages) {
while (messages.hasNext()) {
RandomWalkMessage message = messages.next();
// the last id of path is the previous id
Id preVertexId = message.path().getLast();

if (message.isFinish()) {
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());

vertex.inactivate();
continue;
Expand All @@ -123,7 +228,7 @@
message.addToPath(vertex);

if (vertex.numEdges() <= 0) {
// there is nowhere to walkfinish eariler
// there is nowhere to walk, finish eariler
message.finish();
context.sendMessage(this.getSourceId(message.path()), message);

Expand All @@ -137,7 +242,7 @@

if (vertex.id().equals(sourceId)) {
// current vertex is the source vertex,no need to send message once more
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());
} else {
context.sendMessage(sourceId, message);
}
Expand All @@ -146,29 +251,128 @@
continue;
}

vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));

// random select one edge and walk
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(),
vertex.edges());
context.sendMessage(selectedEdge.targetId(), message);
}
}

/**
* random select one edge
*/
private Edge randomSelectEdge(Edges edges) {
Edge selectedEdge = null;
int randomNum = random.nextInt(edges.size());
private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) {
List<Double> weightList = new ArrayList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a mark here: TODO: use primitive array instead, like DoubleArray, in order to reduce memory fragmentation generated during calculations
the same to https://github.com/search?q=repo%3Aapache%2Fincubator-hugegraph-computer+path%3A%2F%5Ecomputer-algorithm%5C%2F%2F++new+ArrayList&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean double[]? I'll submit a new issue for this and try to fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome


int i = 0;
Iterator<Edge> iterator = edges.iterator();
while (iterator.hasNext()) {
selectedEdge = iterator.next();
if (i == randomNum) {
Edge edge = iterator.next();
// calculate weight
Value weight = this.getWeight(edge);
Double finalWeight = this.calculateWeight(preVertexId, preVertexAdjacenceIdList,
edge.targetId(), weight);
weightList.add(finalWeight);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a mark here: TODO: improve to avoid OOM

}

int selectedIndex = this.randomSelectIndex(weightList);
Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex);
return selectedEdge;
}

/**
* get edge weight by weight property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Get the weight of a edge by its weight property"

*/
private Value getWeight(Edge edge) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer getEdgeWeight()

Value weight = edge.property(this.weightProperty);
if (weight == null) {
weight = new DoubleValue(this.defaultWeight);
}

if (!weight.isNumber()) {
throw new ComputerException("The value of %s must be a numeric value, " +

Check warning on line 294 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L294

Added line #L294 was not covered by tests
"actual got '%s'",
this.weightProperty, weight.string());

Check warning on line 296 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L296

Added line #L296 was not covered by tests
}

// weight threshold truncation
if ((Double) weight.value() < this.minWeightThreshold) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call weight.doubleValue() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Property value may not be a numeric value.

weight = new DoubleValue(this.minWeightThreshold);
}
if ((Double) weight.value() > this.maxWeightThreshold) {
weight = new DoubleValue(this.maxWeightThreshold);

Check warning on line 304 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L304

Added line #L304 was not covered by tests
}
return weight;
}

/**
* calculate edge weight
*/
private Double calculateWeight(Id preVertexId, IdList preVertexAdjacenceIdList,
Id nextVertexId, Value weight) {
/*
* 3 types of vertices.
* 1. current vertex, called v
* 2. previous vertex, called t
* 3. current vertex outer vertex, called x(x1, x2.. xn)
*
* Definition of weight correction coefficient α:
* if distance(t, x) = 0, then α = 1.0 / returnFactor
* if distance(t, x) = 1, then α = 1.0
* if distance(t, x) = 2, then α = 1.0 / inOutFactor
*
* Final edge weight π(v, x) = α * edgeWeight
*/
Double finalWeight = 0.0;
if (preVertexId != null && preVertexId.equals(nextVertexId)) {
// distance(t, x) = 0
finalWeight = 1.0 / this.returnFactor * (Double) weight.value();
} else if (preVertexAdjacenceIdList != null &&
preVertexAdjacenceIdList.contains(nextVertexId)) {
// distance(t, x) = 1
finalWeight = 1.0 * (Double) weight.value();
} else {
// distance(t, x) = 2
finalWeight = 1.0 / this.inOutFactor * (Double) weight.value();
}
return finalWeight;
}

/**
* random select index
*/
private int randomSelectIndex(List<Double> weightList) {
int selectedIndex = 0;
double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum();
double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight)

// determine which interval the random number falls into
double cumulativeWeight = 0;
for (int i = 0; i < weightList.size(); ++i) {
cumulativeWeight += weightList.get(i);
if (randomNum < cumulativeWeight) {
selectedIndex = i;
break;
}
i++;
}
return selectedIndex;
}

/**
* select edge from iterator by index
*/
private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) {
Edge selectedEdge = null;

int index = 0;
while (iterator.hasNext()) {
selectedEdge = iterator.next();
if (index == selectedIndex) {
break;
}
index++;
}
return selectedEdge;
}

Expand All @@ -177,7 +381,7 @@
*/
private Id getSourceId(IdList path) {
// the first id of path is the source id
return path.get(0);
return path.getFirst();
}

/**
Expand Down
Loading
Loading