Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip ssl check #149

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions testgres/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_new_node(name=None, base_dir=None, **kwargs):
return PostgresNode(name=name, base_dir=base_dir, **kwargs)


def get_remote_node(name=None, conn_params=None):
def get_remote_node(name=None):
"""
Simply a wrapper around :class:`.PostgresNode` constructor for remote node.
See :meth:`.PostgresNode.__init__` for details.
Expand All @@ -51,4 +51,4 @@ def get_remote_node(name=None, conn_params=None):
ssh_key=None,
username=default_username())
"""
return get_new_node(name=name, conn_params=conn_params)
return get_new_node(name=name)
90 changes: 56 additions & 34 deletions testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@

from .operations.os_ops import ConnectionParams
from .operations.local_ops import LocalOperations
from .operations.remote_ops import RemoteOperations

InternalError = pglib.InternalError
ProgrammingError = pglib.ProgrammingError
Expand Down Expand Up @@ -128,7 +127,8 @@ def __repr__(self):


class PostgresNode(object):
def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(), bin_dir=None, prefix=None):
def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(),
bin_dir=None, prefix=None):
"""
PostgresNode constructor.

Expand All @@ -152,21 +152,17 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP
self.name = name or generate_app_name()
if testgres_config.os_ops:
self.os_ops = testgres_config.os_ops
elif conn_params.ssh_key:
self.os_ops = RemoteOperations(conn_params)
else:
self.os_ops = LocalOperations(conn_params)

self.host = self.os_ops.host
self.port = port or reserve_port()

self.ssh_key = self.os_ops.ssh_key

# defaults for __exit__()
self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit
self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit
self.shutdown_max_attempts = 3

self.port = port or self.os_ops.port or reserve_port()

# NOTE: for compatibility
self.utils_log_name = self.utils_log_file
self.pg_log_name = self.pg_log_file
Expand Down Expand Up @@ -492,7 +488,7 @@ def init(self, initdb_params=None, cached=True, **kwargs):
os_ops=self.os_ops,
params=initdb_params,
bin_path=self.bin_dir,
cached=False)
cached=cached)

# initialize default config files
self.default_conf(**kwargs)
Expand Down Expand Up @@ -722,9 +718,9 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
OperationalError},
max_attempts=max_attempts)

def start(self, params=[], wait=True):
def start(self, params=None, wait: bool = True) -> 'PostgresNode':
"""
Starts the PostgreSQL node using pg_ctl if node has not been started.
Starts the PostgreSQL node using pg_ctl if the node has not been started.
By default, it waits for the operation to complete before returning.
Optionally, it can return immediately without waiting for the start operation
to complete by setting the `wait` parameter to False.
Expand All @@ -736,6 +732,8 @@ def start(self, params=[], wait=True):
Returns:
This instance of :class:`.PostgresNode`.
"""
if params is None:
params = []
if self.is_started:
return self

Expand All @@ -745,34 +743,49 @@ def start(self, params=[], wait=True):
"-w" if wait else '-W', # --wait or --no-wait
"start"] + params # yapf: disable

startup_retries = 5
while True:
max_retries = 5
sleep_interval = 5 # seconds

for attempt in range(max_retries):
try:
exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
if error and 'does not exist' in error:
raise Exception
break # Exit the loop if successful
except Exception as e:
files = self._collect_special_files()
if any(len(file) > 1 and 'Is another postmaster already '
'running on port' in file[1].decode() for
file in files):
logging.warning("Detected an issue with connecting to port {0}. "
"Trying another port after a 5-second sleep...".format(self.port))
self.port = reserve_port()
options = {'port': str(self.port)}
self.set_auto_conf(options)
startup_retries -= 1
time.sleep(5)
continue

msg = 'Cannot start node'
raise_from(StartNodeException(msg, files), e)
break
if self._handle_port_conflict():
if attempt < max_retries - 1:
logging.info(f"Retrying start operation (Attempt {attempt + 2}/{max_retries})...")
time.sleep(sleep_interval)
continue
else:
logging.error("Reached maximum retry attempts. Unable to start node.")
raise StartNodeException("Cannot start node after multiple attempts",
self._collect_special_files()) from e
raise StartNodeException("Cannot start node", self._collect_special_files()) from e

self._maybe_start_logger()
self.is_started = True
return self

def stop(self, params=[], wait=True):
def _handle_port_conflict(self) -> bool:
"""
Checks for a port conflict and attempts to resolve it by changing the port.
Returns True if the port was changed, False otherwise.
"""
files = self._collect_special_files()
if any(len(file) > 1 and 'Is another postmaster already running on port' in file[1].decode() for file in files):
dmitry-lipetsk marked this conversation as resolved.
Show resolved Hide resolved
logging.warning(f"Port conflict detected on port {self.port}.")
if self._should_free_port:
logging.warning("Port reservation skipped due to _should_free_port setting.")
return False
self.port = reserve_port()
dmitry-lipetsk marked this conversation as resolved.
Show resolved Hide resolved
self.set_auto_conf({'port': str(self.port)})
logging.info(f"Port changed to {self.port}.")
return True
return False

def stop(self, params=None, wait=True):
"""
Stops the PostgreSQL node using pg_ctl if the node has been started.

