From 170e2231d4857dfbef03b4a7ccd46e9aa8a961f3 Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Fri, 17 May 2024 16:24:38 +0100 Subject: [PATCH 1/2] Bazel: allow LFS rules to use cached downloads without internet If the cache is prefilled, LFS rules were still trying to query LFS urls. Now the strategy is to first try to fetch the files from the repository cache (which is possible by providing an empty url list and `allow_fail` to `repository_ctx.download`), and only run the LFS protocol if that fails. Technically this is possible by enhancing `git_lfs_probe.py` with a `--hash-only` flag. This is also an optimization where no uneeded access is done (including the slightly slow SSH call) if the repository cache is warm. --- misc/bazel/internal/git_lfs_probe.py | 19 +++++++-- misc/bazel/lfs.bzl | 60 ++++++++++++++++------------ 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index d22747e8547e..47c2c80b1f43 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -2,9 +2,10 @@ """ Probe lfs files. -For each source file provided as output, this will print: +For each source file provided as input, this will print: * "local", if the source file is not an LFS pointer * the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise +If --hash-only is provided, the transient URL will not be fetched and printed """ import sys @@ -19,6 +20,13 @@ import base64 from dataclasses import dataclass from typing import Dict +import argparse + +def options(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--hash-only", action="store_true") + p.add_argument("sources", type=pathlib.Path, nargs="+") + return p.parse_args() @dataclass @@ -30,7 +38,8 @@ def update_headers(self, d: Dict[str, str]): self.headers.update((k.capitalize(), v) for k, v in d.items()) -sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]] +opts = options() +sources = [p.resolve() for p in opts.sources] source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip() @@ -84,11 +93,15 @@ def get_endpoint(): # see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md def get_locations(objects): ret = ["local" for _ in objects] - endpoint = get_endpoint() indexes = [i for i, o in enumerate(objects) if o] if not indexes: # all objects are local, do not send an empty request as that would be an error return ret + if opts.hash_only: + for i in indexes: + ret[i] = objects[i]["oid"] + return ret + endpoint = get_endpoint() data = { "operation": "download", "transfers": ["basic"], diff --git a/misc/bazel/lfs.bzl b/misc/bazel/lfs.bzl index 4ba66c9dbfc6..3a496ea9530c 100644 --- a/misc/bazel/lfs.bzl +++ b/misc/bazel/lfs.bzl @@ -1,36 +1,44 @@ def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None): - for src in srcs: - repository_ctx.watch(src) - script = Label("//misc/bazel/internal:git_lfs_probe.py") python = repository_ctx.which("python3") or repository_ctx.which("python") if not python: fail("Neither python3 nor python executables found") - repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs])) - res = repository_ctx.execute([python, script] + srcs, quiet = True) - if res.return_code != 0: - fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr)) - promises = [] - for src, loc in zip(srcs, res.stdout.splitlines()): - if loc == "local": - if extract: - repository_ctx.report_progress("extracting local %s" % src.basename) - repository_ctx.extract(src, stripPrefix = stripPrefix) - else: - repository_ctx.report_progress("symlinking local %s" % src.basename) - repository_ctx.symlink(src, src.basename) + script = Label("//misc/bazel/internal:git_lfs_probe.py") + + def probe(srcs, hash_only = False): + repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs])) + cmd = [python, script] + if hash_only: + cmd.append("--hash-only") + cmd.extend(srcs) + res = repository_ctx.execute(cmd, quiet = True) + if res.return_code != 0: + fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr)) + return res.stdout.splitlines() + + for src in srcs: + repository_ctx.watch(src) + infos = probe(srcs, hash_only = True) + remote = [] + for src, info in zip(srcs, infos): + if info == "local": + repository_ctx.report_progress("symlinking local %s" % src.basename) + repository_ctx.symlink(src, src.basename) else: - sha256, _, url = loc.partition(" ") - if extract: - # we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz` - # or similar wouldn't work - # it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip", - # download_and_extract will just append ".name.zip" its internal temporary name, so extraction works - possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:]) - repository_ctx.report_progress("downloading and extracting remote %s" % src.basename) - repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension) - else: + repository_ctx.report_progress("trying cache for remote %s" % src.basename) + res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True) + if not res.success: + remote.append(src) + if remote: + infos = probe(remote) + for src, info in zip(remote, infos): + sha256, _, url = info.partition(" ") repository_ctx.report_progress("downloading remote %s" % src.basename) repository_ctx.download(url, src.basename, sha256 = sha256) + if extract: + for src in srcs: + repository_ctx.report_progress("extracting %s" % src.basename) + repository_ctx.extract(src.basename, stripPrefix = stripPrefix) + repository_ctx.delete(src.basename) def _download_and_extract_lfs(repository_ctx): attr = repository_ctx.attr From d01d657f89cb5e4f476ac689f5d0fa89b003eeb6 Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Fri, 17 May 2024 16:39:18 +0100 Subject: [PATCH 2/2] Bazel: accept new SSH keys in `git_lfs_probe.py` --- misc/bazel/internal/git_lfs_probe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index 47c2c80b1f43..6ceb89c8ab0a 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -22,6 +22,7 @@ from typing import Dict import argparse + def options(): p = argparse.ArgumentParser(description=__doc__) p.add_argument("--hash-only", action="store_true") @@ -69,7 +70,12 @@ def get_endpoint(): server, _, path = ssh_endpoint.partition(":") ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))) assert ssh_command, "no ssh command found" - resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"])) + resp = json.loads(subprocess.check_output([ssh_command, + "-oStrictHostKeyChecking=accept-new", + server, + "git-lfs-authenticate", + path, + "download"])) endpoint.href = resp.get("href", endpoint) endpoint.update_headers(resp.get("header", {})) url = urlparse(endpoint.href)