Skip to content

Commit

Permalink
fix install bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Sijie Shen <[email protected]>
  • Loading branch information
ds-ssj committed May 16, 2024
1 parent da615e4 commit 6d9b900
Show file tree
Hide file tree
Showing 4 changed files with 508 additions and 26 deletions.
2 changes: 1 addition & 1 deletion scripts/gart
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ while [ : ]; do
shift 2
;;
--db-type)
db_type=$2
db_type=$(echo "$2" | awk '{print tolower($0)}')
shift 2
;;
--enable-bulkload)
Expand Down
5 changes: 5 additions & 0 deletions scripts/install-mysql.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ binlog-do-db=ldbc # change the name to your database
EOT

# For some reason, the mysql user's home directory is not set correctly
# see https://stackoverflow.com/questions/62987154/mysql-wont-start-error-su-warning-cannot-change-directory-to-nonexistent
sudo usermod -d /var/lib/mysql/ mysql

# Restart mysql to apply the changes
sudo service mysql restart
62 changes: 37 additions & 25 deletions scripts/update_kafka_config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import argparse
import os
import etcd3
import sys

from urllib.parse import urlparse
import yaml
import shutil
import sys

import etcd3
import yaml


def get_parser():
Expand All @@ -28,6 +30,16 @@ def get_parser():
return parser


def get_etcd_client(etcd_endpoint):
if not etcd_endpoint.startswith(("http://", "https://")):
etcd_endpoint = "http://" + etcd_endpoint

parsed_url = urlparse(etcd_endpoint)
etcd_host = parsed_url.netloc.split(":")[0]
etcd_port = parsed_url.port
return etcd3.client(host=etcd_host, port=etcd_port)


if __name__ == "__main__":
arg_parser = get_parser()
args = arg_parser.parse_args()
Expand All @@ -38,7 +50,7 @@ def get_parser():
print("KAFKA_HOME is not set")
sys.exit(1)

db_type = args.db_type
db_type = args.db_type.lower()

kafka_server = args.kafka_endpoint
kafka_port = kafka_server.split(":")[1]
Expand All @@ -60,10 +72,17 @@ def get_parser():
kafka_home + "/config/connect-debezium-postgresql.properties.tmp"
)

etcd_endpoint = args.etcd_endpoint
etcd_client = get_etcd_client(args.etcd_endpoint)

enable_bulkload = (
args.enable_bulkload == "1"
or args.enable_bulkload == 1
or args.enable_bulkload is True
or args.enable_bulkload.lower() == "true"
)

with open(kafka_config_file_name, "r") as file, open(
temp_file_name, "w"
with open(kafka_config_file_name, "r", encoding="UTF-8") as file, open(
temp_file_name, "w", encoding="UTF-8"
) as temp_file:
for line in file:
if line.startswith("database.hostname"):
Expand All @@ -79,14 +98,15 @@ def get_parser():
elif line.startswith("database.include.list"):
temp_file.write(f"database.include.list={args.db_name}\n")
elif line.startswith("table.include.list"):
if not etcd_endpoint.startswith(("http://", "https://")):
etcd_endpoint = "http://" + etcd_endpoint
parsed_url = urlparse(etcd_endpoint)
etcd_host = parsed_url.netloc.split(":")[0]
etcd_port = parsed_url.port
etcd_client = etcd3.client(host=etcd_host, port=etcd_port)
rg_mapping_key = args.etcd_prefix + "gart_rg_mapping_yaml"
rg_mapping_str = etcd_client.get(rg_mapping_key)[0].decode("utf-8")

raw_value = etcd_client.get(rg_mapping_key)
if raw_value is None or raw_value[0] is None:
print(f"ERROR: Key {rg_mapping_key} not found in etcd")
exit(1)
else:
rg_mapping_str = raw_value[0].decode("utf-8")

graph_schema = yaml.load(rg_mapping_str, Loader=yaml.SafeLoader)

# Extract the 'vertex_types' list from the dictionary
Expand Down Expand Up @@ -129,8 +149,7 @@ def get_parser():
# build a dict (table_name, src/dst_type_names)
edge_table_src_dst_type_mapping = {}

for idx in range(len(edge_table_names)):
edge_table_name = edge_table_names[idx]
for idx, edge_table_name in enumerate(edge_table_names):
if edge_table_name not in both_vertex_edge_table_names:
continue

Expand Down Expand Up @@ -161,10 +180,9 @@ def get_parser():
both_vertex_edge_table_names
):
break
for idx in range(len(both_vertex_edge_table_names)):
for idx, edge_table_name in enumerate(both_vertex_edge_table_names):
if both_vertex_edge_table_placed[idx] == 1:
continue
edge_table_name = both_vertex_edge_table_names[idx]
src_dst_type_names = edge_table_src_dst_type_mapping[
edge_table_name
]
Expand All @@ -189,13 +207,7 @@ def get_parser():
new_line = ",".join(new_line_list)
temp_file.write(f"table.include.list={new_line}\n")
elif line.startswith("snapshot.mode"):
if (
args.enable_bulkload == "1"
or args.enable_bulkload == 1
or args.enable_bulkload == True
or args.enable_bulkload == "True"
or args.enable_bulkload == "true"
):
if enable_bulkload:
if db_type == "postgresql":
temp_file.write("snapshot.mode=always\n")
else:
Expand Down
Loading

0 comments on commit 6d9b900

Please sign in to comment.