Expand All @@ -783,6 +796,8 @@ def stop(self, params=[], wait=True):
Returns:
This instance of :class:`.PostgresNode`.
"""
if params is None:
params = []
if not self.is_started:
return self

Expand All @@ -797,6 +812,7 @@ def stop(self, params=[], wait=True):

self._maybe_stop_logger()
self.is_started = False
release_port(self.port)
return self

def kill(self, someone=None):
Expand All @@ -815,7 +831,7 @@ def kill(self, someone=None):
os.kill(self.auxiliary_pids[someone][0], sig)
self.is_started = False

def restart(self, params=[]):
def restart(self, params=None):
"""
Restart this node using pg_ctl.

Expand All @@ -826,6 +842,8 @@ def restart(self, params=[]):
This instance of :class:`.PostgresNode`.
"""

if params is None:
params = []
_params = [
self._get_bin_path("pg_ctl"),
"-D", self.data_dir,
Expand All @@ -847,7 +865,7 @@ def restart(self, params=[]):

return self

def reload(self, params=[]):
def reload(self, params=None):
"""
Asynchronously reload config files using pg_ctl.

Expand All @@ -858,6 +876,8 @@ def reload(self, params=[]):
This instance of :class:`.PostgresNode`.
"""

if params is None:
params = []
_params = [
self._get_bin_path("pg_ctl"),
"-D", self.data_dir,
Expand Down Expand Up @@ -1036,7 +1056,7 @@ def _psql(

# select query source
if query:
if self.os_ops.remote:
if self.os_ops.conn_params.remote:
psql_params.extend(("-c", '"{}"'.format(query)))
else:
psql_params.extend(("-c", query))
Expand Down Expand Up @@ -1620,7 +1640,7 @@ def pgbench_table_checksums(self, dbname="postgres",
return {(table, self.table_checksum(table, dbname))
for table in pgbench_tables}

def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}):
def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options=None):
"""
Update or remove configuration options in the specified configuration file,
updates the options specified in the options dictionary, removes any options
Expand All @@ -1636,6 +1656,8 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}):
Defaults to an empty set.
"""
# parse postgresql.auto.conf
if rm_options is None:
rm_options = {}
path = os.path.join(self.data_dir, config)

lines = self.os_ops.readlines(path)
Expand Down
21 changes: 2 additions & 19 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import getpass
import logging
import os
import shutil
Expand All @@ -10,7 +9,7 @@
import psutil

from ..exceptions import ExecUtilException
from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding
from .os_ops import ConnectionParams, OsOperations, get_default_encoding
from .raise_error import RaiseError
from .helpers import Helpers

Expand Down Expand Up @@ -42,12 +41,7 @@ class LocalOperations(OsOperations):
def __init__(self, conn_params=None):
if conn_params is None:
conn_params = ConnectionParams()
super(LocalOperations, self).__init__(conn_params.username)
self.conn_params = conn_params
self.host = conn_params.host
self.ssh_key = None
self.remote = False
self.username = conn_params.username or getpass.getuser()
super(LocalOperations, self).__init__(conn_params)

@staticmethod
def _process_output(encoding, temp_file_path):
Expand Down Expand Up @@ -329,14 +323,3 @@ def get_pid(self):

def get_process_children(self, pid):
return psutil.Process(pid).children()

# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
)
return conn
42 changes: 37 additions & 5 deletions testgres/operations/os_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import getpass
import locale
import os
import sys

try:
Expand All @@ -12,11 +13,18 @@


class ConnectionParams:
def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None):
def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=None):
"""
skip_ssl: if is True, the connection to database is established without SSL.
"""
self.remote = remote
self.host = host
self.port = port
self.ssh_key = ssh_key
self.username = username
if skip_ssl is None:
skip_ssl = os.getenv("TESTGRES_SKIP_SSL", False)
self.skip_ssl = skip_ssl


def get_default_encoding():
Expand All @@ -26,9 +34,12 @@ def get_default_encoding():


class OsOperations:
def __init__(self, username=None):
self.ssh_key = None
self.username = username or getpass.getuser()
def __init__(self, conn_params=ConnectionParams()):
self.ssh_key = conn_params.ssh_key
self.username = conn_params.username or getpass.getuser()
self.host = conn_params.host
self.port = conn_params.port
self.conn_params = conn_params

# Command execution
def exec_command(self, cmd, **kwargs):
Expand Down Expand Up @@ -113,6 +124,27 @@ def get_pid(self):
def get_process_children(self, pid):
raise NotImplementedError()

def _get_ssl_options(self):
"""
Determine the SSL options based on available modules.
"""
if self.conn_params.skip_ssl:
if 'psycopg2' in sys.modules:
return {"sslmode": "disable"}
elif 'pg8000' in sys.modules:
return {"ssl_context": None}
return {}

# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
raise NotImplementedError()
ssl_options = self._get_ssl_options()
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
**ssl_options
)

return conn
Loading