Skip to content

Commit

Permalink
webserver: support authentication on the webserver
Browse files Browse the repository at this point in the history
Support requesting authentication on the web server. This is intended to
limit exposure from local redirects, open proxies and/or call out
functionality.
  • Loading branch information
hnousiainen committed Sep 19, 2024
1 parent 82dc55e commit 2409307
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 10 deletions.
12 changes: 10 additions & 2 deletions golang/pghoard_postgres_command_go.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func run() (int, error) {
verPtr := flag.Bool("version", false, "show program version")
hostPtr := flag.String("host", PGHOARD_HOST, "pghoard service host")
portPtr := flag.Int("port", PGHOARD_PORT, "pghoard service port")
usernamePtr := flag.String("username", "", "pghoard service username")
passwordPtr := flag.String("password", "", "pghoard service password")
sitePtr := flag.String("site", "", "pghoard backup site")
xlogPtr := flag.String("xlog", "", "xlog file name")
outputPtr := flag.String("output", "", "output file")
Expand Down Expand Up @@ -82,7 +84,7 @@ func run() (int, error) {
retry_seconds := *riPtr
for {
attempt += 1
rc, err := restore_command(url, *outputPtr, *xlogPtr)
rc, err := restore_command(url, *outputPtr, *xlogPtr, *usernamePtr, *passwordPtr)
if rc != EXIT_RESTORE_FAIL {
return rc, err
}
Expand All @@ -101,13 +103,16 @@ func archive_command(url string) (int, error) {
return EXIT_ABORT, errors.New("archive_command not yet implemented")
}

func restore_command(url string, output string, xlog string) (int, error) {
func restore_command(url string, output string, xlog string, username string, password string) (int, error) {
var output_path string
var req *http.Request
var err error

if output == "" {
req, err = http.NewRequest("HEAD", url, nil)
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
} else {
/* Construct absolute path for output - postgres calls this command with a relative path to its xlog
directory. Note that os.path.join strips preceding components if a new components starts with a
Expand Down Expand Up @@ -136,6 +141,9 @@ func restore_command(url string, output string, xlog string) (int, error) {
}
req, err = http.NewRequest("GET", url, nil)
req.Header.Set("x-pghoard-target-path", output_path)
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
}

client := &http.Client{}
Expand Down
2 changes: 2 additions & 0 deletions pghoard/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def set_and_check_config_defaults(config, *, check_commands=True, check_pgdata=T
config.setdefault("backup_location", None)
config.setdefault("http_address", PGHOARD_HOST)
config.setdefault("http_port", PGHOARD_PORT)
config.setdefault("webserver_username", None)
config.setdefault("webserver_password", None)
config.setdefault("alert_file_dir", config.get("backup_location") or os.getcwd())
config.setdefault("json_state_file_path", "/var/lib/pghoard/pghoard_state.json")
config.setdefault("maintenance_mode_file", "/var/lib/pghoard/maintenance_mode_file")
Expand Down
7 changes: 5 additions & 2 deletions pghoard/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Optional

from requests import Session
from requests.auth import HTTPBasicAuth
from rohmu import dates


Expand Down Expand Up @@ -72,14 +73,16 @@ def get_file_bytes(self, name):


class HTTPRestore(ObjectStore):
def __init__(self, host, port, site, pgdata=None):
def __init__(self, host, port, site, pgdata=None, *, username=None, password=None):
super().__init__(storage=None, prefix=None, site=site, pgdata=pgdata)
self.host = host
self.port = port
self.session = Session()
if username and password:
self.session.auth = HTTPBasicAuth(username, password)

def _url(self, path):
return "http://{host}:{port}/{site}/{path}".format(host=self.host, port=self.port, site=self.site, path=path)
return f"http://{self.host}:{self.port}/{self.site}/{path}"

def list_basebackups(self):
response = self.session.get(self._url("basebackup"))
Expand Down
14 changes: 12 additions & 2 deletions pghoard/postgres_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import argparse
import base64
import os
import socket
import sys
Expand Down Expand Up @@ -45,10 +46,17 @@ def __init__(self, message, exit_code=EXIT_FAIL):
self.exit_code = exit_code


def http_request(host, port, method, path, headers=None):
def http_request(host, port, method, path, headers=None, *, username=None, password=None):
conn = HTTPConnection(host=host, port=port)
if headers is not None:
headers = headers.copy()
else:
headers = {}
if username is not None and password is not None:
auth_str = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode()
headers["Authorization"] = f"Basic {auth_str}"
try:
conn.request(method, path, headers=headers or {})
conn.request(method, path, headers=headers)
resp = conn.getresponse()
finally:
conn.close()
Expand Down Expand Up @@ -112,6 +120,8 @@ def main(args=None):
parser.add_argument("--version", action="version", help="show program version", version=version.__version__)
parser.add_argument("--host", type=str, default=PGHOARD_HOST, help="pghoard service host")
parser.add_argument("--port", type=int, default=PGHOARD_PORT, help="pghoard service port")
parser.add_argument("--username", type=str, help="pghoard service username")
parser.add_argument("--password", type=str, help="pghoard service password")
parser.add_argument("--site", type=str, required=True, help="pghoard backup site")
parser.add_argument("--xlog", type=str, required=True, help="xlog file name")
parser.add_argument("--output", type=str, help="output file")
Expand Down
19 changes: 16 additions & 3 deletions pghoard/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def create_recovery_conf(
site,
*,
port=PGHOARD_PORT,
webserver_username=None,
webserver_password=None,
primary_conninfo=None,
recovery_end_command=None,
recovery_target_action=None,
Expand All @@ -113,6 +115,13 @@ def create_recovery_conf(
"--xlog",
"%f",
]
if webserver_username and webserver_password:
restore_command.extend([
"--username",
webserver_username,
"--password",
webserver_password,
])
with open(os.path.join(dirpath, "PG_VERSION"), "r") as fp:
v = Version(fp.read().strip())
pg_version = v.major if v.major >= 10 else float(f"{v.major}.{v.minor}")
Expand Down Expand Up @@ -213,9 +222,11 @@ def generic_args(require_config=True, require_site=False):

cmd.add_argument("--site", help="pghoard site", required=require_site)

def host_port_args():
def host_port_user_args():
cmd.add_argument("--host", help="pghoard repository host", default=PGHOARD_HOST)
cmd.add_argument("--port", help="pghoard repository port", default=PGHOARD_PORT)
cmd.add_argument("--username", help="pghoard repository username")
cmd.add_argument("--password", help="pghoard repository password")

def target_args():
cmd.add_argument("--basebackup", help="pghoard basebackup", default="latest")
Expand Down Expand Up @@ -266,7 +277,7 @@ def target_args():
)

cmd = add_cmd(self.list_basebackups_http)
host_port_args()
host_port_user_args()
generic_args(require_config=False, require_site=True)

cmd = add_cmd(self.list_basebackups)
Expand All @@ -280,7 +291,7 @@ def target_args():

def list_basebackups_http(self, arg):
"""List available basebackups from a HTTP source"""
self.storage = HTTPRestore(arg.host, arg.port, arg.site)
self.storage = HTTPRestore(arg.host, arg.port, arg.site, username=arg.username, password=arg.password)
self.storage.show_basebackup_list(verbose=arg.verbose)

def _get_site_prefix(self, site):
Expand Down Expand Up @@ -609,6 +620,8 @@ def _get_basebackup(
dirpath=pgdata,
site=site,
port=self.config["http_port"],
webserver_username=self.config.get("webserver_username"),
webserver_password=self.config.get("webserver_password"),
primary_conninfo=primary_conninfo,
recovery_end_command=recovery_end_command,
recovery_target_action=recovery_target_action,
Expand Down
23 changes: 23 additions & 0 deletions pghoard/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Copyright (c) 2016 Ohmu Ltd
See LICENSE for details
"""
import base64
import ipaddress
import logging
import os
Expand Down Expand Up @@ -242,6 +243,22 @@ class RequestHandler(BaseHTTPRequestHandler):
disable_nagle_algorithm = True
server_version = "pghoard/" + __version__
server: OwnHTTPServer
_expected_auth_header = None

def _authentication_check(self):
if self.server.config.get("webserver_username") and self.server.config.get("webserver_password"):
if self._expected_auth_header is None:
auth_data_raw = self.server.config["webserver_username"] + ":" + self.server.config["webserver_password"]
auth_data_b64 = base64.b64encode(auth_data_raw.encode("utf-8")).decode()
self._expected_auth_header = f"Basic {auth_data_b64}"
if self.headers.get("Authorization") != self._expected_auth_header:
self.send_response(401)
self.send_header("WWW-Authenticate", 'Basic realm="pghoard"')
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"Authentication required")
return False
return True

@contextmanager
def _response_handler(self, method):
Expand Down Expand Up @@ -645,12 +662,16 @@ def handle_archival_request(self, site, filename, filetype):
raise HttpResponse(status=201)

def do_PUT(self):
if not self._authentication_check():
return
with self._response_handler("PUT") as path:
site, obtype, obname = self._parse_request(path)
assert obtype in ("basebackup", "xlog", "timeline")
self.handle_archival_request(site, obname, obtype)

def do_HEAD(self):
if not self._authentication_check():
return
with self._response_handler("HEAD") as path:
site, obtype, obname = self._parse_request(path)
if self.headers.get("x-pghoard-target-path"):
Expand All @@ -664,6 +685,8 @@ def do_HEAD(self):
raise HttpResponse(status=200, headers=headers)

def do_GET(self):
if not self._authentication_check():
return
with self._response_handler("GET") as path:
site, obtype, obname = self._parse_request(path)
if obtype == "basebackup":
Expand Down
13 changes: 12 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def fixture_pghoard(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request)


@pytest.fixture(name="pghoard_with_userauth")
def fixture_pghoard_with_userauth(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request, username="testuser", password="testpass")


@pytest.fixture(name="pghoard_ipv4_hostname")
def fixture_pghoard_ipv4_hostname(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request, listen_http_address="localhost")
Expand Down Expand Up @@ -362,7 +367,9 @@ def pghoard_base(
active_backup_mode="pg_receivexlog",
slot_name=None,
compression_count=None,
listen_http_address="127.0.0.1"
listen_http_address="127.0.0.1",
username=None,
password=None
):
test_site = request.function.__name__

Expand Down Expand Up @@ -418,6 +425,10 @@ def pghoard_base(
if compression_count is not None:
config["compression"]["thread_count"] = compression_count

if username is not None and password is not None:
config["webserver_username"] = username
config["webserver_password"] = password

confpath = os.path.join(str(tmpdir), "config.json")
with open(confpath, "w") as fp:
json.dump(config, fp)
Expand Down
59 changes: 59 additions & 0 deletions test/test_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""
import base64
import json
import logging
import os
Expand Down Expand Up @@ -40,6 +41,19 @@ def http_restore(pghoard):
return HTTPRestore("localhost", pghoard.config["http_port"], site=pghoard.test_site, pgdata=pgdata)


@pytest.fixture
def http_restore_with_userauth(pghoard_with_userauth):
pgdata = get_pg_wal_directory(pghoard_with_userauth.config["backup_sites"][pghoard_with_userauth.test_site])
return HTTPRestore(
"localhost",
pghoard_with_userauth.config["http_port"],
site=pghoard_with_userauth.test_site,
pgdata=pgdata,
username=pghoard_with_userauth.config["webserver_username"],
password=pghoard_with_userauth.config["webserver_password"]
)


class TestWebServer:
def test_requesting_status(self, pghoard):
pghoard.write_backup_state_to_json_file()
Expand Down Expand Up @@ -774,6 +788,51 @@ def test_uncontrolled_target_path(self, pghoard):
status = conn.getresponse().status
assert status == 400

def test_requesting_status_with_user_authentiction(self, pghoard_with_userauth):
pghoard_with_userauth.write_backup_state_to_json_file()
conn = HTTPConnection(host="127.0.0.1", port=pghoard_with_userauth.config["http_port"])
conn.request("GET", "/status")
response = conn.getresponse()
assert response.status == 401

username = pghoard_with_userauth.config["webserver_username"]
password = pghoard_with_userauth.config["webserver_password"]
auth_str = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode()
headers = {"Authorization": f"Basic {auth_str}"}

conn = HTTPConnection(host="127.0.0.1", port=pghoard_with_userauth.config["http_port"])
conn.request("GET", "/status", headers=headers)
response = conn.getresponse()
assert response.status == 200

response_parsed = json.loads(response.read().decode("utf-8"))
# "startup_time": "2016-06-23T14:53:25.840787",
assert response_parsed["startup_time"] is not None

conn.request("GET", "/status/somesite", headers=headers)
response = conn.getresponse()
assert response.status == 400

conn.request("GET", "/somesite/status", headers=headers)
response = conn.getresponse()
assert response.status == 404

conn.request("GET", "/{}/status".format(pghoard_with_userauth.test_site), headers=headers)
response = conn.getresponse()
assert response.status == 501

def test_basebackups_with_user_authentication(self, capsys, db, http_restore_with_userauth, pghoard_with_userauth): # pylint: disable=redefined-outer-name
final_location = self._run_and_wait_basebackup(pghoard_with_userauth, db, "pipe")
backups = http_restore_with_userauth.list_basebackups()
assert len(backups) == 1
assert backups[0]["size"] > 0
assert backups[0]["name"] == os.path.join(pghoard_with_userauth.test_site, "basebackup", os.path.basename(final_location))
# make sure they show up on the printable listing, too
http_restore_with_userauth.show_basebackup_list()
out, _ = capsys.readouterr()
assert "{} MB".format(int(backups[0]["metadata"]["original-file-size"]) // (1024 ** 2)) in out
assert backups[0]["name"] in out


@pytest.fixture(name="download_results_processor")
def fixture_download_results_processor() -> DownloadResultsProcessor:
Expand Down

0 comments on commit 2409307

Please sign in to comment.