diff --git a/golang/pghoard_postgres_command_go.go b/golang/pghoard_postgres_command_go.go index 13a5fbd1..e55161aa 100644 --- a/golang/pghoard_postgres_command_go.go +++ b/golang/pghoard_postgres_command_go.go @@ -16,6 +16,7 @@ import ( "net/http" "os" "path" + "regexp" "time" ) @@ -81,7 +82,7 @@ func run() (int, error) { retry_seconds := *riPtr for { attempt += 1 - rc, err := restore_command(url, *outputPtr) + rc, err := restore_command(url, *outputPtr, *xlogPtr) if rc != EXIT_RESTORE_FAIL { return rc, err } @@ -100,7 +101,7 @@ func archive_command(url string) (int, error) { return EXIT_ABORT, errors.New("archive_command not yet implemented") } -func restore_command(url string, output string) (int, error) { +func restore_command(url string, output string, xlog string) (int, error) { var output_path string var req *http.Request var err error @@ -120,6 +121,19 @@ func restore_command(url string, output string) (int, error) { } output_path = path.Join(cwd, output) } + xlogNameRe := regexp.MustCompile(`^([A-F0-9]{24}|[A-F0-9]{8}\.history)$`) + if xlogNameRe.MatchString(xlog) { + // if file ".pghoard.prefetch" exists, just move it to destination + xlogPrefetchPath := path.Join(path.Dir(output_path), xlog+".pghoard.prefetch") + _, err = os.Stat(xlogPrefetchPath) + if err == nil { + err := os.Rename(xlogPrefetchPath, output_path) + if err != nil { + return EXIT_ABORT, err + } + return EXIT_OK, nil + } + } req, err = http.NewRequest("GET", url, nil) req.Header.Set("x-pghoard-target-path", output_path) } diff --git a/pghoard/common.py b/pghoard/common.py index 6f93b7cb..e02e6a84 100644 --- a/pghoard/common.py +++ b/pghoard/common.py @@ -448,8 +448,8 @@ class UnhandledThreadException(Exception): class PGHoardThread(Thread): - def __init__(self): - super().__init__() + def __init__(self, name: Optional[str] = None): + super().__init__(name=name) self.exception: Optional[Exception] = None def run_safe(self): diff --git a/pghoard/pghoard.py b/pghoard/pghoard.py index 396ecda1..4537a805 100644 --- a/pghoard/pghoard.py +++ b/pghoard/pghoard.py @@ -47,7 +47,7 @@ from pghoard.receivexlog import PGReceiveXLog from pghoard.transfer import (TransferAgent, TransferQueue, UploadEvent, UploadEventProgressTracker) from pghoard.walreceiver import WALReceiver -from pghoard.webserver import WebServer +from pghoard.webserver import DownloadResultsProcessor, WebServer @dataclass @@ -149,6 +149,10 @@ def __init__(self, config_path): self.webserver = WebServer( self.config, self.requested_basebackup_sites, self.compression_queue, self.transfer_queue, self.metrics ) + self.download_results_processor = DownloadResultsProcessor( + self.webserver.lock, self.webserver.download_results, self.webserver.pending_download_ops, + self.webserver.prefetch_404 + ) self.wal_file_deleter = WALFileDeleterThread( config=self.config, wal_file_deletion_queue=self.wal_file_deletion_queue, metrics=self.metrics @@ -701,6 +705,7 @@ def start_threads_on_startup(self): self.inotify.start() self.upload_tracker.start() self.webserver.start() + self.download_results_processor.start() self.wal_file_deleter.start() for compressor in self.compressors: compressor.start() @@ -983,6 +988,8 @@ def _get_all_threads(self): if hasattr(self, "webserver"): all_threads.append(self.webserver) + if hasattr(self, "download_results_processor"): + all_threads.append(self.download_results_processor) all_threads.extend(self.basebackups.values()) all_threads.extend(self.receivexlogs.values()) all_threads.extend(self.walreceivers.values()) diff --git a/pghoard/postgres_command.py b/pghoard/postgres_command.py index 087f83d5..df753ff5 100644 --- a/pghoard/postgres_command.py +++ b/pghoard/postgres_command.py @@ -12,6 +12,8 @@ import time from http.client import BadStatusLine, HTTPConnection, IncompleteRead +from pghoard.wal import TIMELINE_RE, WAL_RE + from . import version PGHOARD_HOST = "127.0.0.1" @@ -72,6 +74,12 @@ def restore_command(site, xlog, output, host=PGHOARD_HOST, port=PGHOARD_PORT, re # directory. Note that os.path.join strips preceding components if a new components starts with a # slash so it's still possible to use this with absolute paths. output_path = os.path.join(os.getcwd(), output) + if WAL_RE.match(xlog) or TIMELINE_RE.match(xlog): + # if file ".pghoard.prefetch" exists, just move it to destination + prefetch_path = os.path.join(os.path.dirname(output_path), xlog + ".pghoard.prefetch") + if os.path.exists(prefetch_path): + os.rename(prefetch_path, output_path) + return headers = {"x-pghoard-target-path": output_path} method = "GET" path = "/{}/archive/{}".format(site, xlog) diff --git a/pghoard/webserver.py b/pghoard/webserver.py index 2fd29db9..b242f304 100644 --- a/pghoard/webserver.py +++ b/pghoard/webserver.py @@ -14,21 +14,31 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress +from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from queue import Empty, Queue from socketserver import ThreadingMixIn from threading import RLock +from typing import Dict from rohmu.errors import Error, FileNotFoundFromStorageError from pghoard import wal -from pghoard.common import (FileType, FileTypePrefixes, PGHoardThread, get_pg_wal_directory, json_encode) +from pghoard.common import (CallbackEvent, FileType, FileTypePrefixes, PGHoardThread, get_pg_wal_directory, json_encode) from pghoard.compressor import CompressionEvent from pghoard.transfer import DownloadEvent, OperationEvents, TransferOperation from pghoard.version import __version__ +@dataclass(frozen=True) +class PendingDownloadOp: + started_at: float + target_path: str + filetype: str + filename: str + + class PoolMixIn(ThreadingMixIn): def process_request(self, request, client_address): self.pool.submit(self.process_request_thread, request, client_address) @@ -39,7 +49,10 @@ class OwnHTTPServer(PoolMixIn, HTTPServer): pool = ThreadPoolExecutor(max_workers=10) requested_basebackup_sites = None - def __init__(self, server_address, RequestHandlerClass): + def __init__( + self, server_address, request_handler, *, config, log, requested_basebackup_sites, compression_queue, transfer_queue, + lock, pending_download_ops, download_results, prefetch_404, metrics + ): # to avoid any kind of regression where the server address is not a legal ip address, catch any ValueError try: # specifying an empty http_address will make pghoard listen on all IPV4 addresses, @@ -50,7 +63,26 @@ def __init__(self, server_address, RequestHandlerClass): self.address_family = socket.AF_INET6 except ValueError: pass - HTTPServer.__init__(self, server_address, RequestHandlerClass) + HTTPServer.__init__(self, server_address, request_handler) + + self.config = config + self.log = log + self.requested_basebackup_sites = requested_basebackup_sites + self.compression_queue = compression_queue + self.transfer_queue = transfer_queue + self.lock = lock + self.pending_download_ops = pending_download_ops + self.download_results = download_results + self.most_recently_served_files = {} + # Bounded list of files returned from local disk. Sometimes the file on disk is in some way "bad" + # and PostgreSQL doesn't accept it and keeps on requesting it again. If the file was recently served + # from disk serve it from file storage instead because the file there could be different. + self.served_from_disk = deque(maxlen=10) + # Bounded negative cache for failed prefetch operations - we don't want to try prefetching files that + # aren't there. This isn't used for explicit download requests as it's possible that a file appears + # later on in the object store. + self.prefetch_404 = prefetch_404 + self.metrics = metrics class HttpResponse(Exception): @@ -65,6 +97,83 @@ def __init__(self, msg=None, headers=None, status=500): super().__init__("{} {}".format(self.__class__.__name__, status)) +class DownloadResultsProcessor(PGHoardThread): + """ + Processes download_results queue, validates WAL and renames tmp file to target (".prefetch") + """ + def __init__( + self, lock: RLock, download_results: Queue, pending_download_ops: Dict[str, PendingDownloadOp], prefetch_404: deque + ) -> None: + super().__init__(name=self.__class__.__name__) + self.log = logging.getLogger("WebServer") + self.lock = lock + self.download_results = download_results + self.pending_download_ops = pending_download_ops + self.prefetch_404 = prefetch_404 + self.running = False + + def run_safe(self) -> None: + self.running = True + while self.running: + try: + item = self.download_results.get(block=True, timeout=1.0) + self.process_queue_item(item) + except Empty: + pass + except Exception: # pylint: disable=broad-except + self.log.exception("Unhandled exception in %s", self.__class__.__name__) + + def process_queue_item(self, download_result: CallbackEvent) -> None: + key = str(download_result.opaque) + pending_download_op = self.pending_download_ops.pop(key, None) + with self.lock: + if not download_result.success: + ex = download_result.exception or Error + if isinstance(ex, FileNotFoundFromStorageError): + # don't try prefetching this file again + self.prefetch_404.append(key) + else: + if pending_download_op: + delta = time.monotonic() - pending_download_op.started_at + else: + delta = -1.0 + self.log.warning("Fetching %r failed (%s), took: %.3fs", key, ex.__class__.__name__, delta) + return + + if not isinstance(download_result.payload, dict) or "target_path" not in download_result.payload: + raise RuntimeError( + f"Invalid payload in callback event: {download_result}, payload: {download_result.payload}" + ) + src_tmp_file_path = download_result.payload["target_path"] + + if not pending_download_op: + self.log.warning("Orphaned download operation %r completed: %r", key, download_result) + if download_result.success: + with suppress(OSError): + os.unlink(src_tmp_file_path) + return + + if os.path.isfile(pending_download_op.target_path): + self.log.warning("Target path for %r already exists, skipping", key) + return + # verify wal + if pending_download_op.filetype == "xlog": + try: + wal.verify_wal(wal_name=pending_download_op.filename, filepath=src_tmp_file_path) + self.log.info("WAL verification successful %s", src_tmp_file_path) + except ValueError: + self.log.warning("WAL verification failed %s. Unlink file", src_tmp_file_path) + with suppress(OSError): + os.unlink(src_tmp_file_path) + return + os.rename(src_tmp_file_path, pending_download_op.target_path) + metadata = download_result.payload.get("metadata", {}) + self.log.info( + "Renamed %s to %s. Original upload from %r, hash %s:%s", download_result.payload["target_path"], + pending_download_op.target_path, metadata.get("host"), metadata.get("hash-algorithm"), metadata.get("hash") + ) + + class WebServer(PGHoardThread): def __init__(self, config, requested_basebackup_sites, compression_queue, transfer_queue, metrics): super().__init__() @@ -83,29 +192,23 @@ def __init__(self, config, requested_basebackup_sites, compression_queue, transf self._running = False self.log.debug("WebServer initialized with address: %r port: %r", self.address, self.port) self.is_initialized = threading.Event() + self.prefetch_404 = deque(maxlen=32) def run_safe(self): # We bind the port only when we start running self._running = True - self.server = OwnHTTPServer((self.address, self.port), RequestHandler) - self.server.config = self.config # pylint: disable=attribute-defined-outside-init - self.server.log = self.log # pylint: disable=attribute-defined-outside-init - self.server.requested_basebackup_sites = self.requested_basebackup_sites - self.server.compression_queue = self.compression_queue # pylint: disable=attribute-defined-outside-init - self.server.transfer_queue = self.transfer_queue # pylint: disable=attribute-defined-outside-init - self.server.lock = self.lock # pylint: disable=attribute-defined-outside-init - self.server.pending_download_ops = self.pending_download_ops # pylint: disable=attribute-defined-outside-init - self.server.download_results = self.download_results # pylint: disable=attribute-defined-outside-init - self.server.most_recently_served_files = {} # pylint: disable=attribute-defined-outside-init - # Bounded list of files returned from local disk. Sometimes the file on disk is in some way "bad" - # and PostgreSQL doesn't accept it and keeps on requesting it again. If the file was recently served - # from disk serve it from file storage instead because the file there could be different. - self.server.served_from_disk = deque(maxlen=10) # pylint: disable=attribute-defined-outside-init - # Bounded negative cache for failed prefetch operations - we don't want to try prefetching files that - # aren't there. This isn't used for explicit download requests as it's possible that a file appears - # later on in the object store. - self.server.prefetch_404 = deque(maxlen=32) # pylint: disable=attribute-defined-outside-init - self.server.metrics = self.metrics # pylint: disable=attribute-defined-outside-init + self.server = OwnHTTPServer((self.address, self.port), + RequestHandler, + config=self.config, + log=self.log, + requested_basebackup_sites=self.requested_basebackup_sites, + compression_queue=self.compression_queue, + transfer_queue=self.transfer_queue, + lock=self.lock, + pending_download_ops=self.pending_download_ops, + download_results=self.download_results, + prefetch_404=self.prefetch_404, + metrics=self.metrics) self.is_initialized.set() self.server.serve_forever() @@ -138,6 +241,7 @@ def running(self, value): class RequestHandler(BaseHTTPRequestHandler): disable_nagle_algorithm = True server_version = "pghoard/" + __version__ + server: OwnHTTPServer @contextmanager def _response_handler(self, method): @@ -221,23 +325,6 @@ def _parse_request(self, path): raise HttpResponse("Invalid path {!r}".format(path), status=400) - def _verify_wal(self, filetype, filename, path): - if filetype != "xlog": - return - try: - wal.verify_wal(wal_name=filename, filepath=path) - except ValueError as ex: - raise HttpResponse(str(ex), status=412) - - def _save_and_verify_restored_file(self, filetype, filename, tmp_target_path, target_path): - self._verify_wal(filetype, filename, tmp_target_path) - try: - with self.server.lock: - os.rename(tmp_target_path, target_path) - except OSError as ex: - fmt = "Unable to write final file to requested location {path!r}: {ex.__class__.__name__}: {ex}" - raise HttpResponse(fmt.format(path=target_path, ex=ex), status=409) - def _transfer_agent_op(self, site, filename, filetype, method, *, retries=2): start_time = time.time() @@ -283,7 +370,7 @@ def _create_prefetch_operations(self, site, filetype, filename): return xlog_dir = get_pg_wal_directory(self.server.config["backup_sites"][site]) - names = [] + prefetch_filenames = [] if filetype == "timeline": tli_num = int(filename.replace(".history", ""), 16) for _ in range(prefetch_n): @@ -291,7 +378,7 @@ def _create_prefetch_operations(self, site, filetype, filename): prefetch_name = "{:08X}.history".format(tli_num) if os.path.isfile(os.path.join(xlog_dir, prefetch_name)): continue - names.append(prefetch_name) + prefetch_filenames.append(prefetch_name) elif filetype == "xlog": xlog_num = int(filename, 16) for _ in range(prefetch_n): @@ -306,20 +393,20 @@ def _create_prefetch_operations(self, site, filetype, filename): continue except ValueError as e: self.server.log.debug("(Prefetch) File %s already exists but is invalid: %r", xlog_path, e) - names.append(prefetch_name) + prefetch_filenames.append(prefetch_name) - for obname in names: - key = self._make_file_key(site, filetype, obname) + for prefetch_filename in prefetch_filenames: + key = self._make_file_key(site, filetype, prefetch_filename) if key in self.server.prefetch_404: continue # previously failed to prefetch this file, don't try again - self._create_fetch_operation(key, site, filetype, obname) + self._create_fetch_operation(key, site, filetype, prefetch_filename) def _create_fetch_operation(self, key, site, filetype, obname, max_age=-1, suppress_error=True): with self.server.lock: # Don't fetch again if we already have pending fetch operation unless the operation # has been ongoing longer than given max age and has potentially became stale existing = self.server.pending_download_ops.get(key) - if existing and (max_age < 0 or time.monotonic() - existing["started_at"] <= max_age): + if existing and (max_age < 0 or time.monotonic() - existing.started_at <= max_age): return xlog_dir = get_pg_wal_directory(self.server.config["backup_sites"][site]) @@ -343,9 +430,8 @@ def _create_fetch_operation(self, key, site, filetype, obname, max_age=-1, suppr "Fetching site: %r, filename: %r, filetype: %r, tmp_target_path: %r", site, obname, filetype, tmp_target_path ) target_path = os.path.join(xlog_dir, "{}.pghoard.prefetch".format(obname)) - self.server.pending_download_ops[key] = dict( - started_at=time.monotonic(), - target_path=target_path, + self.server.pending_download_ops[key] = PendingDownloadOp( + started_at=time.monotonic(), target_path=target_path, filetype=filetype, filename=obname ) self.server.transfer_queue.put( DownloadEvent( @@ -358,42 +444,6 @@ def _create_fetch_operation(self, key, site, filetype, obname, max_age=-1, suppr ) ) - def _process_completed_download_operations(self, timeout=None): - while True: - try: - result = self.server.download_results.get(block=timeout is not None, timeout=timeout) - key = result.opaque - with self.server.lock: - op = self.server.pending_download_ops.pop(key, None) - if not op: - self.server.log.warning("Orphaned download operation %r completed: %r", key, result) - if result.success: - with suppress(OSError): - os.unlink(result.payload["target_path"]) - continue - if result.success: - if os.path.isfile(op["target_path"]): - self.server.log.warning("Target path for %r already exists, skipping", key) - continue - os.rename(result.payload["target_path"], op["target_path"]) - metadata = result.payload["metadata"] or {} - self.server.log.info( - "Renamed %s to %s. Original upload from %r, hash %s:%s", result.payload["target_path"], - op["target_path"], metadata.get("host"), metadata.get("hash-algorithm"), metadata.get("hash") - ) - else: - ex = result.exception or Error - if isinstance(ex, FileNotFoundFromStorageError): - # don't try prefetching this file again - self.server.prefetch_404.append(key) - else: - self.server.log.warning( - "Fetching %r failed (%s), took: %.3fs", key, ex.__class__.__name__, - time.monotonic() - op["started_at"] - ) - except Empty: - return - def get_status(self, site): state_file_path = self.server.config["json_state_file_path"] if site is None: @@ -414,21 +464,19 @@ def get_metrics(self, site): else: raise HttpResponse(status=501) # Not Implemented - def _try_save_and_verify_restored_file(self, filetype, filename, prefetch_target_path, target_path, unlink=True): + def _rename(self, src: str, dst: str) -> None: try: - self._save_and_verify_restored_file(filetype, filename, prefetch_target_path, target_path) - self.server.log.info("Renamed %s to %s", prefetch_target_path, target_path) - return None - except (ValueError, HttpResponse) as e: - # Just try loading the file again - with suppress(OSError): - self.server.log.warning("Verification of prefetch file %s failed: %r", prefetch_target_path, e) - if unlink: - os.unlink(prefetch_target_path) - return e + with self.server.lock: + os.rename(src, dst) + self.server.log.info("Renamed %s to %s", src, dst) + except OSError as e: + self.server.log.warning( + "Unable to write final file to requested location %s: %s: %r", dst, e.__class__.__name__, e + ) + raise @staticmethod - def _validate_target_path(config, target_path): + def _validate_target_path(pg_data_directory: str, target_path: str) -> None: # The `restore_command` (postgres_command.py or pghoard_postgres_command_go.go) called by PostgresSQL has # prepended the PostgresSQL 'data' directory with `%p` parameter from PostgresSQL server, hence here # `target_path` is expected to be an absolute path. @@ -439,12 +487,14 @@ def _validate_target_path(config, target_path): # /var/lib/pgsql/11/data/../../../../../etc/passwd, could be actually /etc/passwd if not os.path.isabs(target_path): raise HttpResponse(f"Invalid xlog file path {target_path}, an absolute path expected", status=400) - data_dir = Path(config["pg_data_directory"]).resolve() + data_dir = Path(pg_data_directory).resolve() xlog_file = Path(target_path).resolve() if data_dir not in xlog_file.parents: raise HttpResponse(f"Invalid xlog file path {target_path}, it should be in data directory", status=400) + if not xlog_file.parent.is_dir(): + raise HttpResponse(f"Invalid xlog file path {target_path}, parent directory should exist", status=409) - def get_wal_or_timeline_file(self, site, filename, filetype): + def get_wal_or_timeline_file(self, site: str, filename: str, filetype: str) -> None: target_path = self.headers.get("x-pghoard-target-path") if not target_path: raise HttpResponse("x-pghoard-target-path header missing from download", status=400) @@ -452,20 +502,21 @@ def get_wal_or_timeline_file(self, site, filename, filetype): site_config = self.server.config["backup_sites"][site] xlog_dir = get_pg_wal_directory(site_config) - self._validate_target_path(site_config, target_path) - self._process_completed_download_operations() + self._validate_target_path(site_config["pg_data_directory"], target_path) # See if we have already prefetched the file prefetch_target_path = os.path.join(xlog_dir, "{}.pghoard.prefetch".format(filename)) if os.path.exists(prefetch_target_path): - ex = self._try_save_and_verify_restored_file(filetype, filename, prefetch_target_path, target_path) - if not ex: - self._create_prefetch_operations(site, filetype, filename) - self.server.most_recently_served_files[filetype] = { - "name": filename, - "time": time.time(), - } - raise HttpResponse(status=201) + if filetype == "xlog": + # WAL with ".prefetch" suffix should be validated, but checking it anyway + wal.verify_wal(wal_name=filename, filepath=prefetch_target_path) + self._rename(prefetch_target_path, target_path) + self._create_prefetch_operations(site, filetype, filename) + self.server.most_recently_served_files[filetype] = { + "name": filename, + "time": time.time(), + } + raise HttpResponse(status=201) # After reaching a recovery_target and restart of a PG server, PG wants to replay and refetch # files from the archive starting from the latest checkpoint. We have potentially fetched these files @@ -477,10 +528,13 @@ def get_wal_or_timeline_file(self, site, filename, filetype): self.server.log.info( "Requested %r, found it in pg_xlog directory as: %r, returning directly", filename, xlog_path ) - ex = self._try_save_and_verify_restored_file(filetype, filename, xlog_path, target_path, unlink=False) - if ex: - self.server.log.warning("Found file: %r but it was invalid: %s", xlog_path, ex) + try: + if filetype == "xlog": + wal.verify_wal(wal_name=filename, filepath=xlog_path) + except ValueError as e: + self.server.log.warning("Found file: %r but it was invalid: %s", xlog_path, e) else: + self._rename(xlog_path, target_path) self.server.served_from_disk.append(filename) self.server.most_recently_served_files[filetype] = { "name": filename, @@ -500,24 +554,21 @@ def get_wal_or_timeline_file(self, site, filename, filetype): start_time = time.monotonic() retries = 2 while (time.monotonic() - start_time) <= 30: - self._process_completed_download_operations(timeout=0.01) with self.server.lock: if os.path.isfile(prefetch_target_path): - ex = self._try_save_and_verify_restored_file(filetype, filename, prefetch_target_path, target_path) - if not ex: - self.server.most_recently_served_files[filetype] = { - "name": filename, - "time": time.time(), - } - raise HttpResponse(status=201) - elif ex and retries == 0: - raise ex # pylint: disable=raising-bad-type - retries -= 1 + if filetype == "xlog": + wal.verify_wal(wal_name=filename, filepath=prefetch_target_path) + self._rename(prefetch_target_path, target_path) + self.server.most_recently_served_files[filetype] = { + "name": filename, + "time": time.time(), + } + raise HttpResponse(status=201) if key in self.server.prefetch_404: raise HttpResponse(status=404) with self.server.lock: if key not in self.server.pending_download_ops: - if retries == 0: + if retries <= 0: raise HttpResponse(status=500) retries -= 1 self._create_fetch_operation(key, site, filetype, filename, suppress_error=False) @@ -525,6 +576,7 @@ def get_wal_or_timeline_file(self, site, filename, filetype): last_schedule_call = time.monotonic() # Replace existing download operation if it has been executing for too long self._create_fetch_operation(key, site, filetype, filename, max_age=10, suppress_error=False) + time.sleep(0.05) raise HttpResponse("TIMEOUT", status=500) @@ -549,7 +601,11 @@ def handle_archival_request(self, site, filename, filetype): self.server.log.debug("xlog_path: %r did not exist, cannot archive, returning 404", xlog_path) raise HttpResponse("N/A", status=404) - self._verify_wal(filetype, filename, xlog_path) + if filetype == "xlog": + try: + wal.verify_wal(wal_name=filename, filepath=xlog_path) + except ValueError as ex: + raise HttpResponse(str(ex), status=412) callback_queue = Queue() if not self.server.config["backup_sites"][site]["object_storage"]: diff --git a/test/test_pghoard.py b/test/test_pghoard.py index 0d364f5c..1a057148 100644 --- a/test/test_pghoard.py +++ b/test/test_pghoard.py @@ -900,6 +900,7 @@ def test_surviving_pg_receivewal_hickup(self, db, pghoard): os.makedirs(wal_directory, exist_ok=True) pghoard.receivexlog_listener(pghoard.test_site, db.user, wal_directory) + time.sleep(0.5) # waiting for thread setup conn = db.connect() conn.autocommit = True @@ -918,6 +919,7 @@ def test_surviving_pg_receivewal_hickup(self, db, pghoard): # stopping the thread is not enough, it's possible that killed receiver will leave incomplete partial files # around, pghoard is capable of cleaning those up but needs to be restarted, for the test it should be OK # just to call startup_walk_for_missed_files, so it takes care of cleaning up + time.sleep(0.5) # waiting for the end of file processing pghoard.startup_walk_for_missed_files() n_xlogs = pghoard.transfer_agent_state[pghoard.test_site]["upload"]["xlog"]["xlogs_since_basebackup"] @@ -930,6 +932,7 @@ def test_surviving_pg_receivewal_hickup(self, db, pghoard): # restart pghoard.receivexlog_listener(pghoard.test_site, db.user, wal_directory) assert pghoard.receivexlogs[pghoard.test_site].is_alive() + time.sleep(0.5) # waiting for thread setup # We should now process all created segments, not only the ones which were created after pg_receivewal was restarted wait_for_xlog(pghoard, n_xlogs + 10) diff --git a/test/test_webserver.py b/test/test_webserver.py index af785ca1..f68b7d2c 100644 --- a/test/test_webserver.py +++ b/test/test_webserver.py @@ -8,7 +8,9 @@ import logging import os import socket +import threading import time +from collections import deque from distutils.version import LooseVersion from http.client import HTTPConnection from queue import Queue @@ -20,11 +22,12 @@ from pghoard import postgres_command, wal from pghoard.archive_sync import ArchiveSync -from pghoard.common import get_pg_wal_directory +from pghoard.common import CallbackEvent, get_pg_wal_directory from pghoard.object_store import HTTPRestore from pghoard.pgutil import create_connection_string from pghoard.postgres_command import archive_command, restore_command from pghoard.restore import Restore +from pghoard.webserver import DownloadResultsProcessor, PendingDownloadOp # pylint: disable=attribute-defined-outside-init from .base import CONSTANT_TEST_RSA_PRIVATE_KEY, CONSTANT_TEST_RSA_PUBLIC_KEY @@ -770,3 +773,62 @@ def test_uncontrolled_target_path(self, pghoard): conn.request("GET", wal_file, headers=headers) status = conn.getresponse().status assert status == 400 + + +@pytest.fixture(name="download_results_processor") +def fixture_download_results_processor() -> DownloadResultsProcessor: + return DownloadResultsProcessor(threading.RLock(), Queue(), {}, deque()) + + +class TestDownloadResultsProcessor: + wal_name = "000000060000000000000001" + + def save_wal_and_download_callback(self, pg_wal_dir, download_results_processor, wal_name=None, is_valid_wal=True): + if wal_name is None: + wal_name = self.wal_name + tmp_path = os.path.join(pg_wal_dir, f"{wal_name}.pghoard.tmp") + target_path = os.path.join(pg_wal_dir, f"{wal_name}.pghoard.prefetch") + assert not os.path.exists(tmp_path) + assert not os.path.exists(target_path) + + # save WAL on FS + if is_valid_wal: + wal_data = wal_header_for_file(wal_name) + else: + another_wal_name = "000000DD00000000000000DD" + assert wal_name != another_wal_name + wal_data = wal_header_for_file(another_wal_name) + with open(tmp_path, "wb") as out_file: + out_file.write(wal_data) + + download_result = CallbackEvent(success=True, payload={"target_path": tmp_path}, opaque=wal_name) + pending_op = PendingDownloadOp( + started_at=time.monotonic(), target_path=target_path, filetype="xlog", filename=wal_name + ) + download_results_processor.pending_download_ops[wal_name] = pending_op + return tmp_path, target_path, download_result + + @pytest.mark.parametrize("empty_pending_download_ops", [True, False]) + @pytest.mark.parametrize("is_valid_wal", [True, False]) + def test_rename_wal(self, download_results_processor, tmpdir, is_valid_wal, empty_pending_download_ops): + tmp_path, target_path, download_result_item = self.save_wal_and_download_callback( + tmpdir, download_results_processor, is_valid_wal=is_valid_wal + ) + if empty_pending_download_ops: + download_results_processor.pending_download_ops = {} + download_results_processor.process_queue_item(download_result_item) + assert os.path.exists(target_path) is (is_valid_wal and not empty_pending_download_ops) + assert not os.path.exists(tmp_path) + + def test_dont_overwrite_existing_target_file(self, download_results_processor, tmpdir): + tmp_path, target_path, download_result_item = self.save_wal_and_download_callback(tmpdir, download_results_processor) + existing_file_data = b"-" + with open(target_path, "wb") as out_file: + out_file.write(existing_file_data) + assert os.path.exists(target_path) + assert os.path.exists(tmp_path) + + download_results_processor.process_queue_item(download_result_item) + assert os.path.exists(target_path) + assert open(target_path, "rb").read() == existing_file_data + assert os.path.exists(tmp_path)