From 5ce89211b1d735cbf19404b3ad191a0d1fd0da7f Mon Sep 17 00:00:00 2001 From: Heikki Nousiainen Date: Tue, 24 Sep 2024 16:32:21 +0300 Subject: [PATCH] Add user authentication support to archive_sync command as well The previous modification, 8c1c2b0e0cb75f9411d4ea2317f006e0e6909d57, added support for user authentication on the local webserver but missed archive_sync command. --- pghoard/archive_sync.py | 17 ++++++++----- test/test_archivesync.py | 54 +++++++++++++++++++--------------------- test/test_webserver.py | 27 ++++++++++++-------- 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/pghoard/archive_sync.py b/pghoard/archive_sync.py index 542c5633..2e92f946 100644 --- a/pghoard/archive_sync.py +++ b/pghoard/archive_sync.py @@ -10,7 +10,8 @@ import os import sys -import requests +from requests import Session +from requests.auth import HTTPBasicAuth from rohmu.errors import InvalidConfigurationError from pghoard.common import get_pg_wal_directory @@ -34,19 +35,23 @@ def __init__(self): self.site = None self.backup_site = None self.base_url = None + self.session = None def set_config(self, config_file, site): self.config = config.read_json_config_file(config_file, check_commands=False) self.site = config.get_site_from_config(self.config, site) self.backup_site = self.config["backup_sites"][self.site] self.base_url = "http://127.0.0.1:{}/{}".format(self.config["http_port"], self.site) + self.session = Session() + if self.config.get("webserver_username") and self.config.get("webserver_password"): + self.session.auth = HTTPBasicAuth(self.config["webserver_username"], self.config["webserver_password"]) def get_current_wal_file(self): # identify the (must be) local database return wal.get_current_lsn(self.backup_site["nodes"][0]).walfile_name def get_first_required_wal_segment(self): - resp = requests.get("{base}/basebackup".format(base=self.base_url)) + resp = self.session.get("{base}/basebackup".format(base=self.base_url)) if resp.status_code != 200: self.log.error("Error looking up basebackups") return None, None @@ -106,7 +111,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks): archive_type = "WAL" if archive_type: - resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) + resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) if resp.status_code == 200: remote_hash = resp.headers.get("metadata-hash") hash_algorithm = resp.headers.get("metadata-hash-algorithm") @@ -147,7 +152,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks): need_archival.append(wal_file) for wal_file in sorted(need_archival): # sort oldest to newest - resp = requests.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) + resp = self.session.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) archive_type = "TIMELINE" if ".history" in wal_file else "WAL" if resp.status_code != 201: self.log.error("%s file %r archival failed with status code %r", archive_type, wal_file, resp.status_code) @@ -175,7 +180,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure): # Decrement one segment if we're on a valid timeline current_lsn = current_lsn.previous_walfile_start_lsn wal_file = current_lsn.walfile_name - resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) + resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file)) if resp.status_code == 200: self.log.info("%s file %r correctly archived", archive_type, wal_file) file_count += 1 @@ -201,7 +206,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure): current_lsn = current_lsn.at_timeline(current_lsn.timeline_id - 1) def request_basebackup(self): - resp = requests.put("{base}/archive/basebackup".format(base=self.base_url)) + resp = self.session.put("{base}/archive/basebackup".format(base=self.base_url)) if resp.status_code != 201: self.log.error("Request for a new backup for site: %r failed", self.site) else: diff --git a/test/test_archivesync.py b/test/test_archivesync.py index b119808d..48c9f662 100644 --- a/test/test_archivesync.py +++ b/test/test_archivesync.py @@ -1,6 +1,6 @@ import hashlib import os -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest @@ -43,9 +43,7 @@ def requests_head_call_return(*args, **kwargs): # pylint: disable=unused-argume return HTTPResult(status_code) -@patch("requests.head") -@patch("requests.put") -def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpdir): +def test_check_wal_archive_integrity(tmpdir): from pghoard.archive_sync import ArchiveSync, SyncError # Instantiate a fake PG data directory @@ -57,59 +55,57 @@ def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpd write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": pg_data_directory}}}) arsy = ArchiveSync() arsy.set_config(config_file, site="foo") - requests_put_mock.return_value = HTTPResult(201) # So the backup requests succeeds - requests_head_mock.side_effect = requests_head_call_return + arsy.session.put = Mock(return_value=HTTPResult(201)) # So the backup requests succeeds + arsy.session.head = Mock(side_effect=requests_head_call_return) # Check integrity within same timeline arsy.get_current_wal_file = Mock(return_value="00000005000000000000008F") arsy.get_first_required_wal_segment = Mock(return_value=("00000005000000000000008C", 90300)) assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0 - assert requests_head_mock.call_count == 3 - assert requests_put_mock.call_count == 0 + assert arsy.session.head.call_count == 3 + assert arsy.session.put.call_count == 0 # Check integrity when timeline has changed - requests_head_mock.call_count = 0 - requests_put_mock.call_count = 0 + arsy.session.head.call_count = 0 + arsy.session.put.call_count = 0 arsy.get_current_wal_file = Mock(return_value="000000090000000000000008") arsy.get_first_required_wal_segment = Mock(return_value=("000000080000000000000005", 90300)) assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0 - assert requests_head_mock.call_count == 4 + assert arsy.session.head.call_count == 4 - requests_head_mock.call_count = 0 - requests_put_mock.call_count = 0 + arsy.session.head.call_count = 0 + arsy.session.put.call_count = 0 arsy.get_current_wal_file = Mock(return_value="000000030000000000000008") arsy.get_first_required_wal_segment = Mock(return_value=("000000030000000000000005", 90300)) with pytest.raises(SyncError): arsy.check_wal_archive_integrity(new_backup_on_failure=False) - assert requests_put_mock.call_count == 0 + assert arsy.session.put.call_count == 0 assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0 - assert requests_put_mock.call_count == 1 + assert arsy.session.put.call_count == 1 - requests_head_mock.call_count = 0 - requests_put_mock.call_count = 0 + arsy.session.head.call_count = 0 + arsy.session.put.call_count = 0 arsy.get_current_wal_file = Mock(return_value="000000070000000000000002") arsy.get_first_required_wal_segment = Mock(return_value=("000000060000000000000001", 90300)) assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0 - assert requests_put_mock.call_count == 0 + assert arsy.session.put.call_count == 0 - requests_head_mock.call_count = 0 - requests_put_mock.call_count = 0 + arsy.session.head.call_count = 0 + arsy.session.put.call_count = 0 arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000") arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90200)) assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0 - assert requests_put_mock.call_count == 0 + assert arsy.session.put.call_count == 0 - requests_head_mock.call_count = 0 - requests_put_mock.call_count = 0 + arsy.session.head.call_count = 0 + arsy.session.put.call_count = 0 arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000") arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90300)) assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0 - assert requests_put_mock.call_count == 1 + assert arsy.session.put.call_count == 1 -@patch("requests.head") -@patch("requests.put") -def test_check_and_upload_missing_local_files(requests_put_mock, requests_head_mock, tmpdir): +def test_check_and_upload_missing_local_files(tmpdir): from pghoard.archive_sync import ArchiveSync data_dir = str(tmpdir) @@ -157,8 +153,8 @@ def requests_put(*args, **kwargs): # pylint: disable=unused-argument write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": data_dir}}}) arsy = ArchiveSync() arsy.set_config(config_file, site="foo") - requests_put_mock.side_effect = requests_put - requests_head_mock.side_effect = requests_head + arsy.session.put = Mock(side_effect=requests_put) + arsy.session.head = Mock(side_effect=requests_head) arsy.get_current_wal_file = Mock(return_value="00000000000000000000001A") arsy.get_first_required_wal_segment = Mock(return_value=("000000000000000000000001", 90300)) diff --git a/test/test_webserver.py b/test/test_webserver.py index 7c7f0976..c72f9066 100644 --- a/test/test_webserver.py +++ b/test/test_webserver.py @@ -242,7 +242,7 @@ def _switch_wal(self, db, count): conn.close() return start_wal, end_wal - def test_archive_sync(self, db, pghoard, pg_version: str): + def _test_archive_sync(self, db, pghoard, pg_version: str): log = logging.getLogger("test_archive_sync") store = pghoard.transfer_agents[0].get_object_storage(pghoard.test_site) @@ -350,36 +350,43 @@ def write_dummy_wal(inc): db.run_pg() db.run_cmd("pg_ctl", "-D", db.pgdata, "promote") time.sleep(5) # TODO: instead of sleeping, poll the db until ready - # we should have a single timeline file in pg_xlog/pg_wal now + # we should have one or more timeline file in pg_xlog/pg_wal now pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)} assert len(pg_wal_timelines) > 0 - # but there should be nothing archived as archive_command wasn't setup + # but there should be one less archived as archive_command wasn't setup/active archived_timelines = set(list_archive("timeline")) - assert len(archived_timelines) == 0 + assert len(archived_timelines) == len(pg_wal_timelines) - 1 # let's hit archive sync arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path]) # now we should have an archived timeline archived_timelines = set(list_archive("timeline")) assert archived_timelines.issuperset(pg_wal_timelines) - assert "00000002.history" in archived_timelines # let's take a new basebackup self._run_and_wait_basebackup(pghoard, db, "basic") + # nuke archives and resync them for name in list_archive(folder="timeline"): store.delete_key(os.path.join(pghoard.test_site, "timeline", name)) for name in list_archive(folder="xlog"): store.delete_key(os.path.join(pghoard.test_site, "xlog", name)) - self._switch_wal(db, 1) + + start_wal, _ = self._switch_wal(db, 1) + pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal} + pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)} arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path]) archived_wals = set(list_archive("xlog")) - # assume the same timeline file as before and one to three wal files - assert len(archived_wals) >= 1 - assert len(archived_wals) <= 3 + assert archived_wals.issuperset(pg_wals) archived_timelines = set(list_archive("timeline")) - assert list(archived_timelines) == ["00000002.history"] + assert archived_timelines.issuperset(pg_wal_timelines) + + def test_archive_sync(self, db, pghoard, pg_version: str): + self._test_archive_sync(db, pghoard, pg_version) + + def test_archive_sync_with_userauth(self, db, pghoard_with_userauth, pg_version: str): + self._test_archive_sync(db, pghoard_with_userauth, pg_version) def test_archive_command_with_invalid_file(self, pghoard): # only WAL and timeline (.history) files can be archived