diff --git a/misc/bazel/internal/git_lfs_probe.py b/misc/bazel/internal/git_lfs_probe.py index d22747e8547e..6ceb89c8ab0a 100755 --- a/misc/bazel/internal/git_lfs_probe.py +++ b/misc/bazel/internal/git_lfs_probe.py @@ -2,9 +2,10 @@ """ Probe lfs files. -For each source file provided as output, this will print: +For each source file provided as input, this will print: * "local", if the source file is not an LFS pointer * the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise +If --hash-only is provided, the transient URL will not be fetched and printed """ import sys @@ -19,6 +20,14 @@ import base64 from dataclasses import dataclass from typing import Dict +import argparse + + +def options(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--hash-only", action="store_true") + p.add_argument("sources", type=pathlib.Path, nargs="+") + return p.parse_args() @dataclass @@ -30,7 +39,8 @@ def update_headers(self, d: Dict[str, str]): self.headers.update((k.capitalize(), v) for k, v in d.items()) -sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]] +opts = options() +sources = [p.resolve() for p in opts.sources] source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip() @@ -60,7 +70,12 @@ def get_endpoint(): server, _, path = ssh_endpoint.partition(":") ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))) assert ssh_command, "no ssh command found" - resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"])) + resp = json.loads(subprocess.check_output([ssh_command, + "-oStrictHostKeyChecking=accept-new", + server, + "git-lfs-authenticate", + path, + "download"])) endpoint.href = resp.get("href", endpoint) endpoint.update_headers(resp.get("header", {})) url = urlparse(endpoint.href) @@ -84,11 +99,15 @@ def get_endpoint(): # see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md def get_locations(objects): ret = ["local" for _ in objects] - endpoint = get_endpoint() indexes = [i for i, o in enumerate(objects) if o] if not indexes: # all objects are local, do not send an empty request as that would be an error return ret + if opts.hash_only: + for i in indexes: + ret[i] = objects[i]["oid"] + return ret + endpoint = get_endpoint() data = { "operation": "download", "transfers": ["basic"], diff --git a/misc/bazel/lfs.bzl b/misc/bazel/lfs.bzl index 4ba66c9dbfc6..3a496ea9530c 100644 --- a/misc/bazel/lfs.bzl +++ b/misc/bazel/lfs.bzl @@ -1,36 +1,44 @@ def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None): - for src in srcs: - repository_ctx.watch(src) - script = Label("//misc/bazel/internal:git_lfs_probe.py") python = repository_ctx.which("python3") or repository_ctx.which("python") if not python: fail("Neither python3 nor python executables found") - repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs])) - res = repository_ctx.execute([python, script] + srcs, quiet = True) - if res.return_code != 0: - fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr)) - promises = [] - for src, loc in zip(srcs, res.stdout.splitlines()): - if loc == "local": - if extract: - repository_ctx.report_progress("extracting local %s" % src.basename) - repository_ctx.extract(src, stripPrefix = stripPrefix) - else: - repository_ctx.report_progress("symlinking local %s" % src.basename) - repository_ctx.symlink(src, src.basename) + script = Label("//misc/bazel/internal:git_lfs_probe.py") + + def probe(srcs, hash_only = False): + repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs])) + cmd = [python, script] + if hash_only: + cmd.append("--hash-only") + cmd.extend(srcs) + res = repository_ctx.execute(cmd, quiet = True) + if res.return_code != 0: + fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr)) + return res.stdout.splitlines() + + for src in srcs: + repository_ctx.watch(src) + infos = probe(srcs, hash_only = True) + remote = [] + for src, info in zip(srcs, infos): + if info == "local": + repository_ctx.report_progress("symlinking local %s" % src.basename) + repository_ctx.symlink(src, src.basename) else: - sha256, _, url = loc.partition(" ") - if extract: - # we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz` - # or similar wouldn't work - # it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip", - # download_and_extract will just append ".name.zip" its internal temporary name, so extraction works - possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:]) - repository_ctx.report_progress("downloading and extracting remote %s" % src.basename) - repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension) - else: + repository_ctx.report_progress("trying cache for remote %s" % src.basename) + res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True) + if not res.success: + remote.append(src) + if remote: + infos = probe(remote) + for src, info in zip(remote, infos): + sha256, _, url = info.partition(" ") repository_ctx.report_progress("downloading remote %s" % src.basename) repository_ctx.download(url, src.basename, sha256 = sha256) + if extract: + for src in srcs: + repository_ctx.report_progress("extracting %s" % src.basename) + repository_ctx.extract(src.basename, stripPrefix = stripPrefix) + repository_ctx.delete(src.basename) def _download_and_extract_lfs(repository_ctx): attr = repository_ctx.attr