-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16393 from github/redsun82/lfs
Bazel: improved lazy lfs files
- Loading branch information
Showing
4 changed files
with
221 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[lfs] | ||
# codeql is publicly forked by many users, and we don't want any LFS file polluting their working | ||
# copies. We therefore exclude everything by default. | ||
# For files required by bazel builds, use rules in `misc/bazel/lfs.bzl` to download them on demand. | ||
fetchinclude = /nothing |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
Probe lfs files. | ||
For each source file provided as output, this will print: | ||
* "local", if the source file is not an LFS pointer | ||
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise | ||
""" | ||
|
||
import sys | ||
import pathlib | ||
import subprocess | ||
import os | ||
import shutil | ||
import json | ||
import urllib.request | ||
from urllib.parse import urlparse | ||
import re | ||
import base64 | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class Endpoint: | ||
href: str | ||
headers: dict[str, str] | ||
|
||
def update_headers(self, d: dict[str, str]): | ||
self.headers.update((k.capitalize(), v) for k, v in d.items()) | ||
|
||
|
||
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]] | ||
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) | ||
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip() | ||
|
||
|
||
def get_env(s, sep="="): | ||
ret = {} | ||
for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M): | ||
ret.setdefault(*m.groups()) | ||
return ret | ||
|
||
|
||
def git(*args, **kwargs): | ||
return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip() | ||
|
||
|
||
def get_endpoint(): | ||
lfs_env = get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir)) | ||
endpoint = next(v for k, v in lfs_env.items() if k.startswith('Endpoint')) | ||
endpoint, _, _ = endpoint.partition(' ') | ||
ssh_endpoint = lfs_env.get(" SSH") | ||
endpoint = Endpoint(endpoint, { | ||
"Content-Type": "application/vnd.git-lfs+json", | ||
"Accept": "application/vnd.git-lfs+json", | ||
}) | ||
if ssh_endpoint: | ||
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md | ||
server, _, path = ssh_endpoint.partition(":") | ||
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh"))) | ||
assert ssh_command, "no ssh command found" | ||
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"])) | ||
endpoint.href = resp.get("href", endpoint) | ||
endpoint.update_headers(resp.get("header", {})) | ||
url = urlparse(endpoint.href) | ||
# this is how actions/checkout persist credentials | ||
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 | ||
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") | ||
endpoint.update_headers(get_env(auth, sep=": ")) | ||
if "GITHUB_TOKEN" in os.environ: | ||
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" | ||
if "Authorization" not in endpoint.headers: | ||
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) | ||
# see https://git-scm.com/docs/git-credential | ||
credentials = get_env(git("credential", "fill", check=True, | ||
# drop leading / from url.path | ||
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n")) | ||
auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii') | ||
endpoint.headers["Authorization"] = f"Basic {auth}" | ||
return endpoint | ||
|
||
|
||
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md | ||
def get_locations(objects): | ||
endpoint = get_endpoint() | ||
indexes = [i for i, o in enumerate(objects) if o] | ||
ret = ["local" for _ in objects] | ||
req = urllib.request.Request( | ||
f"{endpoint.href}/objects/batch", | ||
headers=endpoint.headers, | ||
data=json.dumps({ | ||
"operation": "download", | ||
"transfers": ["basic"], | ||
"objects": [o for o in objects if o], | ||
"hash_algo": "sha256", | ||
}).encode("ascii"), | ||
) | ||
with urllib.request.urlopen(req) as resp: | ||
data = json.load(resp) | ||
assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}" | ||
for i, resp in zip(indexes, data["objects"]): | ||
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' | ||
return ret | ||
|
||
|
||
def get_lfs_object(path): | ||
with open(path, 'rb') as fileobj: | ||
lfs_header = "version https://git-lfs.github.com/spec".encode() | ||
actual_header = fileobj.read(len(lfs_header)) | ||
sha256 = size = None | ||
if lfs_header != actual_header: | ||
return None | ||
data = get_env(fileobj.read().decode('ascii'), sep=' ') | ||
assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}" | ||
_, _, sha256 = data['oid'].partition(':') | ||
size = int(data['size']) | ||
return {"oid": sha256, "size": size} | ||
|
||
|
||
objects = [get_lfs_object(src) for src in sources] | ||
for resp in get_locations(objects): | ||
print(resp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None): | ||
for src in srcs: | ||
repository_ctx.watch(src) | ||
script = Label("//misc/bazel/internal:git_lfs_probe.py") | ||
python = repository_ctx.which("python3") or repository_ctx.which("python") | ||
if not python: | ||
fail("Neither python3 nor python executables found") | ||
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs])) | ||
res = repository_ctx.execute([python, script] + srcs, quiet = True) | ||
if res.return_code != 0: | ||
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr)) | ||
promises = [] | ||
for src, loc in zip(srcs, res.stdout.splitlines()): | ||
if loc == "local": | ||
if extract: | ||
repository_ctx.report_progress("extracting local %s" % src.basename) | ||
repository_ctx.extract(src, stripPrefix = stripPrefix) | ||
else: | ||
repository_ctx.report_progress("symlinking local %s" % src.basename) | ||
repository_ctx.symlink(src, src.basename) | ||
else: | ||
sha256, _, url = loc.partition(" ") | ||
if extract: | ||
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz` | ||
# or similar wouldn't work | ||
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip", | ||
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works | ||
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:]) | ||
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename) | ||
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension) | ||
else: | ||
repository_ctx.report_progress("downloading remote %s" % src.basename) | ||
repository_ctx.download(url, src.basename, sha256 = sha256) | ||
|
||
def _download_and_extract_lfs(repository_ctx): | ||
attr = repository_ctx.attr | ||
src = repository_ctx.path(attr.src) | ||
if attr.build_file_content and attr.build_file: | ||
fail("You should specify only one among build_file_content and build_file for rule @%s" % repository_ctx.name) | ||
lfs_smudge(repository_ctx, [src], extract = True, stripPrefix = attr.strip_prefix) | ||
if attr.build_file_content: | ||
repository_ctx.file("BUILD.bazel", attr.build_file_content) | ||
elif attr.build_file: | ||
repository_ctx.symlink(attr.build_file, "BUILD.bazel") | ||
|
||
def _download_lfs(repository_ctx): | ||
attr = repository_ctx.attr | ||
if int(bool(attr.srcs)) + int(bool(attr.dir)) != 1: | ||
fail("Exactly one between `srcs` and `dir` must be defined for @%s" % repository_ctx.name) | ||
if attr.srcs: | ||
srcs = [repository_ctx.path(src) for src in attr.srcs] | ||
else: | ||
dir = repository_ctx.path(attr.dir) | ||
if not dir.is_dir: | ||
fail("`dir` not a directory in @%s" % repository_ctx.name) | ||
srcs = [f for f in dir.readdir() if not f.is_dir] | ||
lfs_smudge(repository_ctx, srcs) | ||
|
||
# with bzlmod the name is qualified with `~` separators, and we want the base name here | ||
name = repository_ctx.name.split("~")[-1] | ||
repository_ctx.file("BUILD.bazel", """ | ||
exports_files({files}) | ||
filegroup( | ||
name = "{name}", | ||
srcs = {files}, | ||
visibility = ["//visibility:public"], | ||
) | ||
""".format(name = name, files = repr([src.basename for src in srcs]))) | ||
|
||
lfs_archive = repository_rule( | ||
doc = "Export the contents from an on-demand LFS archive. The corresponding path should be added to be ignored " + | ||
"in `.lfsconfig`.", | ||
implementation = _download_and_extract_lfs, | ||
attrs = { | ||
"src": attr.label(mandatory = True, doc = "Local path to the LFS archive to extract."), | ||
"build_file_content": attr.string(doc = "The content for the BUILD file for this repository. " + | ||
"Either build_file or build_file_content can be specified, but not both."), | ||
"build_file": attr.label(doc = "The file to use as the BUILD file for this repository. " + | ||
"Either build_file or build_file_content can be specified, but not both."), | ||
"strip_prefix": attr.string(default = "", doc = "A directory prefix to strip from the extracted files. "), | ||
}, | ||
) | ||
|
||
lfs_files = repository_rule( | ||
doc = "Export LFS files for on-demand download. Exactly one between `srcs` and `dir` must be defined. The " + | ||
"corresponding paths should be added to be ignored in `.lfsconfig`.", | ||
implementation = _download_lfs, | ||
attrs = { | ||
"srcs": attr.label_list(doc = "Local paths to the LFS files to export."), | ||
"dir": attr.label(doc = "Local path to a directory containing LFS files to export. Only the direct contents " + | ||
"of the directory are exported"), | ||
}, | ||
) |