Skip to content

Commit

Permalink
Add user authentication support to archive_sync command as well
Browse files Browse the repository at this point in the history
The previous modification, 8c1c2b0,
added support for user authentication on the local webserver but missed
archive_sync command.
  • Loading branch information
hnousiainen committed Dec 12, 2024
1 parent b93b132 commit d8579a2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 45 deletions.
17 changes: 11 additions & 6 deletions pghoard/archive_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import os
import sys

import requests
from requests import Session
from requests.auth import HTTPBasicAuth
from rohmu.errors import InvalidConfigurationError

from pghoard.common import get_pg_wal_directory
Expand All @@ -34,19 +35,23 @@ def __init__(self):
self.site = None
self.backup_site = None
self.base_url = None
self.session = None

def set_config(self, config_file, site):
self.config = config.read_json_config_file(config_file, check_commands=False)
self.site = config.get_site_from_config(self.config, site)
self.backup_site = self.config["backup_sites"][self.site]
self.base_url = "http://127.0.0.1:{}/{}".format(self.config["http_port"], self.site)
self.session = Session()
if self.config.get("webserver_username") and self.config.get("webserver_password"):
self.session.auth = HTTPBasicAuth(self.config["webserver_username"], self.config["webserver_password"])

def get_current_wal_file(self):
# identify the (must be) local database
return wal.get_current_lsn(self.backup_site["nodes"][0]).walfile_name

def get_first_required_wal_segment(self):
resp = requests.get("{base}/basebackup".format(base=self.base_url))
resp = self.session.get("{base}/basebackup".format(base=self.base_url))
if resp.status_code != 200:
self.log.error("Error looking up basebackups")
return None, None
Expand Down Expand Up @@ -106,7 +111,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
archive_type = "WAL"

if archive_type:
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
if resp.status_code == 200:
remote_hash = resp.headers.get("metadata-hash")
hash_algorithm = resp.headers.get("metadata-hash-algorithm")
Expand Down Expand Up @@ -147,7 +152,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
need_archival.append(wal_file)

for wal_file in sorted(need_archival): # sort oldest to newest
resp = requests.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
archive_type = "TIMELINE" if ".history" in wal_file else "WAL"
if resp.status_code != 201:
self.log.error("%s file %r archival failed with status code %r", archive_type, wal_file, resp.status_code)
Expand Down Expand Up @@ -175,7 +180,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
# Decrement one segment if we're on a valid timeline
current_lsn = current_lsn.previous_walfile_start_lsn
wal_file = current_lsn.walfile_name
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
if resp.status_code == 200:
self.log.info("%s file %r correctly archived", archive_type, wal_file)
file_count += 1
Expand All @@ -201,7 +206,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
current_lsn = current_lsn.at_timeline(current_lsn.timeline_id - 1)

def request_basebackup(self):
resp = requests.put("{base}/archive/basebackup".format(base=self.base_url))
resp = self.session.put("{base}/archive/basebackup".format(base=self.base_url))
if resp.status_code != 201:
self.log.error("Request for a new backup for site: %r failed", self.site)
else:
Expand Down
54 changes: 25 additions & 29 deletions test/test_archivesync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
import os
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest

Expand Down Expand Up @@ -43,9 +43,7 @@ def requests_head_call_return(*args, **kwargs): # pylint: disable=unused-argume
return HTTPResult(status_code)


@patch("requests.head")
@patch("requests.put")
def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpdir):
def test_check_wal_archive_integrity(tmpdir):
from pghoard.archive_sync import ArchiveSync, SyncError

# Instantiate a fake PG data directory
Expand All @@ -57,59 +55,57 @@ def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpd
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": pg_data_directory}}})
arsy = ArchiveSync()
arsy.set_config(config_file, site="foo")
requests_put_mock.return_value = HTTPResult(201) # So the backup requests succeeds
requests_head_mock.side_effect = requests_head_call_return
arsy.session.put = Mock(return_value=HTTPResult(201)) # So the backup requests succeeds
arsy.session.head = Mock(side_effect=requests_head_call_return)

# Check integrity within same timeline
arsy.get_current_wal_file = Mock(return_value="00000005000000000000008F")
arsy.get_first_required_wal_segment = Mock(return_value=("00000005000000000000008C", 90300))
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
assert requests_head_mock.call_count == 3
assert requests_put_mock.call_count == 0
assert arsy.session.head.call_count == 3
assert arsy.session.put.call_count == 0

# Check integrity when timeline has changed
requests_head_mock.call_count = 0
requests_put_mock.call_count = 0
arsy.session.head.call_count = 0
arsy.session.put.call_count = 0
arsy.get_current_wal_file = Mock(return_value="000000090000000000000008")
arsy.get_first_required_wal_segment = Mock(return_value=("000000080000000000000005", 90300))
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
assert requests_head_mock.call_count == 4
assert arsy.session.head.call_count == 4

