Skip to content

Commit

Permalink
Update install.py
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagmt-google authored Nov 12, 2024
1 parent 37ac07c commit 4f7124a
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,36 @@
REPO_ROOT = Path(__file__).parent


import argparse
import os
import subprocess
import sys
from pathlib import Path

from userbenchmark import list_userbenchmarks
from utils import generate_pkg_constraints, get_pkg_versions, TORCH_DEPS
from utils.python_utils import pip_install_requirements

REPO_ROOT = Path(__file__).parent


def install_algorithmic_efficiency_deps():
"""Installs algorithmic efficiency dependencies."""
print("Installing algorithmic efficiency dependencies...")
commands = [
"pip3 install -e '.[jax_cpu]'",
"pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'",
"pip3 install -e '.[full]'",
]
for command in commands:
try:
subprocess.check_call(command, shell=True)
except subprocess.CalledProcessError as e:
print(f"Error installing algorithmic efficiency dependencies: {e}")
return False
return True


if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
Expand Down Expand Up @@ -77,6 +107,10 @@
if args.check_only:
exit(0)

print("start installing deps for algorithmic_efficiency")
install_algorithmic_efficiency_deps()
print("done installing deps for algorithmic_efficiency")

if args.userbenchmark:
# Install userbenchmark dependencies if exists
userbenchmark_dir = REPO_ROOT.joinpath("userbenchmark", args.userbenchmark)
Expand Down

0 comments on commit 4f7124a

Please sign in to comment.