Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.1】 Integrate PaddlePaddle as a New Backend #704

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ jobs:
- name: "Run Torch tests"
run: coverage run --append -m pysr test torch
if: ${{ matrix.test-id == 'main' }}
- name: "Install Paddle"
run: pip install paddle # (optional import)
if: ${{ matrix.test-id == 'main' }}
- name: "Run Paddle tests"
run: coverage run --append -m pysr test paddle
if: ${{ matrix.test-id == 'main' }}
- name: "Coveralls"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down Expand Up @@ -192,7 +198,7 @@ jobs:
pip install .
pip install mypy
- name: "Install additional dependencies"
run: python -m pip install jax jaxlib torch
run: python -m pip install jax jaxlib torch paddlepaddle
if: ${{ matrix.python-version != '3.8' }}
- name: "Run mypy"
run: python -m mypy --install-types --non-interactive pysr
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/CI_Windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ jobs:
run: pip install torch # (optional import)
- name: "Run Torch tests"
run: python -m pysr test torch
- name: "Install Paddle"
run: pip install paddlepaddle # (optional import)
- name: "Run Paddle tests"
run: python -m pysr test paddle
4 changes: 4 additions & 0 deletions .github/workflows/CI_mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ jobs:
run: pip install torch # (optional import)
- name: "Run Torch tests"
run: python -m pysr test torch
- name: "Install Paddle"
run: pip install paddlepaddle # (optional import)
- name: "Run Paddle tests"
run: python -m pysr test paddle
115 changes: 115 additions & 0 deletions examples/pysr_demo_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os

import numpy as np
import paddle
from paddle import nn
from paddle.io import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

os.environ["PYTHON_JULIACALL_THREADS"] = "1"

rstate = np.random.RandomState(0)

N = 100000
Nt = 10
X = 6 * rstate.rand(N, Nt, 5) - 3
y_i = X[..., 0] ** 2 + 6 * np.cos(2 * X[..., 2])
y = np.sum(y_i, axis=1) / y_i.shape[1]
z = y**2


hidden = 128
total_steps = 50_000


def mlp(size_in, size_out, act=nn.ReLU):
return nn.Sequential(
nn.Linear(size_in, hidden),
act(),
nn.Linear(hidden, hidden),
act(),
nn.Linear(hidden, hidden),
act(),
nn.Linear(hidden, size_out),
)


class SumNet(nn.Layer):
def __init__(self):
super().__init__()

########################################################
# The same inductive bias as above!
self.g = mlp(5, 1)
self.f = mlp(1, 1)

def forward(self, x):
y_i = self.g(x)[:, :, 0]
y = paddle.sum(y_i, axis=1, keepdim=True) / y_i.shape[1]
z = self.f(y)
return z[:, 0]


Xt = paddle.to_tensor(X).astype("float32")
zt = paddle.to_tensor(z).astype("float32")
X_train, X_test, z_train, z_test = train_test_split(Xt, zt, random_state=0)
train_set = TensorDataset([X_train, z_train])
train = DataLoader(train_set, batch_size=128, shuffle=True)
test_set = TensorDataset([X_test, z_test])
test = DataLoader(test_set, batch_size=256)

paddle.seed(0)

model = SumNet()
max_lr = 1e-2
model = paddle.Model(model)
scheduler = paddle.optimizer.lr.OneCycleLR(
max_learning_rate=max_lr, total_steps=total_steps, divide_factor=1e4
)
optim = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
model.prepare(optim, paddle.nn.MSELoss())
model.fit(train, test, num_iters=total_steps, eval_freq=1000)

np.random.seed(0)
idx = np.random.randint(0, 10000, size=1000)

X_for_pysr = Xt[idx]
y_i_for_pysr = model.network.g(X_for_pysr)[:, :, 0]
y_for_pysr = paddle.sum(y_i_for_pysr, axis=1) / y_i_for_pysr.shape[1]
z_for_pysr = zt[idx] # Use true values.


