Skip to content

Commit

Permalink
[CI] test_download: check whether self-update downloaded from the exp…
Browse files Browse the repository at this point in the history
…ected url
  • Loading branch information
glandium committed Sep 7, 2024
1 parent 2b49f1e commit a1fb9a6
Showing 1 changed file with 66 additions and 27 deletions.
93 changes: 66 additions & 27 deletions CI/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import urllib.parse
import urllib.request
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -84,6 +85,9 @@ def get_version(x, **kwargs):
return result
return result.removeprefix("git-cinnabar ")

def first(list):
return checked_call(next, iter(list))

executor = ThreadPoolExecutor(max_workers=1)

def get_pkg():
Expand Down Expand Up @@ -236,16 +240,21 @@ def get_pkg():
if last_tag in urls:
shutil.copytree(cwd / v, update_dir, symlinks=True, dirs_exist_ok=True)
git_cinnabar_v = update_dir / git_cinnabar.name
status += assert_eq(
Result(
check_output,
[git_cinnabar_v, "self-update"],
stderr=subprocess.STDOUT,
),
"WARNING Did not find an update to install."
if v in (head, head_version)
else f"Installing update from {urls[last_tag]}",
)
with proxy.capture_log() as log:
status += assert_eq(
Result(
check_output,
[git_cinnabar_v, "self-update"],
stderr=subprocess.STDOUT,
),
"WARNING Did not find an update to install."
if v in (head, head_version)
else f"Installing update from {urls[last_tag]}",
)
if t == head:
status += assert_eq(log, [])
else:
status += assert_eq(Result(first, log), urls[last_tag])
status += assert_eq(
Result(get_version, [git_cinnabar_v, "--version"]),
full_versions[head] if t == head else full_versions[last_tag],
Expand All @@ -254,16 +263,21 @@ def get_pkg():

if head_branch in urls:
shutil.copytree(cwd / v, update_dir, symlinks=True, dirs_exist_ok=True)
status += assert_eq(
Result(
check_output,
[git_cinnabar_v, "self-update", "--branch", head_branch],
stderr=subprocess.STDOUT,
),
"WARNING Did not find an update to install."
if v in (head, head_version)
else f"Installing update from {urls[head_branch]}",
)
with proxy.capture_log() as log:
status += assert_eq(
Result(
check_output,
[git_cinnabar_v, "self-update", "--branch", head_branch],
stderr=subprocess.STDOUT,
),
"WARNING Did not find an update to install."
if v in (head, head_version)
else f"Installing update from {urls[head_branch]}",
)
if t == head:
status += assert_eq(log, [])
else:
status += assert_eq(Result(first, log), urls[head_branch])
status += assert_eq(
Result(get_version, [git_cinnabar_v, "--version"]),
full_versions[head],
Expand Down Expand Up @@ -438,6 +452,7 @@ def checked_call(f, *args, **kwargs):
class ProxyServer(http.server.ThreadingHTTPServer):
def __init__(self):
super().__init__(("localhost", 0), ProxyHTTPRequestHandler)
self.log = None
self.mappings = {}
self.url = f"http://localhost:{self.server_port}"
self.thread = Thread(target=self.serve_forever)
Expand All @@ -452,13 +467,34 @@ def __init__(self):
this_script.with_name("selfsigned.key"),
)

def map(self, url, content):
u = urllib.parse.urlparse(url)
@contextmanager
def capture_log(self):
assert self.log is None
self.log = []
yield self.log
self.log = None

def log_url(self, url_elements):
if self.log is not None:
host, port, path = url_elements
url = f"https://{host}"
if port != 443:
url += f":{port}"
url += path
self.log.append(url)

@staticmethod
def urlsplit(url):
u = urllib.parse.urlsplit(url)
assert u.scheme == "https"
path = u.path
if u.query:
path = f"{path}?{u.query}"
self.mappings.setdefault((u.hostname, u.port or 443), {})[path] = content
return (u.hostname, u.port or 443, path)

def map(self, url, content):
host, port, path = self.urlsplit(url)
self.mappings.setdefault((host, port), {})[path] = content


class ProxyHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
Expand All @@ -468,15 +504,17 @@ def do_CONNECT(self):
self.send_response_only(200)
self.end_headers()

mappings = self.server.mappings.get((host, port))
if mappings:
mappings = self.server.mappings.get((host, port), {})
if mappings or self.server.log is not None:
self.handle_locally(mappings, host, port)
else:
self.pass_through(host, port)

def handle_locally(self, mappings, host, port):
with self.server.context.wrap_socket(self.connection, server_side=True) as sock:
HTTPRequestHandler(sock, self.client_address, (mappings, host, port))
HTTPRequestHandler(
sock, self.client_address, (self.server, mappings, host, port)
)

def pass_through(self, host, port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand All @@ -501,8 +539,9 @@ def relay(src, dest):

class HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
mappings, host, port = self.server
server, mappings, host, port = self.server
content = mappings.get(self.path)
server.log_url((host, port, self.path))
if content:
self.send_content(content)
else:
Expand Down

0 comments on commit a1fb9a6

Please sign in to comment.