Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config can be passed entirely through parameters #394

Merged
merged 9 commits into from
Aug 26, 2024
4 changes: 2 additions & 2 deletions common/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def build_configuration(self):
global_ssh_port = self.input_with_default("oceanbase host ssh_port", "22")
global_home_path = self.input_with_default("oceanbase install home_path", const.OB_INSTALL_DIR_DEFAULT)
default_data_dir = os.path.join(global_home_path, "store")
global_data_dir = self.input_with_default("oceanbase data_dir", default_data_dir)
global_redo_dir = self.input_with_default("oceanbase redo_dir", default_data_dir)
global_data_dir = default_data_dir
global_redo_dir = default_data_dir
tenant_sys_config = {"user": self.sys_tenant_user, "password": self.sys_tenant_password}
global_config = {"ssh_username": global_ssh_username, "ssh_password": global_ssh_password, "ssh_port": global_ssh_port, "ssh_key_file": "", "home_path": global_home_path, "data_dir": global_data_dir, "redo_dir": global_redo_dir}
new_config = {"obcluster": {"ob_cluster_name": ob_cluster_name, "db_host": self.db_host, "db_port": self.db_port, "tenant_sys": tenant_sys_config, "servers": {"nodes": nodes_config, "global": global_config}}}
Expand Down
87 changes: 87 additions & 0 deletions common/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,93 @@ def passwd_format(passwd):
return "'{}'".format(passwd.replace("'", "'\"'\"'"))


class ConfigOptionsParserUtil(object):
def __init__(self):
self.config_dict = {}
self.key_mapping = {
'db_host': 'obcluster.db_host',
'db_port': 'obcluster.db_port',
'tenant_sys.user': 'obcluster.tenant_sys.user',
'tenant_sys.password': 'obcluster.tenant_sys.password',
'ssh_username': 'obcluster.servers.global.ssh_username',
'ssh_password': 'obcluster.servers.global.ssh_password',
'ssh_port': 'obcluster.servers.global.ssh_port',
'home_path': 'obcluster.servers.global.home_path',
'obproxy_home_path': 'obproxy.servers.global.home_path',
}

def set_nested_value(self, d, keys, value):
"""Recursively set the value in a nested dictionary."""
if len(keys) > 1:
if 'nodes' in keys[0]:
try:
# Handle nodes
parts = keys[0].split('[')
base_key = parts[0]
index = int(parts[1].rstrip(']'))
if base_key not in d:
d[base_key] = []
while len(d[base_key]) <= index:
d[base_key].append({})
self.set_nested_value(d[base_key][index], keys[1:], value)
except (IndexError, ValueError) as e:
raise ValueError(f"Invalid node index in key '{keys[0]}'") from e
else:
if keys[0] not in d:
d[keys[0]] = {}
d[keys[0]] = self.set_nested_value(d[keys[0]], keys[1:], value)
else:
d[keys[0]] = value
return d

def parse_config(self, input_array):
for item in input_array:
try:
key, value = item.split('=', 1)
# Map short keys to full keys if needed
if key in self.key_mapping:
key = self.key_mapping[key]
keys = key.split('.')
self.set_nested_value(self.config_dict, keys, value)
except ValueError:
raise ValueError(f"Invalid input format for item '{item}'")

self.config_dict = self.add_default_values(self.config_dict)
return self.config_dict

def add_default_values(self, d):
if isinstance(d, dict):
for k, v in d.items():
if k == 'login':
if 'password' not in v:
v['password'] = ''
elif k == 'tenant_sys':
if 'password' not in v:
v['password'] = ''
elif k == 'global':
if 'ssh_username' not in v:
v['ssh_username'] = ''
if 'ssh_password' not in v:
v['ssh_password'] = ''
elif k == 'servers':
# Ensure 'nodes' is present and initialized as an empty list
if 'nodes' not in v:
v['nodes'] = []
if 'global' not in v:
v['global'] = {}
self.add_default_values(v['global'])
for node in v['nodes']:
if isinstance(node, dict):
self.add_default_values(node)
elif isinstance(v, dict):
self.add_default_values(v)
elif isinstance(v, list):
for node in v:
if isinstance(node, dict):
self.add_default_values(node)
return d


class DirectoryUtil(object):

@staticmethod
Expand Down
25 changes: 14 additions & 11 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import absolute_import, division, print_function
import os
from common.tool import DirectoryUtil
from common.tool import ConfigOptionsParserUtil, DirectoryUtil
from stdio import SafeStdio
import oyaml as yaml
import pathlib
Expand Down Expand Up @@ -148,17 +148,20 @@ def load_config_with_defaults(self, defaults_dict):

class ConfigManager(Manager):

def __init__(self, config_file=None, stdio=None):
def __init__(self, config_file=None, stdio=None, config_env_list=[]):
default_config_path = os.path.join(os.path.expanduser("~"), ".obdiag", "config.yml")

if config_file is None or not os.path.exists(config_file):
config_file = default_config_path
pathlib.Path(os.path.dirname(default_config_path)).mkdir(parents=True, exist_ok=True)
with open(default_config_path, 'w') as f:
f.write(DEFAULT_CONFIG_DATA)
super(ConfigManager, self).__init__(config_file, stdio)
self.config_file = config_file
self.config_data = self.load_config()
if config_env_list is None or len(config_env_list) == 0:
if config_file is None or not os.path.exists(config_file):
config_file = default_config_path
pathlib.Path(os.path.dirname(default_config_path)).mkdir(parents=True, exist_ok=True)
with open(default_config_path, 'w') as f:
f.write(DEFAULT_CONFIG_DATA)
super(ConfigManager, self).__init__(config_file, stdio)
self.config_file = config_file
self.config_data = self.load_config()
else:
parser = ConfigOptionsParserUtil()
self.config_data = parser.parse_config(config_env_list)

def _safe_get(self, dictionary, *keys, default=None):
"""Safe way to retrieve nested values from dictionaries"""
Expand Down
11 changes: 2 additions & 9 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@
from colorama import Fore, Style
from common.config_helper import ConfigHelper

from common.tool import Util
from common.tool import TimeUtils


class ObdiagHome(object):

def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.yml'), inner_config_change_map=None):
def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.yml'), inner_config_change_map=None, custom_config_env_list=None):
self._optimize_manager = None
self.stdio = None
self._stdio_func = None
Expand All @@ -80,13 +79,7 @@ def __init__(self, stdio=None, config_path=os.path.expanduser('~/.obdiag/config.
if self.inner_config_manager.config.get("obdiag") is not None and self.inner_config_manager.config.get("obdiag").get("logger") is not None and self.inner_config_manager.config.get("obdiag").get("logger").get("silent") is not None:
stdio.set_silent(self.inner_config_manager.config.get("obdiag").get("logger").get("silent"))
self.set_stdio(stdio)
if config_path:
if os.path.exists(os.path.abspath(config_path)):
config_path = config_path
else:
stdio.error('The option you provided with -c: {0} is not exist.'.format(config_path))
return
self.config_manager = ConfigManager(config_path, stdio)
self.config_manager = ConfigManager(config_path, stdio, custom_config_env_list)
if (
self.inner_config_manager.config.get("obdiag") is not None
and self.inner_config_manager.config.get("obdiag").get("basic") is not None
Expand Down
Loading
Loading