From d55680c162fd47be563bd25b4d5b21122ff10887 Mon Sep 17 00:00:00 2001 From: diaohancai <36229835+diaohancai@users.noreply.github.com> Date: Sat, 28 Oct 2023 18:12:33 +0800 Subject: [PATCH] feat(algorithm): support random walk in computer (#274) --- .../algorithm/sampling/RandomWalk.java | 190 ++++++++++++++++++ .../algorithm/sampling/RandomWalkMessage.java | 79 ++++++++ .../algorithm/sampling/RandomWalkOutput.java | 48 +++++ .../algorithm/sampling/RandomWalkParams.java | 42 ++++ .../algorithm/AlgorithmTestSuite.java | 4 +- .../algorithm/sampling/RandomWalkTest.java | 149 ++++++++++++++ 6 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java create mode 100644 computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java create mode 100644 computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java create mode 100644 computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java create mode 100644 computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java new file mode 100644 index 000000000..33d738440 --- /dev/null +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.computer.algorithm.sampling; + +import org.apache.hugegraph.computer.core.common.exception.ComputerException; +import org.apache.hugegraph.computer.core.config.Config; +import org.apache.hugegraph.computer.core.graph.edge.Edge; +import org.apache.hugegraph.computer.core.graph.edge.Edges; +import org.apache.hugegraph.computer.core.graph.id.Id; +import org.apache.hugegraph.computer.core.graph.value.IdList; +import org.apache.hugegraph.computer.core.graph.value.IdListList; +import org.apache.hugegraph.computer.core.graph.vertex.Vertex; +import org.apache.hugegraph.computer.core.worker.Computation; +import org.apache.hugegraph.computer.core.worker.ComputationContext; +import org.apache.hugegraph.util.Log; +import org.slf4j.Logger; + +import java.util.Iterator; +import java.util.Random; + +public class RandomWalk implements Computation { + + private static final Logger LOG = Log.logger(RandomWalk.class); + + public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node"; + public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length"; + + /** + * number of times per vertex(source vertex) walks + */ + private Integer walkPerNode; + + /** + * walk length + */ + private Integer walkLength; + + /** + * random + */ + private Random random; + + @Override + public String category() { + return "sampling"; + } + + @Override + public String name() { + return "random_walk"; + } + + @Override + public void init(Config config) { + this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3); + if (this.walkPerNode <= 0) { + throw new ComputerException("The param %s must be greater than 0, " + + "actual got '%s'", + OPTION_WALK_PER_NODE, this.walkPerNode); + } + LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode); + + this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3); + if (this.walkLength <= 0) { + throw new ComputerException("The param %s must be greater than 0, " + + "actual got '%s'", + OPTION_WALK_LENGTH, this.walkLength); + } + LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength); + + this.random = new Random(); + } + + @Override + public void compute0(ComputationContext context, Vertex vertex) { + vertex.value(new IdListList()); + + RandomWalkMessage message = new RandomWalkMessage(); + message.addToPath(vertex); + + if (vertex.numEdges() <= 0) { + // isolated vertex + this.savePath(vertex, message.path()); // save result + vertex.inactivate(); + return; + } + + for (int i = 0; i < walkPerNode; ++i) { + // random select one edge and walk + Edge selectedEdge = this.randomSelectEdge(vertex.edges()); + context.sendMessage(selectedEdge.targetId(), message); + } + } + + @Override + public void compute(ComputationContext context, Vertex vertex, + Iterator messages) { + while (messages.hasNext()) { + RandomWalkMessage message = messages.next(); + + if (message.isFinish()) { + this.savePath(vertex, message.path()); // save result + + vertex.inactivate(); + continue; + } + + message.addToPath(vertex); + + if (vertex.numEdges() <= 0) { + // there is nowhere to walk,finish eariler + message.finish(); + context.sendMessage(this.getSourceId(message.path()), message); + + vertex.inactivate(); + continue; + } + + if (message.path().size() >= this.walkLength + 1) { + message.finish(); + Id sourceId = this.getSourceId(message.path()); + + if (vertex.id().equals(sourceId)) { + // current vertex is the source vertex,no need to send message once more + this.savePath(vertex, message.path()); // save result + } else { + context.sendMessage(sourceId, message); + } + + vertex.inactivate(); + continue; + } + + // random select one edge and walk + Edge selectedEdge = this.randomSelectEdge(vertex.edges()); + context.sendMessage(selectedEdge.targetId(), message); + } + } + + /** + * random select one edge + */ + private Edge randomSelectEdge(Edges edges) { + Edge selectedEdge = null; + int randomNum = random.nextInt(edges.size()); + + int i = 0; + Iterator iterator = edges.iterator(); + while (iterator.hasNext()) { + selectedEdge = iterator.next(); + if (i == randomNum) { + break; + } + i++; + } + + return selectedEdge; + } + + /** + * get source id of path + */ + private Id getSourceId(IdList path) { + // the first id of path is the source id + return path.get(0); + } + + /** + * save path + */ + private void savePath(Vertex sourceVertex, IdList path) { + IdListList curValue = sourceVertex.value(); + curValue.add(path.copy()); + } +} diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java new file mode 100644 index 000000000..bf32ee75c --- /dev/null +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkMessage.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.computer.algorithm.sampling; + +import org.apache.hugegraph.computer.core.graph.value.BooleanValue; +import org.apache.hugegraph.computer.core.graph.value.IdList; +import org.apache.hugegraph.computer.core.graph.value.Value; +import org.apache.hugegraph.computer.core.graph.vertex.Vertex; +import org.apache.hugegraph.computer.core.io.RandomAccessInput; +import org.apache.hugegraph.computer.core.io.RandomAccessOutput; + +import java.io.IOException; +import java.util.List; + +public class RandomWalkMessage implements Value.CustomizeValue> { + + /** + * random walk path + */ + private final IdList path; + + /** + * finish flag + */ + private BooleanValue isFinish; + + public RandomWalkMessage() { + this.path = new IdList(); + this.isFinish = new BooleanValue(false); + } + + @Override + public void read(RandomAccessInput in) throws IOException { + this.path.read(in); + this.isFinish.read(in); + } + + @Override + public void write(RandomAccessOutput out) throws IOException { + this.path.write(out); + this.isFinish.write(out); + } + + @Override + public List value() { + return this.path.value(); + } + + public IdList path() { + return this.path; + } + + public void addToPath(Vertex vertex) { + this.path.add(vertex.id()); + } + + public boolean isFinish() { + return this.isFinish.boolValue(); + } + + public void finish() { + this.isFinish = new BooleanValue(true); + } +} diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java new file mode 100644 index 000000000..ad43d5bd7 --- /dev/null +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkOutput.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.computer.algorithm.sampling; + +import org.apache.hugegraph.computer.core.graph.value.IdListList; +import org.apache.hugegraph.computer.core.graph.vertex.Vertex; +import org.apache.hugegraph.computer.core.output.hg.HugeGraphOutput; + +import java.util.ArrayList; +import java.util.List; + +public class RandomWalkOutput extends HugeGraphOutput> { + + @Override + protected void prepareSchema() { + this.client().schema().propertyKey(this.name()) + .asText() + .writeType(this.writeType()) + .valueList() + .ifNotExist() + .create(); + } + + @Override + protected List value(Vertex vertex) { + IdListList value = vertex.value(); + List propValues = new ArrayList<>(); + for (int i = 0; i < value.size(); i++) { + propValues.add(value.get(i).toString()); + } + return propValues; + } +} diff --git a/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java new file mode 100644 index 000000000..273d7fd67 --- /dev/null +++ b/computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkParams.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.computer.algorithm.sampling; + +import org.apache.hugegraph.computer.algorithm.AlgorithmParams; +import org.apache.hugegraph.computer.core.config.ComputerOptions; +import org.apache.hugegraph.computer.core.graph.value.IdListList; + +import java.util.Map; + +public class RandomWalkParams implements AlgorithmParams { + + @Override + public void setAlgorithmParameters(Map params) { + this.setIfAbsent(params, ComputerOptions.WORKER_COMPUTATION_CLASS, + RandomWalk.class.getName()); + this.setIfAbsent(params, ComputerOptions.ALGORITHM_MESSAGE_CLASS, + RandomWalkMessage.class.getName()); + this.setIfAbsent(params, ComputerOptions.ALGORITHM_RESULT_CLASS, + IdListList.class.getName()); + this.setIfAbsent(params, ComputerOptions.OUTPUT_CLASS, + RandomWalkOutput.class.getName()); + + this.setIfAbsent(params, RandomWalk.OPTION_WALK_PER_NODE, "3"); + this.setIfAbsent(params, RandomWalk.OPTION_WALK_LENGTH, "3"); + } +} diff --git a/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/AlgorithmTestSuite.java b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/AlgorithmTestSuite.java index ad0fc9465..41e9174ce 100644 --- a/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/AlgorithmTestSuite.java +++ b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/AlgorithmTestSuite.java @@ -28,6 +28,7 @@ import org.apache.hugegraph.computer.algorithm.community.wcc.WccTest; import org.apache.hugegraph.computer.algorithm.path.rings.RingsDetectionTest; import org.apache.hugegraph.computer.algorithm.path.rings.RingsDetectionWithFilterTest; +import org.apache.hugegraph.computer.algorithm.sampling.RandomWalkTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -43,7 +44,8 @@ RingsDetectionWithFilterTest.class, ClusteringCoefficientTest.class, ClosenessCentralityTest.class, - BetweennessCentralityTest.class + BetweennessCentralityTest.class, + RandomWalkTest.class }) public class AlgorithmTestSuite { } diff --git a/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java new file mode 100644 index 000000000..5c09af0a7 --- /dev/null +++ b/computer-test/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalkTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.computer.algorithm.sampling; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.hugegraph.computer.algorithm.AlgorithmTestBase; +import org.apache.hugegraph.computer.core.config.ComputerOptions; +import org.apache.hugegraph.computer.core.graph.id.Id; +import org.apache.hugegraph.driver.GraphManager; +import org.apache.hugegraph.driver.HugeClient; +import org.apache.hugegraph.driver.SchemaManager; +import org.apache.hugegraph.structure.constant.T; +import org.apache.hugegraph.structure.graph.Vertex; +import org.apache.hugegraph.testutil.Assert; +import org.apache.hugegraph.util.Log; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class RandomWalkTest extends AlgorithmTestBase { + + private static final Map> EXPECT_WALK_PATH = + ImmutableMap.of( + "F", ImmutableList.of( + "[F, G]", + "[F, G]", + "[F, G]"), + "G", ImmutableList.of("[G]"), + "I", ImmutableList.of("[I]") + ); + + @BeforeClass + public static void setup() { + clearAll(); + + HugeClient client = client(); + SchemaManager schema = client.schema(); + + schema.vertexLabel("user") + .useCustomizeStringId() + .ifNotExist() + .create(); + schema.edgeLabel("know") + .sourceLabel("user") + .targetLabel("user") + .ifNotExist() + .create(); + + GraphManager graph = client.graph(); + Vertex vA = graph.addVertex(T.LABEL, "user", T.ID, "A"); + Vertex vB = graph.addVertex(T.LABEL, "user", T.ID, "B"); + Vertex vC = graph.addVertex(T.LABEL, "user", T.ID, "C"); + Vertex vD = graph.addVertex(T.LABEL, "user", T.ID, "D"); + Vertex vE = graph.addVertex(T.LABEL, "user", T.ID, "E"); + + Vertex vI = graph.addVertex(T.LABEL, "user", T.ID, "I"); + + Vertex vF = graph.addVertex(T.LABEL, "user", T.ID, "F"); + Vertex vG = graph.addVertex(T.LABEL, "user", T.ID, "G"); + + vA.addEdge("know", vB); + vA.addEdge("know", vC); + vA.addEdge("know", vD); + vB.addEdge("know", vC); + vC.addEdge("know", vA); + vC.addEdge("know", vE); + vD.addEdge("know", vA); + vD.addEdge("know", vC); + vE.addEdge("know", vD); + + vF.addEdge("know", vG); + } + + @AfterClass + public static void clear() { + clearAll(); + } + + @Test + public void testRunAlgorithm() throws InterruptedException { + runAlgorithm(RandomWalkTestParams.class.getName()); + } + + public static class RandomWalkTestParams extends RandomWalkParams { + + private static Integer WALK_PER_NODE = 3; + private static Integer WALK_LENGTH = 3; + + @Override + public void setAlgorithmParameters(Map params) { + this.setIfAbsent(params, ComputerOptions.OUTPUT_CLASS, + RandomWalkTest.RandomWalkTestOutput.class.getName()); + this.setIfAbsent(params, RandomWalk.OPTION_WALK_PER_NODE, + WALK_PER_NODE.toString()); + this.setIfAbsent(params, RandomWalk.OPTION_WALK_LENGTH, + WALK_LENGTH.toString()); + + super.setAlgorithmParameters(params); + } + } + + public static class RandomWalkTestOutput extends RandomWalkOutput { + + private static final Logger LOG = Log.logger(RandomWalkTestOutput.class); + + @Override + public List value( + org.apache.hugegraph.computer.core.graph.vertex.Vertex vertex) { + List pathList = super.value(vertex); + LOG.info("vertex: {}, walk path: {}", vertex.id(), pathList); + + this.assertResult(vertex.id(), pathList); + return pathList; + } + + private void assertResult(Id id, List path) { + Set keys = RandomWalkTest.EXPECT_WALK_PATH.keySet(); + if (keys.contains(id.string())) { + List expect = RandomWalkTest.EXPECT_WALK_PATH + .getOrDefault(id.toString(), new ArrayList<>()); + Assert.assertEquals(expect, path); + } else { + Assert.assertEquals(RandomWalkTestParams.WALK_PER_NODE.intValue(), path.size()); + } + } + } +}