requests_head_mock.call_count = 0
requests_put_mock.call_count = 0
arsy.session.head.call_count = 0
arsy.session.put.call_count = 0
arsy.get_current_wal_file = Mock(return_value="000000030000000000000008")
arsy.get_first_required_wal_segment = Mock(return_value=("000000030000000000000005", 90300))
with pytest.raises(SyncError):
arsy.check_wal_archive_integrity(new_backup_on_failure=False)
assert requests_put_mock.call_count == 0
assert arsy.session.put.call_count == 0
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
assert requests_put_mock.call_count == 1
assert arsy.session.put.call_count == 1

requests_head_mock.call_count = 0
requests_put_mock.call_count = 0
arsy.session.head.call_count = 0
arsy.session.put.call_count = 0
arsy.get_current_wal_file = Mock(return_value="000000070000000000000002")
arsy.get_first_required_wal_segment = Mock(return_value=("000000060000000000000001", 90300))
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
assert requests_put_mock.call_count == 0
assert arsy.session.put.call_count == 0

requests_head_mock.call_count = 0
requests_put_mock.call_count = 0
arsy.session.head.call_count = 0
arsy.session.put.call_count = 0
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90200))
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
assert requests_put_mock.call_count == 0
assert arsy.session.put.call_count == 0

requests_head_mock.call_count = 0
requests_put_mock.call_count = 0
arsy.session.head.call_count = 0
arsy.session.put.call_count = 0
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90300))
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
assert requests_put_mock.call_count == 1
assert arsy.session.put.call_count == 1


@patch("requests.head")
@patch("requests.put")
def test_check_and_upload_missing_local_files(requests_put_mock, requests_head_mock, tmpdir):
def test_check_and_upload_missing_local_files(tmpdir):
from pghoard.archive_sync import ArchiveSync

data_dir = str(tmpdir)
Expand Down Expand Up @@ -157,8 +153,8 @@ def requests_put(*args, **kwargs): # pylint: disable=unused-argument
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": data_dir}}})
arsy = ArchiveSync()
arsy.set_config(config_file, site="foo")
requests_put_mock.side_effect = requests_put
requests_head_mock.side_effect = requests_head
arsy.session.put = Mock(side_effect=requests_put)
arsy.session.head = Mock(side_effect=requests_head)
arsy.get_current_wal_file = Mock(return_value="00000000000000000000001A")
arsy.get_first_required_wal_segment = Mock(return_value=("000000000000000000000001", 90300))

Expand Down
27 changes: 17 additions & 10 deletions test/test_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def list_archive(folder):
self._run_and_wait_basebackup(pghoard, db, "pipe")

# force a couple of wal segment switches
start_wal, _ = self._switch_wal(db, 4)
start_wal, end_wal = self._switch_wal(db, 4)
# we should have at least 4 WAL files now (there may be more in
# case other tests created them -- we share a single postresql
# cluster between all tests)
Expand Down Expand Up @@ -350,36 +350,43 @@ def write_dummy_wal(inc):
db.run_pg()
db.run_cmd("pg_ctl", "-D", db.pgdata, "promote")
time.sleep(5) # TODO: instead of sleeping, poll the db until ready
# we should have a single timeline file in pg_xlog/pg_wal now
# we should have one or more timeline file in pg_xlog/pg_wal now
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
assert len(pg_wal_timelines) > 0
# but there should be nothing archived as archive_command wasn't setup
# but there should be one less archived as archive_command wasn't setup/active
archived_timelines = set(list_archive("timeline"))
assert len(archived_timelines) == 0
assert len(archived_timelines) == len(pg_wal_timelines) - 1
# let's hit archive sync
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
# now we should have an archived timeline
archived_timelines = set(list_archive("timeline"))
assert archived_timelines.issuperset(pg_wal_timelines)
assert "00000002.history" in archived_timelines

# let's take a new basebackup
self._run_and_wait_basebackup(pghoard, db, "basic")

# nuke archives and resync them
for name in list_archive(folder="timeline"):
store.delete_key(os.path.join(pghoard.test_site, "timeline", name))
for name in list_archive(folder="xlog"):
store.delete_key(os.path.join(pghoard.test_site, "xlog", name))
self._switch_wal(db, 1)

start_wal, _ = self._switch_wal(db, 1)
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}

arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])

archived_wals = set(list_archive("xlog"))
# assume the same timeline file as before and one to three wal files
assert len(archived_wals) >= 1
assert len(archived_wals) <= 3
assert archived_wals.issuperset(pg_wals)
archived_timelines = set(list_archive("timeline"))
assert list(archived_timelines) == ["00000002.history"]
assert archived_timelines.issuperset(pg_wal_timelines)

def test_archive_sync(self, db, pghoard):
self._test_archive_sync(db, pghoard)

def test_archive_sync_with_userauth(self, db, pghoard_with_userauth):
self._test_archive_sync(db, pghoard_with_userauth)

def test_archive_command_with_invalid_file(self, pghoard):
# only WAL and timeline (.history) files can be archived
Expand Down

0 comments on commit d8579a2

Please sign in to comment.