Skip to content

Commit

Permalink
config can be passed entirely through parameters (#394)
Browse files Browse the repository at this point in the history
* update version to 2.4.0

* Remove duplicate scripts

* The configuration can be passed entirely through parameters

* fix

* fix
  • Loading branch information
Teingi authored Aug 26, 2024
1 parent 6138999 commit f8ab11b
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 32 deletions.
4 changes: 2 additions & 2 deletions common/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def build_configuration(self):
global_ssh_port = self.input_with_default("oceanbase host ssh_port", "22")
global_home_path = self.input_with_default("oceanbase install home_path", const.OB_INSTALL_DIR_DEFAULT)
default_data_dir = os.path.join(global_home_path, "store")
global_data_dir = self.input_with_default("oceanbase data_dir", default_data_dir)
global_redo_dir = self.input_with_default("oceanbase redo_dir", default_data_dir)
global_data_dir = default_data_dir
global_redo_dir = default_data_dir
tenant_sys_config = {"user": self.sys_tenant_user, "password": self.sys_tenant_password}
global_config = {"ssh_username": global_ssh_username, "ssh_password": global_ssh_password, "ssh_port": global_ssh_port, "ssh_key_file": "", "home_path": global_home_path, "data_dir": global_data_dir, "redo_dir": global_redo_dir}
new_config = {"obcluster": {"ob_cluster_name": ob_cluster_name, "db_host": self.db_host, "db_port": self.db_port, "tenant_sys": tenant_sys_config, "servers": {"nodes": nodes_config, "global": global_config}}}
Expand Down
87 changes: 87 additions & 0 deletions common/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,93 @@ def passwd_format(passwd):
return "'{}'".format(passwd.replace("'", "'\"'\"'"))


class ConfigOptionsParserUtil(object):
def __init__(self):
self.config_dict = {}
self.key_mapping = {
'db_host': 'obcluster.db_host',
'db_port': 'obcluster.db_port',
'tenant_sys.user': 'obcluster.tenant_sys.user',
'tenant_sys.password': 'obcluster.tenant_sys.password',
'ssh_username': 'obcluster.servers.global.ssh_username',
'ssh_password': 'obcluster.servers.global.ssh_password',
'ssh_port': 'obcluster.servers.global.ssh_port',
'home_path': 'obcluster.servers.global.home_path',
'obproxy_home_path': 'obproxy.servers.global.home_path',
}

def set_nested_value(self, d, keys, value):
"""Recursively set the value in a nested dictionary."""
if len(keys) > 1:
if 'nodes' in keys[0]:
try:
# Handle nodes
parts = keys[0].split('[')
base_key = parts[0]
index = int(parts[1].rstrip(']'))
if base_key not in d:
d[base_key] = []
while len(d[base_key]) <= index:
d[base_key].append({})
self.set_nested_value(d[base_key][index], keys[1:], value)
except (IndexError, ValueError) as e:
raise ValueError(f"Invalid node index in key '{keys[0]}'") from e
else:
if keys[0] not in d:
d[keys[0]] = {}
d[keys[0]] = self.set_nested_value(d[keys[0]], keys[1:], value)
else:
d[keys[0]] = value
return d

def parse_config(self, input_array):
for item in input_array:
try:
key, value = item.split('=', 1)
# Map short keys to full keys if needed
if key in self.key_mapping:
key = self.key_mapping[key]
keys = key.split('.')
self.set_nested_value(self.config_dict, keys, value)
except ValueError:
raise ValueError(f"Invalid input format for item '{item}'")

self.config_dict = self.add_default_values(self.config_dict)
return self.config_dict

def add_default_values(self, d):
if isinstance(d, dict):
for k, v in d.items():
if k == 'login':
if 'password' not in v:
v['password'] = ''
elif k == 'tenant_sys':
if 'password' not in v:
v['password'] = ''
elif k == 'global':
if 'ssh_username' not in v:
v['ssh_username'] = ''
if 'ssh_password' not in v:
v['ssh_password'] = ''
elif k == 'servers':
# Ensure 'nodes' is present and initialized as an empty list
if 'nodes' not in v:
v['nodes'] = []
if 'global' not in v:
v['global'] = {}
self.add_default_values(v['global'])
for node in v['nodes']:
if isinstance(node, dict):
self.add_default_values(node)
elif isinstance(v, dict):
self.add_default_values(v)
elif isinstance(v, list):
for node in v:
if isinstance(node, dict):
self.add_default_values(node)
return d


class DirectoryUtil(object):

@staticmethod
Expand Down
25 changes: 14 additions & 11 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import absolute_import, division, print_function
import os
from common.tool import DirectoryUtil
from common.tool import ConfigOptionsParserUtil, DirectoryUtil
from stdio import SafeStdio
import oyaml as yaml
import pathlib
Expand Down Expand Up @@ -148,17 +148,20 @@ def load_config_with_defaults(self, defaults_dict):

class ConfigManager(Manager):

def __init__(self, config_file=None, stdio=None):
def __init__(self, config_file=None, stdio=None, config_env_list=[]):
default_config_path = os.path.join(os.path.expanduser("~"), ".obdiag", "config.yml")

if config_file is None or not os.path.exists(config_file):
config_file = default_config_path
pathlib.Path(os.path.dirname(default_config_path)).mkdir(parents=True, exist_ok=True)
with open(default_config_path, 'w') as f:
f.write(DEFAULT_CONFIG_DATA)
super(ConfigManager, self).__init__(config_file, stdio)
self.config_file = config_file
self.config_data = self.load_config()
if config_env_list is None or len(config_env_list) == 0:
if config_file is None or not os.path.exists(config_file):
config_file = default_config_path
pathlib.Path(os.path.dirname(default_config_path)).mkdir(parents=True, exist_ok=True)
with open(default_config_path, 'w') as f:
f.write(DEFAULT_CONFIG_DATA)
super(ConfigManager, self).__init__(config_file, stdio)
self.config_file = config_file
self.config_data = self.load_config()
else:
parser = ConfigOptionsParserUtil()
self.config_data = parser.parse_config(config_env_list)

def _safe_get(self, dictionary, *keys, default=None):
"""Safe way to retrieve nested values from dictionaries"""
Expand Down
11 changes: 2 additions & 9 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@
from colorama import Fore, Style
from common.config_helper import ConfigHelper

from common.tool import Util
from common.tool import TimeUtils


class ObdiagHome(object):

def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.yml'), inner_config_change_map=None):
def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.yml'), inner_config_change_map=None, custom_config_env_list=None):
self._optimize_manager = None
self.stdio = None
self._stdio_func = None
Expand All @@ -80,13 +79,7 @@ def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.
if self.inner_config_manager.config.get("obdiag") is not None and self.inner_config_manager.config.get("obdiag").get("logger") is not None and self.inner_config_manager.config.get("obdiag").get("logger").get("silent") is not None:
stdio.set_silent(self.inner_config_manager.config.get("obdiag").get("logger").get("silent"))
self.set_stdio(stdio)
if config_path:
if os.path.exists(os.path.abspath(config_path)):
config_path = config_path
else:
stdio.error('The option you provided with -c: {0} is not exist.'.format(config_path))
return
self.config_manager = ConfigManager(config_path, stdio)
self.config_manager = ConfigManager(config_path, stdio, custom_config_env_list)
if (
self.inner_config_manager.config.get("obdiag") is not None
and self.inner_config_manager.config.get("obdiag").get("basic") is not None
Expand Down
Loading

0 comments on commit f8ab11b

Please sign in to comment.