Skip to content

Commit

Permalink
Merge pull request #16522 from github/redsun82/lfs
Browse files Browse the repository at this point in the history
Bazel: allow LFS rules to use cached downloads without internet
  • Loading branch information
redsun82 authored May 21, 2024
2 parents 13a7d9a + d01d657 commit 9d21e2c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
27 changes: 23 additions & 4 deletions misc/bazel/internal/git_lfs_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

"""
Probe lfs files.
For each source file provided as output, this will print:
For each source file provided as input, this will print:
* "local", if the source file is not an LFS pointer
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
If --hash-only is provided, the transient URL will not be fetched and printed
"""

import sys
Expand All @@ -19,6 +20,14 @@
import base64
from dataclasses import dataclass
from typing import Dict
import argparse


def options():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--hash-only", action="store_true")
p.add_argument("sources", type=pathlib.Path, nargs="+")
return p.parse_args()


@dataclass
Expand All @@ -30,7 +39,8 @@ def update_headers(self, d: Dict[str, str]):
self.headers.update((k.capitalize(), v) for k, v in d.items())


sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
opts = options()
sources = [p.resolve() for p in opts.sources]
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()

Expand Down Expand Up @@ -60,7 +70,12 @@ def get_endpoint():
server, _, path = ssh_endpoint.partition(":")
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
assert ssh_command, "no ssh command found"
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"]))
resp = json.loads(subprocess.check_output([ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download"]))
endpoint.href = resp.get("href", endpoint)
endpoint.update_headers(resp.get("header", {}))
url = urlparse(endpoint.href)
Expand All @@ -84,11 +99,15 @@ def get_endpoint():
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
def get_locations(objects):
ret = ["local" for _ in objects]
endpoint = get_endpoint()
indexes = [i for i, o in enumerate(objects) if o]
if not indexes:
# all objects are local, do not send an empty request as that would be an error
return ret
if opts.hash_only:
for i in indexes:
ret[i] = objects[i]["oid"]
return ret
endpoint = get_endpoint()
data = {
"operation": "download",
"transfers": ["basic"],
Expand Down
60 changes: 34 additions & 26 deletions misc/bazel/lfs.bzl
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None):
for src in srcs:
repository_ctx.watch(src)
script = Label("//misc/bazel/internal:git_lfs_probe.py")
python = repository_ctx.which("python3") or repository_ctx.which("python")
if not python:
fail("Neither python3 nor python executables found")
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
res = repository_ctx.execute([python, script] + srcs, quiet = True)
if res.return_code != 0:
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
promises = []
for src, loc in zip(srcs, res.stdout.splitlines()):
if loc == "local":
if extract:
repository_ctx.report_progress("extracting local %s" % src.basename)
repository_ctx.extract(src, stripPrefix = stripPrefix)
else:
repository_ctx.report_progress("symlinking local %s" % src.basename)
repository_ctx.symlink(src, src.basename)
script = Label("//misc/bazel/internal:git_lfs_probe.py")

def probe(srcs, hash_only = False):
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
cmd = [python, script]
if hash_only:
cmd.append("--hash-only")
cmd.extend(srcs)
res = repository_ctx.execute(cmd, quiet = True)
if res.return_code != 0:
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
return res.stdout.splitlines()

for src in srcs:
repository_ctx.watch(src)
infos = probe(srcs, hash_only = True)
remote = []
for src, info in zip(srcs, infos):
if info == "local":
repository_ctx.report_progress("symlinking local %s" % src.basename)
repository_ctx.symlink(src, src.basename)
else:
sha256, _, url = loc.partition(" ")
if extract:
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz`
# or similar wouldn't work
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip",
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:])
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename)
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension)
else:
repository_ctx.report_progress("trying cache for remote %s" % src.basename)
res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True)
if not res.success:
remote.append(src)
if remote:
infos = probe(remote)
for src, info in zip(remote, infos):
sha256, _, url = info.partition(" ")
repository_ctx.report_progress("downloading remote %s" % src.basename)
repository_ctx.download(url, src.basename, sha256 = sha256)
if extract:
for src in srcs:
repository_ctx.report_progress("extracting %s" % src.basename)
repository_ctx.extract(src.basename, stripPrefix = stripPrefix)
repository_ctx.delete(src.basename)

def _download_and_extract_lfs(repository_ctx):
attr = repository_ctx.attr
Expand Down

0 comments on commit 9d21e2c

Please sign in to comment.