Skip to content

Commit

Permalink
Merge pull request #688 from MilesCranmer/install-all-extensions
Browse files Browse the repository at this point in the history
Create `load_all_packages` to install Julia extensions
  • Loading branch information
MilesCranmer authored Aug 1, 2024
2 parents 3aee19e + 23eafbe commit d7e87b4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ADD ./pysr /pysr/pysr
RUN pip3 install --no-cache-dir .

# Install Julia pre-requisites:
RUN python3 -c 'import pysr'
RUN python3 -c 'import pysr; pysr.load_all_packages()'

# metainformation
LABEL org.opencontainers.image.authors = "Miles Cranmer"
Expand Down
2 changes: 2 additions & 0 deletions pysr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .deprecated import best, best_callable, best_row, best_tex, install, pysr
from .export_jax import sympy2jax
from .export_torch import sympy2torch
from .julia_extensions import load_all_packages
from .sr import PySRRegressor

# This file is created by setuptools_scm during the build process:
Expand All @@ -19,6 +20,7 @@
"sympy2jax",
"sympy2torch",
"install",
"load_all_packages",
"PySRRegressor",
"best",
"best_callable",
Expand Down
11 changes: 11 additions & 0 deletions pysr/julia_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def load_required_packages(
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")


def load_all_packages():
"""Install and load all Julia extensions available to PySR."""
load_required_packages(
turbo=True, bumper=True, enable_autodiff=True, cluster_manager="slurm"
)


# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
# ideally parameterizable via the regular Python extras API


def isinstalled(uuid_s: str):
return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))

Expand Down
7 changes: 6 additions & 1 deletion pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sympy # type: ignore
from sklearn.utils.estimator_checks import check_estimator

from pysr import PySRRegressor, install, jl
from pysr import PySRRegressor, install, jl, load_all_packages
from pysr.export_latex import sympy2latex
from pysr.feature_selection import _handle_feature_selection, run_feature_selection
from pysr.julia_helpers import init_julia
Expand Down Expand Up @@ -739,6 +739,11 @@ def test_param_groupings(self):
# Check the sets are equal:
self.assertSetEqual(set(params), set(regressor_params))

def test_load_all_packages(self):
"""Test we can load all packages at once."""
load_all_packages()
self.assertTrue(jl.seval("ClusterManagers isa Module"))


class TestHelpMessages(unittest.TestCase):
"""Test user help messages."""
Expand Down

0 comments on commit d7e87b4

Please sign in to comment.