nnet_recordings = {
"g_input": X_for_pysr.detach().cpu().numpy().reshape(-1, 5),
"g_output": y_i_for_pysr.detach().cpu().numpy().reshape(-1),
"f_input": y_for_pysr.detach().cpu().numpy().reshape(-1, 1),
"f_output": z_for_pysr.detach().cpu().numpy().reshape(-1),
}

# Save the data for later use:
import pickle as pkl

with open("nnet_recordings.pkl", "wb") as f:
pkl.dump(nnet_recordings, f)

import pickle as pkl

nnet_recordings = pkl.load(open("nnet_recordings.pkl", "rb"))
f_input = nnet_recordings["f_input"]
f_output = nnet_recordings["f_output"]
g_input = nnet_recordings["g_input"]
g_output = nnet_recordings["g_output"]


rstate = np.random.RandomState(0)
f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)
from pysr import PySRRegressor

model = PySRRegressor(
niterations=50,
binary_operators=["+", "-", "*"],
unary_operators=["cos", "square"],
)
model.fit(g_input[f_sample_idx], g_output[f_sample_idx])

model.equations_[["complexity", "loss", "equation"]]
144 changes: 144 additions & 0 deletions examples/pysr_demo_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
from multiprocessing import cpu_count

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

from pysr import PySRRegressor

os.environ["PYTHON_JULIACALL_THREADS"] = "1"
rstate = np.random.RandomState(0)

N = 100000
Nt = 10
X = 6 * rstate.rand(N, Nt, 5) - 3
y_i = X[..., 0] ** 2 + 6 * np.cos(2 * X[..., 2])
y = np.sum(y_i, axis=1) / y_i.shape[1]
z = y**2


hidden = 128
total_steps = 50_000


def mlp(size_in, size_out, act=nn.ReLU):
return nn.Sequential(
nn.Linear(size_in, hidden),
act(),
nn.Linear(hidden, hidden),
act(),
nn.Linear(hidden, hidden),
act(),
nn.Linear(hidden, size_out),
)


class SumNet(pl.LightningModule):
def __init__(self):
super().__init__()

########################################################
# The same inductive bias as above!
self.g = mlp(5, 1)
self.f = mlp(1, 1)

def forward(self, x):
y_i = self.g(x)[:, :, 0]
y = torch.sum(y_i, dim=1, keepdim=True) / y_i.shape[1]
z = self.f(y)
return z[:, 0]

########################################################

# PyTorch Lightning bookkeeping:
def training_step(self, batch, batch_idx):
x, z = batch
predicted_z = self(x)
loss = F.mse_loss(predicted_z, z)
return loss

def validation_step(self, batch, batch_idx):
return self.training_step(batch, batch_idx)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)
scheduler = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.max_lr,
total_steps=self.trainer.estimated_stepping_batches,
final_div_factor=1e4,
),
"interval": "step",
}
return [optimizer], [scheduler]


Xt = torch.tensor(X).float()
zt = torch.tensor(z).float()
X_train, X_test, z_train, z_test = train_test_split(Xt, zt, random_state=0)
train_set = TensorDataset(X_train, z_train)
train = DataLoader(
train_set, batch_size=128, num_workers=cpu_count(), shuffle=True, pin_memory=True
)
test_set = TensorDataset(X_test, z_test)
test = DataLoader(test_set, batch_size=256, num_workers=cpu_count(), pin_memory=True)

pl.seed_everything(0)
model = SumNet()
model.total_steps = total_steps
model.max_lr = 1e-2

trainer = pl.Trainer(max_steps=total_steps, accelerator="gpu", devices=1)
trainer.fit(model, train_dataloaders=train, val_dataloaders=test)


np.random.seed(0)
idx = np.random.randint(0, 10000, size=1000)

X_for_pysr = Xt[idx]
y_i_for_pysr = model.g(X_for_pysr)[:, :, 0]
y_for_pysr = torch.sum(y_i_for_pysr, dim=1) / y_i_for_pysr.shape[1]
z_for_pysr = zt[idx] # Use true values.

X_for_pysr.shape, y_i_for_pysr.shape


