diff --git a/.gitignore b/.gitignore index 3d149b6..b8b320d 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ coverage.xml *.cover .hypothesis/ .pytest_cache/ +tox.*.ini # Translations *.mo diff --git a/README.md b/README.md index b6feb95..b0363d2 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,9 @@ Options: --mysql-charset TEXT MySQL database and table character set [default: utf8mb4] --mysql-collation TEXT MySQL database and table collation + --mysql-ssl-ca PATH Path to SSL CA certificate file. + --mysql-ssl-cert PATH Path to SSL certificate file. + --mysql-ssl-key PATH Path to SSL key file. -S, --skip-ssl Disable MySQL connection encryption. -c, --chunk INTEGER Chunk reading/writing SQL records -l, --log-file PATH Log file diff --git a/docs/README.rst b/docs/README.rst index 3eb421e..0ce2bdf 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -45,7 +45,10 @@ Connection Options - ``-h, --mysql-host TEXT``: MySQL host. Defaults to localhost. - ``-P, --mysql-port INTEGER``: MySQL port. Defaults to 3306. - ``--mysql-charset TEXT``: MySQL database and table character set. The default is utf8mb4. -- ``--mysql-collation TEXT``: MySQL database and table collation +- ``--mysql-collation TEXT``: MySQL database and table collation. +- ``--mysql-ssl-ca PATH``: Path to SSL CA certificate file. +- ``--mysql-ssl-cert PATH``: Path to SSL certificate file. +- ``--mysql-ssl-key PATH``: Path to SSL key file. - ``-S, --skip-ssl``: Disable MySQL connection encryption. Other Options diff --git a/src/mysql_to_sqlite3/cli.py b/src/mysql_to_sqlite3/cli.py index cd9bc26..9ed48ea 100644 --- a/src/mysql_to_sqlite3/cli.py +++ b/src/mysql_to_sqlite3/cli.py @@ -126,6 +126,9 @@ default=None, help="MySQL database and table collation", ) +@click.option("--mysql-ssl-ca", type=click.Path(), help="Path to SSL CA certificate file.") +@click.option("--mysql-ssl-cert", type=click.Path(), help="Path to SSL certificate file.") +@click.option("--mysql-ssl-key", type=click.Path(), help="Path to SSL key file.") @click.option("-S", "--skip-ssl", is_flag=True, help="Disable MySQL connection encryption.") @click.option( "-c", @@ -171,6 +174,9 @@ def cli( mysql_port: int, mysql_charset: str, mysql_collation: str, + mysql_ssl_ca: t.Optional[str], + mysql_ssl_cert: t.Optional[str], + mysql_ssl_key: t.Optional[str], skip_ssl: bool, chunk: int, log_file: t.Union[str, "os.PathLike[t.Any]"], @@ -219,6 +225,9 @@ def cli( mysql_port=mysql_port, mysql_charset=mysql_charset, mysql_collation=mysql_collation, + mysql_ssl_ca=mysql_ssl_ca, + mysql_ssl_cert=mysql_ssl_cert, + mysql_ssl_key=mysql_ssl_key, mysql_ssl_disabled=skip_ssl, chunk=chunk, json_as_text=json_as_text, diff --git a/src/mysql_to_sqlite3/transporter.py b/src/mysql_to_sqlite3/transporter.py index c6151de..84294bd 100644 --- a/src/mysql_to_sqlite3/transporter.py +++ b/src/mysql_to_sqlite3/transporter.py @@ -100,6 +100,12 @@ def __init__(self, **kwargs: tx.Unpack[MySQLtoSQLiteParams]) -> None: if self._without_tables and self._without_data: raise ValueError("Unable to continue without transferring data or creating tables!") + self._mysql_ssl_ca = kwargs.get("mysql_ssl_ca") or None + + self._mysql_ssl_cert = kwargs.get("mysql_ssl_cert") or None + + self._mysql_ssl_key = kwargs.get("mysql_ssl_key") or None + self._mysql_ssl_disabled = bool(kwargs.get("mysql_ssl_disabled", False)) self._current_chunk_number = 0 @@ -135,6 +141,9 @@ def __init__(self, **kwargs: tx.Unpack[MySQLtoSQLiteParams]) -> None: password=self._mysql_password, host=self._mysql_host, port=self._mysql_port, + ssl_ca=self._mysql_ssl_ca, + ssl_cert=self._mysql_ssl_cert, + ssl_key=self._mysql_ssl_key, ssl_disabled=self._mysql_ssl_disabled, charset=self._mysql_charset, collation=self._mysql_collation, diff --git a/src/mysql_to_sqlite3/types.py b/src/mysql_to_sqlite3/types.py index 2a28f2a..395146a 100644 --- a/src/mysql_to_sqlite3/types.py +++ b/src/mysql_to_sqlite3/types.py @@ -26,6 +26,9 @@ class MySQLtoSQLiteParams(tx.TypedDict): mysql_port: int mysql_charset: t.Optional[str] mysql_collation: t.Optional[str] + mysql_ssl_ca: t.Optional[str] + mysql_ssl_cert: t.Optional[str] + mysql_ssl_key: t.Optional[str] mysql_ssl_disabled: t.Optional[bool] mysql_tables: t.Optional[t.Sequence[str]] mysql_user: str diff --git a/tests/conftest.py b/tests/conftest.py index 855a9db..c2185c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ import json import os import socket +import subprocess +import threading import typing as t +import uuid from codecs import open from contextlib import contextmanager -from os.path import abspath, dirname, isfile, join +from os.path import abspath, basename, dirname, isfile, join from pathlib import Path from random import choice from string import ascii_lowercase, ascii_uppercase, digits @@ -31,6 +34,7 @@ from sqlalchemy_utils import database_exists, drop_database from . import database, factories, models +from .utils import generate_ssl_certs, stream_logs def pytest_addoption(parser: "Parser"): @@ -70,6 +74,27 @@ def pytest_addoption(parser: "Parser"): help="The TCP port of the MySQL server.", ) + parser.addoption( + "--mysql-ssl-ca", + dest="mysql_ssl_ca", + default=None, + help="Path to SSL CA certificate file.", + ) + + parser.addoption( + "--mysql-ssl-cert", + dest="mysql_ssl_cert", + default=None, + help="Path to SSL certificate file.", + ) + + parser.addoption( + "--mysql-ssl-key", + dest="mysql_ssl_key", + default=None, + help="Path to SSL key file.", + ) + parser.addoption( "--no-docker", dest="use_docker", @@ -159,10 +184,35 @@ class MySQLCredentials(t.NamedTuple): host: str port: int database: str + ssl_ca: t.Optional[str] = None + ssl_cert: t.Optional[str] = None + ssl_key: t.Optional[str] = None @pytest.fixture(scope="session") -def mysql_credentials(pytestconfig: Config) -> MySQLCredentials: +def mysql_credentials(request, pytestconfig: Config, tmp_path_factory: pytest.TempPathFactory) -> MySQLCredentials: + ssl_credentials = { + "ssl_ca": pytestconfig.getoption("mysql_ssl_ca") or None, + "ssl_cert": pytestconfig.getoption("mysql_ssl_cert") or None, + "ssl_key": pytestconfig.getoption("mysql_ssl_key") or None, + } + + if hasattr(request, "param") and request.param == "ssl": + certs_dir = tmp_path_factory.getbasetemp() / "certs" + if not certs_dir.exists(): + certs_dir.mkdir(parents=True) + generate_ssl_certs(certs_dir) + + # FIXED: docker perms + subprocess.call(["chmod", "0644", str(certs_dir / "ca-key.pem")]) + subprocess.call(["chmod", "0644", str(certs_dir / "server-key.pem")]) + + ssl_credentials = { + "ssl_ca": str(certs_dir / "ca.pem"), + "ssl_cert": str(certs_dir / "server-cert.pem"), + "ssl_key": str(certs_dir / "server-key.pem"), + } + db_credentials_file: str = abspath(join(dirname(__file__), "db_credentials.json")) if isfile(db_credentials_file): with open(db_credentials_file, "r", "utf-8") as fh: @@ -173,6 +223,9 @@ def mysql_credentials(pytestconfig: Config) -> MySQLCredentials: database=db_credentials["mysql_database"], host=db_credentials["mysql_host"], port=db_credentials["mysql_port"], + ssl_ca=db_credentials.get("mysql_ssl_ca") or ssl_credentials["ssl_ca"], + ssl_cert=db_credentials.get("mysql_ssl_cert") or ssl_credentials["ssl_cert"], + ssl_key=db_credentials.get("mysql_ssl_key") or ssl_credentials["ssl_key"], ) port: int = pytestconfig.getoption("mysql_port") or 3306 @@ -188,6 +241,7 @@ def mysql_credentials(pytestconfig: Config) -> MySQLCredentials: database=pytestconfig.getoption("mysql_database") or "test_db", host=pytestconfig.getoption("mysql_host") or "0.0.0.0", port=port, + **ssl_credentials, ) @@ -197,6 +251,7 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> mysql_connection: t.Optional[t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection]] = None mysql_available: bool = False mysql_connection_retries: int = 15 # failsafe + ssl_args = {} db_credentials_file = abspath(join(dirname(__file__), "db_credentials.json")) if isfile(db_credentials_file): @@ -222,9 +277,37 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> except (HTTPError, NotFound) as err: pytest.fail(str(err)) + ssl_cmds = [] + ssl_volumes = {} + host_certs_dir = None + container_certs_dir = "/etc/mysql/certs" + + if mysql_credentials.ssl_ca: + host_certs_dir = dirname(mysql_credentials.ssl_ca) + ssl_cmds.append(f"--ssl-ca={container_certs_dir}/{basename(mysql_credentials.ssl_ca)}") + ssl_args["ssl_ca"] = mysql_credentials.ssl_ca + + if mysql_credentials.ssl_cert: + host_certs_dir = dirname(mysql_credentials.ssl_cert) + ssl_cmds.append(f"--ssl-cert={container_certs_dir}/{basename(mysql_credentials.ssl_cert)}") + ssl_args["ssl_cert"] = f"{host_certs_dir}/client-cert.pem" + + if mysql_credentials.ssl_key: + host_certs_dir = dirname(mysql_credentials.ssl_key) + ssl_cmds.append(f"--ssl-key={container_certs_dir}/{basename(mysql_credentials.ssl_key)}") + ssl_args["ssl_key"] = f"{host_certs_dir}/client-key.pem" + + if host_certs_dir: + ssl_volumes[host_certs_dir] = {"bind": container_certs_dir, "mode": "ro"} + + if ssl_args: + ssl_args["ssl_verify_cert"] = True + + container_name = f"pytest_mysql_to_sqlite3_{uuid.uuid4().hex[:10]}" + container = client.containers.run( image=docker_mysql_image, - name="pytest_mysql_to_sqlite3", + name=container_name, ports={"3306/tcp": (mysql_credentials.host, f"{mysql_credentials.port}/tcp")}, environment={ "MYSQL_RANDOM_ROOT_PASSWORD": "yes", @@ -232,16 +315,25 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> "MYSQL_PASSWORD": mysql_credentials.password, "MYSQL_DATABASE": mysql_credentials.database, }, + volumes=ssl_volumes, command=[ "--character-set-server=utf8mb4", "--collation-server=utf8mb4_unicode_ci", - ], + ] + + ssl_cmds, detach=True, auto_remove=True, ) + log_thread = threading.Thread(target=stream_logs, args=(container,)) + # The thread will terminate when the main program terminates + log_thread.daemon = True + log_thread.start() + while not mysql_available and mysql_connection_retries > 0: try: + print(f"Attempt #{mysql_connection_retries} to connect to MySQL...") + mysql_connection = mysql.connector.connect( user=mysql_credentials.user, password=mysql_credentials.password, @@ -249,6 +341,7 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> port=mysql_credentials.port, charset="utf8mb4", collation="utf8mb4_unicode_ci", + **ssl_args, ) except mysql.connector.Error as err: if err.errno == errorcode.CR_SERVER_LOST: @@ -270,6 +363,10 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> if use_docker and container is not None: container.kill() + # Wait for the log thread to finish (optional) + if "log_thread" in locals() and log_thread.is_alive(): + log_thread.join(timeout=5) + @pytest.fixture(scope="session") def mysql_database( diff --git a/tests/func/test_cli.py b/tests/func/test_cli.py index 6130ca1..d9c960d 100644 --- a/tests/func/test_cli.py +++ b/tests/func/test_cli.py @@ -1,8 +1,11 @@ import os +import subprocess import typing as t from datetime import datetime +from pathlib import Path from random import choice, sample +import mysql.connector import pytest from click.testing import CliRunner, Result from faker import Faker @@ -578,3 +581,180 @@ def test_version(self, cli_runner: CliRunner) -> None: "tqdm", } ) + + @pytest.mark.parametrize("mysql_credentials", ["ssl"], indirect=True) + def test_ssl_connection( + self, + cli_runner: CliRunner, + sqlite_database: "os.PathLike[t.Any]", + mysql_credentials: MySQLCredentials, + mysql_database: Database, + tmp_path_factory: pytest.TempPathFactory, + ): + certs_dir = tmp_path_factory.getbasetemp() / "certs" + + result: Result = cli_runner.invoke( + mysql2sqlite, + [ + "-f", + str(sqlite_database), + "-d", + mysql_credentials.database, + "-u", + mysql_credentials.user, + "--mysql-password", + mysql_credentials.password, + "-h", + mysql_credentials.host, + "-P", + str(mysql_credentials.port), + "--mysql-ssl-ca", + str(certs_dir / "ca.pem"), + "--mysql-ssl-cert", + str(certs_dir / "client-cert.pem"), + "--mysql-ssl-key", + str(certs_dir / "client-key.pem"), + ], + ) + + assert result.exit_code == 0 + + @pytest.mark.parametrize("mysql_credentials", ["ssl"], indirect=True) + def test_ssl_connection_missing_ca( + self, + cli_runner: CliRunner, + sqlite_database: "os.PathLike[t.Any]", + mysql_credentials: MySQLCredentials, + mysql_database: Database, + tmp_path_factory: pytest.TempPathFactory, + ): + certs_dir = tmp_path_factory.getbasetemp() / "certs" + + result: Result = cli_runner.invoke( + mysql2sqlite, + [ + "-f", + str(sqlite_database), + "-d", + mysql_credentials.database, + "-u", + mysql_credentials.user, + "--mysql-password", + mysql_credentials.password, + "-h", + mysql_credentials.host, + "-P", + str(mysql_credentials.port), + "--mysql-ssl-cert", + str(certs_dir / "client-cert.pem"), + "--mysql-ssl-key", + str(certs_dir / "client-key.pem"), + ], + ) + + assert result.exit_code == 0 + + @pytest.mark.parametrize("mysql_credentials", ["ssl"], indirect=True) + def test_ssl_connection_missing_cert( + self, + cli_runner: CliRunner, + sqlite_database: "os.PathLike[t.Any]", + mysql_credentials: MySQLCredentials, + mysql_database: Database, + tmp_path_factory: pytest.TempPathFactory, + ): + certs_dir = tmp_path_factory.getbasetemp() / "certs" + + result: Result = cli_runner.invoke( + mysql2sqlite, + [ + "-f", + str(sqlite_database), + "-d", + mysql_credentials.database, + "-u", + mysql_credentials.user, + "--mysql-password", + mysql_credentials.password, + "-h", + mysql_credentials.host, + "-P", + str(mysql_credentials.port), + "--mysql-ssl-ca", + str(certs_dir / "ca.pem"), + "--mysql-ssl-key", + str(certs_dir / "client-key.pem"), + ], + ) + + assert result.exit_code > 0 + assert "ssl_key and ssl_cert need to be both set, or neither" in result.output + + @pytest.mark.parametrize("mysql_credentials", ["ssl"], indirect=True) + def test_ssl_connection_missing_key( + self, + cli_runner: CliRunner, + sqlite_database: "os.PathLike[t.Any]", + mysql_credentials: MySQLCredentials, + mysql_database: Database, + tmp_path_factory: pytest.TempPathFactory, + ): + certs_dir = tmp_path_factory.getbasetemp() / "certs" + + result: Result = cli_runner.invoke( + mysql2sqlite, + [ + "-f", + str(sqlite_database), + "-d", + mysql_credentials.database, + "-u", + mysql_credentials.user, + "--mysql-password", + mysql_credentials.password, + "-h", + mysql_credentials.host, + "-P", + str(mysql_credentials.port), + "--mysql-ssl-ca", + str(certs_dir / "ca.pem"), + "--mysql-ssl-cert", + str(certs_dir / "client-cert.pem"), + ], + ) + + assert result.exit_code > 0 + assert "ssl_key and ssl_cert need to be both set, or neither" in result.output + + @pytest.mark.parametrize("mysql_credentials", ["ssl"], indirect=True) + def test_ssl_connection_only_ca( + self, + cli_runner: CliRunner, + sqlite_database: "os.PathLike[t.Any]", + mysql_credentials: MySQLCredentials, + mysql_database: Database, + tmp_path_factory: pytest.TempPathFactory, + ): + certs_dir = tmp_path_factory.getbasetemp() / "certs" + + result: Result = cli_runner.invoke( + mysql2sqlite, + [ + "-f", + str(sqlite_database), + "-d", + mysql_credentials.database, + "-u", + mysql_credentials.user, + "--mysql-password", + mysql_credentials.password, + "-h", + mysql_credentials.host, + "-P", + str(mysql_credentials.port), + "--mysql-ssl-ca", + str(certs_dir / "ca.pem"), + ], + ) + + assert result.exit_code == 0 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..9e196f9 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,104 @@ +import subprocess +from pathlib import Path + + +def stream_logs(container): + for line in container.logs(stream=True, follow=True): + print(f"Container log: {line.decode('utf-8').strip()}") + + +def generate_ssl_certs(certs_dir: Path): + certs_dir.mkdir(parents=True, exist_ok=True) + + # File paths + ca_key = certs_dir / "ca-key.pem" + ca_cert = certs_dir / "ca.pem" + server_key = certs_dir / "server-key.pem" + server_cert = certs_dir / "server-cert.pem" + client_key = certs_dir / "client-key.pem" + client_cert = certs_dir / "client-cert.pem" + server_csr = certs_dir / "server.csr" + client_csr = certs_dir / "client.csr" + + # Create CA key and certificate + subprocess.run(["openssl", "genrsa", "-out", str(ca_key), "2048"], check=True) + + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-key", + str(ca_key), + "-out", + str(ca_cert), + "-days", + "3650", + "-subj", + "/CN=MySQL Test CA", + ], + check=True, + ) + + # Create server key and CSR (Certificate Signing Request) + subprocess.run(["openssl", "genrsa", "-out", str(server_key), "2048"], check=True) + + subprocess.run( + ["openssl", "req", "-new", "-key", str(server_key), "-out", str(server_csr), "-subj", "/CN=MySQL Server"], + check=True, + ) + + # Sign the server CSR with the CA to create the server certificate + subprocess.run( + [ + "openssl", + "x509", + "-req", + "-in", + str(server_csr), + "-CA", + str(ca_cert), + "-CAkey", + str(ca_key), + "-CAcreateserial", + "-out", + str(server_cert), + "-days", + "3650", + ], + check=True, + ) + + # Create client key and CSR + subprocess.run(["openssl", "genrsa", "-out", str(client_key), "2048"], check=True) + + subprocess.run( + ["openssl", "req", "-new", "-key", str(client_key), "-out", str(client_csr), "-subj", "/CN=MySQL Client"], + check=True, + ) + + # Sign the client CSR with the CA to create the client certificate + subprocess.run( + [ + "openssl", + "x509", + "-req", + "-in", + str(client_csr), + "-CA", + str(ca_cert), + "-CAkey", + str(ca_key), + "-CAcreateserial", + "-out", + str(client_cert), + "-days", + "3650", + ], + check=True, + ) + + # Clean up CSR files + server_csr.unlink() + client_csr.unlink()