Skip to content

Commit

Permalink
Add user authentication support to archive_sync command as well
Browse files Browse the repository at this point in the history
The previous modification, 8c1c2b0,
added support for user authentication on the local webserver but missed
archive_sync command.
  • Loading branch information
hnousiainen committed Sep 24, 2024
1 parent 69d1115 commit 59fab60
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
17 changes: 11 additions & 6 deletions pghoard/archive_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import os
import sys

import requests
from requests import Session
from requests.auth import HTTPBasicAuth
from rohmu.errors import InvalidConfigurationError

from pghoard.common import get_pg_wal_directory
Expand All @@ -34,19 +35,23 @@ def __init__(self):
self.site = None
self.backup_site = None
self.base_url = None
self.session = None

def set_config(self, config_file, site):
self.config = config.read_json_config_file(config_file, check_commands=False)
self.site = config.get_site_from_config(self.config, site)
self.backup_site = self.config["backup_sites"][self.site]
self.base_url = "http://127.0.0.1:{}/{}".format(self.config["http_port"], self.site)
self.session = Session()
if self.config.get("webserver_username") and self.config.get("webserver_password"):
self.session.auth = HTTPBasicAuth(self.config["webserver_username"], self.config["webserver_password"])

def get_current_wal_file(self):
# identify the (must be) local database
return wal.get_current_lsn(self.backup_site["nodes"][0]).walfile_name

def get_first_required_wal_segment(self):
resp = requests.get("{base}/basebackup".format(base=self.base_url))
resp = self.session.get("{base}/basebackup".format(base=self.base_url))
if resp.status_code != 200:
self.log.error("Error looking up basebackups")
return None, None
Expand Down Expand Up @@ -106,7 +111,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
archive_type = "WAL"

if archive_type:
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
if resp.status_code == 200:
remote_hash = resp.headers.get("metadata-hash")
hash_algorithm = resp.headers.get("metadata-hash-algorithm")
Expand Down Expand Up @@ -147,7 +152,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
need_archival.append(wal_file)

for wal_file in sorted(need_archival): # sort oldest to newest
resp = requests.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
archive_type = "TIMELINE" if ".history" in wal_file else "WAL"
if resp.status_code != 201:
self.log.error("%s file %r archival failed with status code %r", archive_type, wal_file, resp.status_code)
Expand Down Expand Up @@ -175,7 +180,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
# Decrement one segment if we're on a valid timeline
current_lsn = current_lsn.previous_walfile_start_lsn
wal_file = current_lsn.walfile_name
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
if resp.status_code == 200:
self.log.info("%s file %r correctly archived", archive_type, wal_file)
file_count += 1
Expand All @@ -201,7 +206,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
current_lsn = current_lsn.at_timeline(current_lsn.timeline_id - 1)

def request_basebackup(self):
resp = requests.put("{base}/archive/basebackup".format(base=self.base_url))
resp = self.session.put("{base}/archive/basebackup".format(base=self.base_url))
if resp.status_code != 201:
self.log.error("Request for a new backup for site: %r failed", self.site)
else:
Expand Down
33 changes: 20 additions & 13 deletions test/test_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _switch_wal(self, db, count):
conn.close()
return start_wal, end_wal

def test_archive_sync(self, db, pghoard):
def _test_archive_sync(self, db, pghoard):
log = logging.getLogger("test_archive_sync")
store = pghoard.transfer_agents[0].get_object_storage(pghoard.test_site)

Expand All @@ -267,13 +267,13 @@ def list_archive(folder):
self._run_and_wait_basebackup(pghoard, db, "pipe")

# force a couple of wal segment switches
start_wal, _ = self._switch_wal(db, 4)
start_wal, end_wal = self._switch_wal(db, 4)
# we should have at least 4 WAL files now (there may be more in
# case other tests created them -- we share a single postresql
# cluster between all tests)
pg_wal_dir = get_pg_wal_directory(pghoard.config["backup_sites"][pghoard.test_site])
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
assert len(pg_wals) >= 4
assert len(pg_wals) == int(end_wal, 16) - int(start_wal, 16)

# create a couple of "recycled" xlog files that we must ignore
last_wal = sorted(pg_wals)[-1]
Expand All @@ -291,7 +291,7 @@ def write_dummy_wal(inc):
# check what we have archived, there should be at least the three
# above WALs that are NOT there at the moment
archived_wals = set(list_archive("xlog"))
assert len(pg_wals - archived_wals) >= 4
assert len(pg_wals - archived_wals) >= 3
# now perform an archive sync
arsy = ArchiveSync()
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
Expand Down Expand Up @@ -346,36 +346,43 @@ def write_dummy_wal(inc):
db.run_pg()
db.run_cmd("pg_ctl", "-D", db.pgdata, "promote")
time.sleep(5) # TODO: instead of sleeping, poll the db until ready
# we should have a single timeline file in pg_xlog/pg_wal now
# we should have one or more timeline file in pg_xlog/pg_wal now
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
assert len(pg_wal_timelines) > 0
# but there should be nothing archived as archive_command wasn't setup
# but there should be one less archived as archive_command wasn't setup/active
archived_timelines = set(list_archive("timeline"))
assert len(archived_timelines) == 0
assert len(archived_timelines) == len(pg_wal_timelines) - 1
# let's hit archive sync
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
# now we should have an archived timeline
archived_timelines = set(list_archive("timeline"))
assert archived_timelines.issuperset(pg_wal_timelines)
assert "00000002.history" in archived_timelines

# let's take a new basebackup
self._run_and_wait_basebackup(pghoard, db, "basic")

# nuke archives and resync them
for name in list_archive(folder="timeline"):
store.delete_key(os.path.join(pghoard.test_site, "timeline", name))
for name in list_archive(folder="xlog"):
store.delete_key(os.path.join(pghoard.test_site, "xlog", name))
self._switch_wal(db, 1)

start_wal, _ = self._switch_wal(db, 1)
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}

arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])

archived_wals = set(list_archive("xlog"))
# assume the same timeline file as before and one to three wal files
assert len(archived_wals) >= 1
assert len(archived_wals) <= 3
assert archived_wals.issuperset(pg_wals)
archived_timelines = set(list_archive("timeline"))
assert list(archived_timelines) == ["00000002.history"]
assert archived_timelines.issuperset(pg_wal_timelines)

def test_archive_sync(self, db, pghoard):
self._test_archive_sync(db, pghoard)

def test_archive_sync_with_userauth(self, db, pghoard_with_userauth):
self._test_archive_sync(db, pghoard_with_userauth)

def test_archive_command_with_invalid_file(self, pghoard):
# only WAL and timeline (.history) files can be archived
Expand Down

0 comments on commit 59fab60

Please sign in to comment.