Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bazel: allow LFS rules to use cached downloads without internet #16522

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions misc/bazel/internal/git_lfs_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

"""
Probe lfs files.
For each source file provided as output, this will print:
For each source file provided as input, this will print:
* "local", if the source file is not an LFS pointer
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
If --hash-only is provided, the transient URL will not be fetched and printed
"""

import sys
Expand All @@ -19,6 +20,14 @@
import base64
from dataclasses import dataclass
from typing import Dict
import argparse


def options():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--hash-only", action="store_true")
p.add_argument("sources", type=pathlib.Path, nargs="+")
return p.parse_args()


@dataclass
Expand All @@ -30,7 +39,8 @@ def update_headers(self, d: Dict[str, str]):
self.headers.update((k.capitalize(), v) for k, v in d.items())


sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
opts = options()
sources = [p.resolve() for p in opts.sources]
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()

Expand Down Expand Up @@ -60,7 +70,12 @@ def get_endpoint():
server, _, path = ssh_endpoint.partition(":")
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
assert ssh_command, "no ssh command found"
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"]))
resp = json.loads(subprocess.check_output([ssh_command,
"-oStrictHostKeyChecking=accept-new",
server,
"git-lfs-authenticate",
path,
"download"]))
endpoint.href = resp.get("href", endpoint)
endpoint.update_headers(resp.get("header", {}))
url = urlparse(endpoint.href)
Expand All @@ -84,11 +99,15 @@ def get_endpoint():
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
def get_locations(objects):
ret = ["local" for _ in objects]
endpoint = get_endpoint()
indexes = [i for i, o in enumerate(objects) if o]
if not indexes:
# all objects are local, do not send an empty request as that would be an error
return ret
if opts.hash_only:
for i in indexes:
ret[i] = objects[i]["oid"]
return ret
endpoint = get_endpoint()
data = {
"operation": "download",
"transfers": ["basic"],
Expand Down
60 changes: 34 additions & 26 deletions misc/bazel/lfs.bzl
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None):
for src in srcs:
repository_ctx.watch(src)
script = Label("//misc/bazel/internal:git_lfs_probe.py")
python = repository_ctx.which("python3") or repository_ctx.which("python")
if not python:
fail("Neither python3 nor python executables found")
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
res = repository_ctx.execute([python, script] + srcs, quiet = True)
if res.return_code != 0:
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
promises = []
for src, loc in zip(srcs, res.stdout.splitlines()):
if loc == "local":
if extract:
repository_ctx.report_progress("extracting local %s" % src.basename)
repository_ctx.extract(src, stripPrefix = stripPrefix)
else:
repository_ctx.report_progress("symlinking local %s" % src.basename)
repository_ctx.symlink(src, src.basename)
script = Label("//misc/bazel/internal:git_lfs_probe.py")

def probe(srcs, hash_only = False):
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
cmd = [python, script]
if hash_only:
cmd.append("--hash-only")
cmd.extend(srcs)
res = repository_ctx.execute(cmd, quiet = True)
if res.return_code != 0:
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
return res.stdout.splitlines()

for src in srcs:
repository_ctx.watch(src)
infos = probe(srcs, hash_only = True)
remote = []
for src, info in zip(srcs, infos):
if info == "local":
repository_ctx.report_progress("symlinking local %s" % src.basename)
repository_ctx.symlink(src, src.basename)
else:
sha256, _, url = loc.partition(" ")
if extract:
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz`
# or similar wouldn't work
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip",
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:])
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename)
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension)
else:
repository_ctx.report_progress("trying cache for remote %s" % src.basename)
res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True)
if not res.success:
remote.append(src)
if remote:
infos = probe(remote)
for src, info in zip(remote, infos):
sha256, _, url = info.partition(" ")
repository_ctx.report_progress("downloading remote %s" % src.basename)
repository_ctx.download(url, src.basename, sha256 = sha256)
if extract:
for src in srcs:
repository_ctx.report_progress("extracting %s" % src.basename)
repository_ctx.extract(src.basename, stripPrefix = stripPrefix)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this now slower on CI (with an empty repository cache) where we'll get a separate download and extract step, or is download_and_extract performance-wise 1. download 2. extract with no overlap anyways?

Copy link
Contributor Author

@redsun82 redsun82 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't really hard measure it, but it did not seem to make a difference locally.

repository_ctx.delete(src.basename)

def _download_and_extract_lfs(repository_ctx):
attr = repository_ctx.attr
Expand Down
Loading