Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(algorithm): support biased second order random walk #280

Merged
merged 17 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,90 @@

package org.apache.hugegraph.computer.algorithm.sampling;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import org.apache.hugegraph.computer.core.common.exception.ComputerException;
import org.apache.hugegraph.computer.core.config.Config;
import org.apache.hugegraph.computer.core.graph.edge.Edge;
import org.apache.hugegraph.computer.core.graph.edge.Edges;
import org.apache.hugegraph.computer.core.graph.id.Id;
import org.apache.hugegraph.computer.core.graph.value.DoubleValue;
import org.apache.hugegraph.computer.core.graph.value.IdList;
import org.apache.hugegraph.computer.core.graph.value.IdListList;
import org.apache.hugegraph.computer.core.graph.value.Value;
import org.apache.hugegraph.computer.core.graph.vertex.Vertex;
import org.apache.hugegraph.computer.core.worker.Computation;
import org.apache.hugegraph.computer.core.worker.ComputationContext;
import org.apache.hugegraph.util.Log;
import org.slf4j.Logger;

import java.util.Iterator;
import java.util.Random;

public class RandomWalk implements Computation<RandomWalkMessage> {

private static final Logger LOG = Log.logger(RandomWalk.class);

public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node";
public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length";
public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node";
public static final String OPTION_WALK_LENGTH = "random_walk.walk_length";

public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property";
public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight";
public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold";
public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold";

public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor";
public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor";

/**
* number of times per vertex(source vertex) walks
* Random
*/
private Random random;

/**
* Number of times per vertex(source vertex) walks
*/
private Integer walkPerNode;

/**
* walk length
* Walk length
*/
private Integer walkLength;

/**
* random
* Weight property, related to the walking probability
*/
private Random random;
private String weightProperty;

/**
* Biased walk
* Default 1
*/
private Double defaultWeight;

/**
* Weight less than this threshold will be truncated.
* Default 0
*/
private Double minWeightThreshold;

/**
* Weight greater than this threshold will be truncated.
* Default Integer.MAX_VALUE
*/
private Double maxWeightThreshold;

/**
* Controls the probability of re-walk to a previously walked vertex.
* Default 1
*/
private Double returnFactor;

/**
* Controls whether to walk inward or outward.
* Default 1
*/
private Double inOutFactor;

@Override
public String category() {
Expand All @@ -67,23 +114,63 @@

@Override
public void init(Config config) {
this.random = new Random();

this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3);
if (this.walkPerNode <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 121 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L121

Added line #L121 was not covered by tests
"actual got '%s'",
OPTION_WALK_PER_NODE, this.walkPerNode);
"actual got '%s'",
OPTION_WALK_PER_NODE, this.walkPerNode);
}
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode);

this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3);
if (this.walkLength <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 128 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L128

Added line #L128 was not covered by tests
"actual got '%s'",
OPTION_WALK_LENGTH, this.walkLength);
"actual got '%s'",
OPTION_WALK_LENGTH, this.walkLength);
}
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength);

this.random = new Random();
this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, "");

this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1);
if (this.defaultWeight <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 137 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L137

Added line #L137 was not covered by tests
"actual got '%s'",
OPTION_DEFAULT_WEIGHT, this.defaultWeight);
}

this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0);
if (this.minWeightThreshold < 0) {
throw new ComputerException("The param %s must be greater than or equal 0, " +

Check warning on line 144 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L144

Added line #L144 was not covered by tests
"actual got '%s'",
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold);
}

this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE);
if (this.maxWeightThreshold < 0) {
throw new ComputerException("The param %s must be greater than or equal 0, " +

Check warning on line 151 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L151

Added line #L151 was not covered by tests
"actual got '%s'",
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold);
}

if (this.minWeightThreshold > this.maxWeightThreshold) {
throw new ComputerException("%s must be greater than or equal %s, ",

Check warning on line 157 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L157

Added line #L157 was not covered by tests
OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD);
}

this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1);
if (this.returnFactor <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 163 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L163

Added line #L163 was not covered by tests
"actual got '%s'",
OPTION_RETURN_FACTOR, this.returnFactor);
}

this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1);
if (this.inOutFactor <= 0) {
throw new ComputerException("The param %s must be greater than 0, " +

Check warning on line 170 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L170

Added line #L170 was not covered by tests
"actual got '%s'",
OPTION_INOUT_FACTOR, this.inOutFactor);
}
}

@Override
Expand All @@ -95,14 +182,16 @@

if (vertex.numEdges() <= 0) {
// isolated vertex
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());
vertex.inactivate();
return;
}

vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));

for (int i = 0; i < walkPerNode; ++i) {
// random select one edge and walk
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges());
context.sendMessage(selectedEdge.targetId(), message);
}
}
Expand All @@ -112,9 +201,11 @@
Iterator<RandomWalkMessage> messages) {
while (messages.hasNext()) {
RandomWalkMessage message = messages.next();
// the last id of path is the previous id
Id preVertexId = message.path().getLast();

if (message.isFinish()) {
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());

vertex.inactivate();
continue;
Expand All @@ -123,7 +214,7 @@
message.addToPath(vertex);

if (vertex.numEdges() <= 0) {
// there is nowhere to walkfinish eariler
// there is nowhere to walk, finish eariler
message.finish();
context.sendMessage(this.getSourceId(message.path()), message);

Expand All @@ -137,7 +228,7 @@

if (vertex.id().equals(sourceId)) {
// current vertex is the source vertex,no need to send message once more
this.savePath(vertex, message.path()); // save result
this.savePath(vertex, message.path());
} else {
context.sendMessage(sourceId, message);
}
Expand All @@ -146,29 +237,129 @@
continue;
}

vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));