nnet_recordings = {
"g_input": X_for_pysr.detach().cpu().numpy().reshape(-1, 5),
"g_output": y_i_for_pysr.detach().cpu().numpy().reshape(-1),
"f_input": y_for_pysr.detach().cpu().numpy().reshape(-1, 1),
"f_output": z_for_pysr.detach().cpu().numpy().reshape(-1),
}

# Save the data for later use:
import pickle as pkl

with open("nnet_recordings.pkl", "wb") as f:
pkl.dump(nnet_recordings, f)

import pickle as pkl

nnet_recordings = pkl.load(open("nnet_recordings.pkl", "rb"))
f_input = nnet_recordings["f_input"]
f_output = nnet_recordings["f_output"]
g_input = nnet_recordings["g_input"]
g_output = nnet_recordings["g_output"]


rstate = np.random.RandomState(0)
f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)

model = PySRRegressor(
niterations=50,
binary_operators=["+", "-", "*"],
unary_operators=["cos", "square"],
)
model.fit(g_input[f_sample_idx], g_output[f_sample_idx])

model.equations_[["complexity", "loss", "equation"]]
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,25 @@ build-backend = "setuptools.build_meta"
[project]
name = "pysr"
version = "0.19.4"
authors = [
{name = "Miles Cranmer", email = "[email protected]"},
]
authors = [{ name = "Miles Cranmer", email = "[email protected]" }]
description = "Simple and efficient symbolic regression"
readme = {file = "README.md", content-type = "text/markdown"}
license = {file = "LICENSE"}
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
"License :: OSI Approved :: Apache Software License"
"License :: OSI Approved :: Apache Software License",
]
dynamic = ["dependencies"]

[tool.setuptools]
packages = ["pysr", "pysr._cli", "pysr.test"]
include-package-data = false
package-data = {pysr = ["juliapkg.json"]}
package-data = { pysr = ["juliapkg.json"] }

[tool.setuptools.dynamic]
dependencies = {file = "requirements.txt"}
dependencies = { file = "requirements.txt" }

[tool.isort]
profile = "black"
Expand All @@ -38,6 +36,7 @@ dev-dependencies = [
"mypy>=1.10.0",
"jax[cpu]>=0.4.26",
"torch>=2.3.0",
"paddlepaddle>=2.6.1",
"pandas-stubs>=2.2.1.240316",
"types-pytz>=2024.1.0.20240417",
"types-openpyxl>=3.1.0.20240428",
Expand Down
2 changes: 2 additions & 0 deletions pysr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import sklearn_monkeypatch
from .deprecated import best, best_callable, best_row, best_tex, install, pysr
from .export_jax import sympy2jax
from .export_paddle import sympy2paddle
from .export_torch import sympy2torch
from .julia_extensions import load_all_packages
from .sr import PySRRegressor
Expand All @@ -19,6 +20,7 @@
"sklearn_monkeypatch",
"sympy2jax",
"sympy2torch",
"sympy2paddle",
"install",
"load_all_packages",
"PySRRegressor",
Expand Down
7 changes: 5 additions & 2 deletions pysr/_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
runtests,
runtests_dev,
runtests_jax,
runtests_paddle,
runtests_startup,
runtests_torch,
)
Expand Down Expand Up @@ -48,7 +49,7 @@ def _install(julia_project, quiet, precompile):
)


TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"}
TEST_OPTIONS = {"main", "jax", "torch", "paddle", "cli", "dev", "startup"}


@pysr.command("test")
Expand All @@ -63,7 +64,7 @@ def _install(julia_project, quiet, precompile):
def _tests(tests, expressions):
"""Run parts of the PySR test suite.

Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
Choose from main, jax, torch, paddle, cli, dev, and startup. You can give multiple tests, separated by commas.
"""
test_cases = []
for test in tests.split(","):
Expand All @@ -73,6 +74,8 @@ def _tests(tests, expressions):
test_cases.extend(runtests_jax(just_tests=True))
elif test == "torch":
test_cases.extend(runtests_torch(just_tests=True))
elif test == "paddle":
test_cases.extend(runtests_paddle(just_tests=True))
elif test == "cli":
runtests_cli = get_runtests_cli()
test_cases.extend(runtests_cli(just_tests=True))
Expand Down
Loading