diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java index 2fdf2dde2..a74cd4888 100644 --- a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java @@ -17,43 +17,90 @@ package org.apache.hugegraph.computer.algorithm.sampling; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + import org.apache.hugegraph.computer.core.common.exception.ComputerException; import org.apache.hugegraph.computer.core.config.Config; import org.apache.hugegraph.computer.core.graph.edge.Edge; import org.apache.hugegraph.computer.core.graph.edge.Edges; import org.apache.hugegraph.computer.core.graph.id.Id; +import org.apache.hugegraph.computer.core.graph.value.DoubleValue; import org.apache.hugegraph.computer.core.graph.value.IdList; import org.apache.hugegraph.computer.core.graph.value.IdListList; +import org.apache.hugegraph.computer.core.graph.value.Value; import org.apache.hugegraph.computer.core.graph.vertex.Vertex; import org.apache.hugegraph.computer.core.worker.Computation; import org.apache.hugegraph.computer.core.worker.ComputationContext; import org.apache.hugegraph.util.Log; import org.slf4j.Logger; -import java.util.Iterator; -import java.util.Random; - public class RandomWalk implements Computation { private static final Logger LOG = Log.logger(RandomWalk.class); - public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node"; - public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length"; + public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node"; + public static final String OPTION_WALK_LENGTH = "random_walk.walk_length"; + + public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property"; + public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight"; + public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold"; + public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold"; + + public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor"; + public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor"; /** - * number of times per vertex(source vertex) walks + * Random + */ + private Random random; + + /** + * Number of times per vertex(source vertex) walks */ private Integer walkPerNode; /** - * walk length + * Walk length */ private Integer walkLength; /** - * random + * Weight property, related to the walking probability */ - private Random random; + private String weightProperty; + + /** + * Biased walk + * Default 1 + */ + private Double defaultWeight; + + /** + * Weight less than this threshold will be truncated. + * Default 0 + */ + private Double minWeightThreshold; + + /** + * Weight greater than this threshold will be truncated. + * Default Integer.MAX_VALUE + */ + private Double maxWeightThreshold; + + /** + * Controls the probability of re-walk to a previously walked vertex. + * Default 1 + */ + private Double returnFactor; + + /** + * Controls whether to walk inward or outward. + * Default 1 + */ + private Double inOutFactor; @Override public String category() { @@ -67,23 +114,63 @@ public String name() { @Override public void init(Config config) { + this.random = new Random(); + this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3); if (this.walkPerNode <= 0) { throw new ComputerException("The param %s must be greater than 0, " + - "actual got '%s'", - OPTION_WALK_PER_NODE, this.walkPerNode); + "actual got '%s'", + OPTION_WALK_PER_NODE, this.walkPerNode); } - LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode); this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3); if (this.walkLength <= 0) { throw new ComputerException("The param %s must be greater than 0, " + - "actual got '%s'", - OPTION_WALK_LENGTH, this.walkLength); + "actual got '%s'", + OPTION_WALK_LENGTH, this.walkLength); } - LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength); - this.random = new Random(); + this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, ""); + + this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1); + if (this.defaultWeight <= 0) { + throw new ComputerException("The param %s must be greater than 0, " + + "actual got '%s'", + OPTION_DEFAULT_WEIGHT, this.defaultWeight); + } + + this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0); + if (this.minWeightThreshold < 0) { + throw new ComputerException("The param %s must be greater than or equal 0, " + + "actual got '%s'", + OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold); + } + + this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE); + if (this.maxWeightThreshold < 0) { + throw new ComputerException("The param %s must be greater than or equal 0, " + + "actual got '%s'", + OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold); + } + + if (this.minWeightThreshold > this.maxWeightThreshold) { + throw new ComputerException("%s must be greater than or equal %s, ", + OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD); + } + + this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1); + if (this.returnFactor <= 0) { + throw new ComputerException("The param %s must be greater than 0, " + + "actual got '%s'", + OPTION_RETURN_FACTOR, this.returnFactor); + } + + this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1); + if (this.inOutFactor <= 0) { + throw new ComputerException("The param %s must be greater than 0, " + + "actual got '%s'", + OPTION_INOUT_FACTOR, this.inOutFactor); + } } @Override @@ -95,14 +182,16 @@ public void compute0(ComputationContext context, Vertex vertex) { if (vertex.numEdges() <= 0) { // isolated vertex - this.savePath(vertex, message.path()); // save result + this.savePath(vertex, message.path()); vertex.inactivate(); return; } + vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); + for (int i = 0; i < walkPerNode; ++i) { // random select one edge and walk - Edge selectedEdge = this.randomSelectEdge(vertex.edges()); + Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges()); context.sendMessage(selectedEdge.targetId(), message); } } @@ -112,9 +201,11 @@ public void compute(ComputationContext context, Vertex vertex, Iterator messages) { while (messages.hasNext()) { RandomWalkMessage message = messages.next(); + // the last id of path is the previous id + Id preVertexId = message.path().getLast(); if (message.isFinish()) { - this.savePath(vertex, message.path()); // save result + this.savePath(vertex, message.path()); vertex.inactivate(); continue; @@ -123,7 +214,7 @@ public void compute(ComputationContext context, Vertex vertex, message.addToPath(vertex); if (vertex.numEdges() <= 0) { - // there is nowhere to walk,finish eariler + // there is nowhere to walk, finish eariler message.finish(); context.sendMessage(this.getSourceId(message.path()), message); @@ -137,7 +228,7 @@ public void compute(ComputationContext context, Vertex vertex, if (vertex.id().equals(sourceId)) { // current vertex is the source vertex,no need to send message once more - this.savePath(vertex, message.path()); // save result + this.savePath(vertex, message.path()); } else { context.sendMessage(sourceId, message); } @@ -146,8 +237,11 @@ public void compute(ComputationContext context, Vertex vertex, continue; } + vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId())); + // random select one edge and walk - Edge selectedEdge = this.randomSelectEdge(vertex.edges()); + Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(), + vertex.edges()); context.sendMessage(selectedEdge.targetId(), message); } } @@ -155,20 +249,121 @@ public void compute(ComputationContext context, Vertex vertex, /** * random select one edge */ - private Edge randomSelectEdge(Edges edges) { - Edge selectedEdge = null; - int randomNum = random.nextInt(edges.size()); + private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) { + // TODO: use primitive array instead, like DoubleArray, + // in order to reduce memory fragmentation generated during calculations + List weightList = new ArrayList<>(); - int i = 0; Iterator iterator = edges.iterator(); while (iterator.hasNext()) { - selectedEdge = iterator.next(); - if (i == randomNum) { + Edge edge = iterator.next(); + // calculate edge weight + double weight = this.getEdgeWeight(edge); + double finalWeight = this.calculateEdgeWeight(preVertexId, preVertexAdjacenceIdList, + edge.targetId(), weight); + // TODO: improve to avoid OOM + weightList.add(finalWeight); + } + + int selectedIndex = this.randomSelectIndex(weightList); + Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex); + return selectedEdge; + } + + /** + * get the weight of an edge by its weight property + */ + private double getEdgeWeight(Edge edge) { + double weight = this.defaultWeight; + + Value property = edge.property(this.weightProperty); + if (property != null) { + if (!property.isNumber()) { + throw new ComputerException("The value of %s must be a numeric value, " + + "actual got '%s'", + this.weightProperty, property.string()); + } + + weight = ((DoubleValue) property).doubleValue(); + } + + // weight threshold truncation + if (weight < this.minWeightThreshold) { + weight = this.minWeightThreshold; + } + if (weight > this.maxWeightThreshold) { + weight = this.maxWeightThreshold; + } + return weight; + } + + /** + * calculate edge weight + */ + private double calculateEdgeWeight(Id preVertexId, IdList preVertexAdjacenceIdList, + Id nextVertexId, double weight) { + /* + * 3 types of vertices. + * 1. current vertex, called v + * 2. previous vertex, called t + * 3. current vertex outer vertex, called x(x1, x2.. xn) + * + * Definition of weight correction coefficient α: + * if distance(t, x) = 0, then α = 1.0 / returnFactor + * if distance(t, x) = 1, then α = 1.0 + * if distance(t, x) = 2, then α = 1.0 / inOutFactor + * + * Final edge weight π(v, x) = α * edgeWeight + */ + double finalWeight = 0.0; + if (preVertexId != null && preVertexId.equals(nextVertexId)) { + // distance(t, x) = 0 + finalWeight = 1.0 / this.returnFactor * weight; + } else if (preVertexAdjacenceIdList != null && + preVertexAdjacenceIdList.contains(nextVertexId)) { + // distance(t, x) = 1 + finalWeight = 1.0 * weight; + } else { + // distance(t, x) = 2 + finalWeight = 1.0 / this.inOutFactor * weight; + } + return finalWeight; + } + + /** + * random select index + */ + private int randomSelectIndex(List weightList) { + int selectedIndex = 0; + double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum(); + double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight) + + // determine which interval the random number falls into + double cumulativeWeight = 0; + for (int i = 0; i < weightList.size(); ++i) { + cumulativeWeight += weightList.get(i); + if (randomNum < cumulativeWeight) { + selectedIndex = i; break; } - i++; } + return selectedIndex; + } + + /** + * select edge from iterator by index + */ + private Edge selectEdge(Iterator iterator, int selectedIndex) { + Edge selectedEdge = null; + int index = 0; + while (iterator.hasNext()) { + selectedEdge = iterator.next(); + if (index == selectedIndex) { + break; + } + index++; + } return selectedEdge; } diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java index bf32ee75c..6d92781ac 100644 --- a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java @@ -17,6 +17,10 @@ package org.apache.hugegraph.computer.algorithm.sampling; +import java.io.IOException; +import java.util.List; + +import org.apache.hugegraph.computer.core.graph.id.Id; import org.apache.hugegraph.computer.core.graph.value.BooleanValue; import org.apache.hugegraph.computer.core.graph.value.IdList; import org.apache.hugegraph.computer.core.graph.value.Value; @@ -24,34 +28,39 @@ import org.apache.hugegraph.computer.core.io.RandomAccessInput; import org.apache.hugegraph.computer.core.io.RandomAccessOutput; -import java.io.IOException; -import java.util.List; - public class RandomWalkMessage implements Value.CustomizeValue> { /** - * random walk path + * Previous vertex adjacent(out edge) vertex id list + */ + private final IdList preVertexAdjacence; + + /** + * Random walk path */ private final IdList path; /** - * finish flag + * Finish flag */ private BooleanValue isFinish; public RandomWalkMessage() { + this.preVertexAdjacence = new IdList(); this.path = new IdList(); this.isFinish = new BooleanValue(false); } @Override public void read(RandomAccessInput in) throws IOException { + this.preVertexAdjacence.read(in); this.path.read(in); this.isFinish.read(in); } @Override public void write(RandomAccessOutput out) throws IOException { + this.preVertexAdjacence.write(out); this.path.write(out); this.isFinish.write(out); } @@ -61,6 +70,14 @@ public List value() { return this.path.value(); } + public IdList preVertexAdjacence() { + return this.preVertexAdjacence; + } + + public void addToPreVertexAdjacence(Id vertexId) { + this.preVertexAdjacence.add(vertexId); + } + public IdList path() { return this.path; } diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java index ad43d5bd7..47ef2d58f 100644 --- a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java @@ -17,23 +17,23 @@ package org.apache.hugegraph.computer.algorithm.sampling; +import java.util.ArrayList; +import java.util.List; + import org.apache.hugegraph.computer.core.graph.value.IdListList; import org.apache.hugegraph.computer.core.graph.vertex.Vertex; import org.apache.hugegraph.computer.core.output.hg.HugeGraphOutput; -import java.util.ArrayList; -import java.util.List; - public class RandomWalkOutput extends HugeGraphOutput> { @Override protected void prepareSchema() { this.client().schema().propertyKey(this.name()) - .asText() - .writeType(this.writeType()) - .valueList() - .ifNotExist() - .create(); + .asText() + .writeType(this.writeType()) + .valueList() + .ifNotExist() + .create(); } @Override diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java index 273d7fd67..a8d9fd817 100644 --- a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java @@ -17,24 +17,26 @@ package org.apache.hugegraph.computer.algorithm.sampling; +import java.util.Map; + import org.apache.hugegraph.computer.algorithm.AlgorithmParams; import org.apache.hugegraph.computer.core.config.ComputerOptions; import org.apache.hugegraph.computer.core.graph.value.IdListList; -import java.util.Map; - public class RandomWalkParams implements AlgorithmParams { @Override public void setAlgorithmParameters(Map params) { this.setIfAbsent(params, ComputerOptions.WORKER_COMPUTATION_CLASS, - RandomWalk.class.getName()); + RandomWalk.class.getName()); this.setIfAbsent(params, ComputerOptions.ALGORITHM_MESSAGE_CLASS, - RandomWalkMessage.class.getName()); + RandomWalkMessage.class.getName()); this.setIfAbsent(params, ComputerOptions.ALGORITHM_RESULT_CLASS, - IdListList.class.getName()); + IdListList.class.getName()); + this.setIfAbsent(params, ComputerOptions.INPUT_FILTER_CLASS, + EXTRACTALLPROPERTYINPUTFILTER_CLASS_NAME); this.setIfAbsent(params, ComputerOptions.OUTPUT_CLASS, - RandomWalkOutput.class.getName()); + RandomWalkOutput.class.getName()); this.setIfAbsent(params, RandomWalk.OPTION_WALK_PER_NODE, "3"); this.setIfAbsent(params, RandomWalk.OPTION_WALK_LENGTH, "3"); diff --git a/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java index 5c09af0a7..2d1a00c6e 100644 --- a/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java +++ b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java @@ -17,8 +17,11 @@ package org.apache.hugegraph.computer.algorithm.sampling; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.apache.hugegraph.computer.algorithm.AlgorithmTestBase; import org.apache.hugegraph.computer.core.config.ComputerOptions; import org.apache.hugegraph.computer.core.graph.id.Id; @@ -34,13 +37,13 @@ import org.junit.Test; import org.slf4j.Logger; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class RandomWalkTest extends AlgorithmTestBase { + private static final String PROPERTY_KEY = "frequency"; + private static final Map> EXPECT_WALK_PATH = ImmutableMap.of( "F", ImmutableList.of( @@ -58,15 +61,21 @@ public static void setup() { HugeClient client = client(); SchemaManager schema = client.schema(); + schema.propertyKey(PROPERTY_KEY) + .asDouble() + .ifNotExist() + .create(); schema.vertexLabel("user") - .useCustomizeStringId() - .ifNotExist() - .create(); + .useCustomizeStringId() + .ifNotExist() + .create(); schema.edgeLabel("know") - .sourceLabel("user") - .targetLabel("user") - .ifNotExist() - .create(); + .sourceLabel("user") + .targetLabel("user") + .properties(PROPERTY_KEY) + .nullableKeys(PROPERTY_KEY) + .ifNotExist() + .create(); GraphManager graph = client.graph(); Vertex vA = graph.addVertex(T.LABEL, "user", T.ID, "A"); @@ -80,17 +89,17 @@ public static void setup() { Vertex vF = graph.addVertex(T.LABEL, "user", T.ID, "F"); Vertex vG = graph.addVertex(T.LABEL, "user", T.ID, "G"); - vA.addEdge("know", vB); + vA.addEdge("know", vB, PROPERTY_KEY, 9); vA.addEdge("know", vC); - vA.addEdge("know", vD); - vB.addEdge("know", vC); + vA.addEdge("know", vD, PROPERTY_KEY, 3); + vB.addEdge("know", vC, PROPERTY_KEY, 2); vC.addEdge("know", vA); - vC.addEdge("know", vE); - vD.addEdge("know", vA); - vD.addEdge("know", vC); - vE.addEdge("know", vD); + vC.addEdge("know", vE, PROPERTY_KEY, 2); + vD.addEdge("know", vA, PROPERTY_KEY, 7); + vD.addEdge("know", vC, PROPERTY_KEY, 1); + vE.addEdge("know", vD, PROPERTY_KEY, 8); - vF.addEdge("know", vG); + vF.addEdge("know", vG, PROPERTY_KEY, 5); } @AfterClass @@ -108,14 +117,36 @@ public static class RandomWalkTestParams extends RandomWalkParams { private static Integer WALK_PER_NODE = 3; private static Integer WALK_LENGTH = 3; + private static String WEIGHT_PROPERTY = PROPERTY_KEY; + private static Double DEFAULT_WEIGHT = 1.0; + private static Double MIN_WEIGHT_THRESHOLD = 3.0; + private static Double MAX_WEIGHT_THRESHOLD = 7.0; + + private static Double RETURN_FACTOR = 2.0; + private static Double INOUT_FACTOR = 1.0 / 2.0; + @Override public void setAlgorithmParameters(Map params) { this.setIfAbsent(params, ComputerOptions.OUTPUT_CLASS, - RandomWalkTest.RandomWalkTestOutput.class.getName()); + RandomWalkTest.RandomWalkTestOutput.class.getName()); this.setIfAbsent(params, RandomWalk.OPTION_WALK_PER_NODE, - WALK_PER_NODE.toString()); + WALK_PER_NODE.toString()); this.setIfAbsent(params, RandomWalk.OPTION_WALK_LENGTH, - WALK_LENGTH.toString()); + WALK_LENGTH.toString()); + + this.setIfAbsent(params, RandomWalk.OPTION_WEIGHT_PROPERTY, + WEIGHT_PROPERTY); + this.setIfAbsent(params, RandomWalk.OPTION_DEFAULT_WEIGHT, + DEFAULT_WEIGHT.toString()); + this.setIfAbsent(params, RandomWalk.OPTION_MIN_WEIGHT_THRESHOLD, + MIN_WEIGHT_THRESHOLD.toString()); + this.setIfAbsent(params, RandomWalk.OPTION_MAX_WEIGHT_THRESHOLD, + MAX_WEIGHT_THRESHOLD.toString()); + + this.setIfAbsent(params, RandomWalk.OPTION_RETURN_FACTOR, + RETURN_FACTOR.toString()); + this.setIfAbsent(params, RandomWalk.OPTION_INOUT_FACTOR, + INOUT_FACTOR.toString()); super.setAlgorithmParameters(params); }