// random select one edge and walk
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(),
vertex.edges());
context.sendMessage(selectedEdge.targetId(), message);
}
}

/**
* random select one edge
*/
private Edge randomSelectEdge(Edges edges) {
Edge selectedEdge = null;
int randomNum = random.nextInt(edges.size());
private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) {
List<Double> weightList = new ArrayList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a mark here: TODO: use primitive array instead, like DoubleArray, in order to reduce memory fragmentation generated during calculations
the same to https://github.com/search?q=repo%3Aapache%2Fincubator-hugegraph-computer+path%3A%2F%5Ecomputer-algorithm%5C%2F%2F++new+ArrayList&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean double[]? I'll submit a new issue for this and try to fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome


int i = 0;
Iterator<Edge> iterator = edges.iterator();
while (iterator.hasNext()) {
selectedEdge = iterator.next();
if (i == randomNum) {
Edge edge = iterator.next();
// calculate edge weight
double weight = this.getEdgeWeight(edge);
Double finalWeight = this.calculateEdgeWeight(preVertexId, preVertexAdjacenceIdList,
edge.targetId(), weight);
weightList.add(finalWeight);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a mark here: TODO: improve to avoid OOM

}

int selectedIndex = this.randomSelectIndex(weightList);
Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex);
return selectedEdge;
}

/**
* get the weight of an edge by its weight property
*/
private double getEdgeWeight(Edge edge) {
Value property = edge.property(this.weightProperty);
if (property == null) {
property = new DoubleValue(this.defaultWeight);
}

if (!property.isNumber()) {
throw new ComputerException("The value of %s must be a numeric value, " +

Check warning on line 280 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L280

Added line #L280 was not covered by tests
"actual got '%s'",
this.weightProperty, property.string());

Check warning on line 282 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L282

Added line #L282 was not covered by tests
}

// weight threshold truncation
DoubleValue weight = (DoubleValue) property;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on a double value like:

  1. double weight = this.defaultWeight;
  2. weight = property..doubleValue() if checked ok
  3. do truncation
  4. return weight

if (weight.doubleValue() < this.minWeightThreshold) {
weight = new DoubleValue(this.minWeightThreshold);
}
if (weight.doubleValue() > this.maxWeightThreshold) {
weight = new DoubleValue(this.maxWeightThreshold);

Check warning on line 291 in computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

View check run for this annotation

Codecov / codecov/patch

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java#L291

Added line #L291 was not covered by tests
}
return weight.doubleValue();
}

/**
* calculate edge weight
*/
private Double calculateEdgeWeight(Id preVertexId, IdList preVertexAdjacenceIdList,
Id nextVertexId, double weight) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also keep double finalWeight and return double?

/*
* 3 types of vertices.
* 1. current vertex, called v
* 2. previous vertex, called t
* 3. current vertex outer vertex, called x(x1, x2.. xn)
*
* Definition of weight correction coefficient α:
* if distance(t, x) = 0, then α = 1.0 / returnFactor
* if distance(t, x) = 1, then α = 1.0
* if distance(t, x) = 2, then α = 1.0 / inOutFactor
*
* Final edge weight π(v, x) = α * edgeWeight
*/
Double finalWeight = 0.0;
if (preVertexId != null && preVertexId.equals(nextVertexId)) {
// distance(t, x) = 0
finalWeight = 1.0 / this.returnFactor * weight;
} else if (preVertexAdjacenceIdList != null &&
preVertexAdjacenceIdList.contains(nextVertexId)) {
// distance(t, x) = 1
finalWeight = 1.0 * weight;
} else {
// distance(t, x) = 2
finalWeight = 1.0 / this.inOutFactor * weight;
}
return finalWeight;
}

/**
* random select index
*/
private int randomSelectIndex(List<Double> weightList) {
int selectedIndex = 0;
double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum();
double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight)

// determine which interval the random number falls into
double cumulativeWeight = 0;
for (int i = 0; i < weightList.size(); ++i) {
cumulativeWeight += weightList.get(i);
if (randomNum < cumulativeWeight) {
selectedIndex = i;
break;
}
i++;
}
return selectedIndex;
}

/**
* select edge from iterator by index
*/
private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) {
Edge selectedEdge = null;

int index = 0;
while (iterator.hasNext()) {
selectedEdge = iterator.next();
if (index == selectedIndex) {
break;
}
index++;
}
return selectedEdge;
}

Expand All @@ -177,7 +368,7 @@
*/
private Id getSourceId(IdList path) {
// the first id of path is the source id
return path.get(0);
return path.getFirst();
}

/**
Expand Down
Loading
Loading