Skip to content

Commit

Permalink
Updates (AliciaCurth#24)
Browse files Browse the repository at this point in the history
* bump version

* Bump version: 0.1.6

* Workflows update: Scheduled runs, updated python versions (#2)

* update Py versions

* py3.10 fix

* update deps

* drop 3.10

* updare reqs

* debug jaxlib

* debug jaxlib

* windows debug

* debug windows

* debug windows

* optional JAX, Torch

* debug windows

* debug windows

* bump version

* Update .pre-commit-config.yaml

* Update .pre-commit-config.yaml

* update workflows (#3)

* update workflows

* debug release

* cleanup

* Bugfixing (#4)

* fix PyTorch API

* GPU fixes
enable more tests

* bugfixing

* train/eval review

* update JAX (#5)
  • Loading branch information
bcebere authored Sep 3, 2022
1 parent bb4da46 commit 21782d9
Show file tree
Hide file tree
Showing 31 changed files with 488 additions and 401 deletions.
73 changes: 73 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: Package release

on:
release:
types: [created]


jobs:
deploy_osx:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [macos-latest]

steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_osx.sh

deploy_linux:
strategy:
matrix:
python-version:
- cp37-cp37m
- cp38-cp38
- cp39-cp39
- cp10-cp10

runs-on: ubuntu-latest
container: quay.io/pypa/manylinux2014_x86_64
steps:
- uses: actions/checkout@v1
with:
submodules: true
- name: Set target Python version PATH
run: |
echo "/opt/python/${{ matrix.python-version }}/bin" >> $GITHUB_PATH
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_linux.sh

deploy_windows:
runs-on: windows-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
../../.github/workflows/scripts/release_windows.bat
16 changes: 16 additions & 0 deletions .github/workflows/scripts/release_linux.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

set -e

yum makecache -y
yum install centos-release-scl -y
yum-config-manager --enable rhel-server-rhscl-7-rpms
yum install llvm-toolset-7.0 python3 python3-devel -y

# Python
python3 -m pip install --upgrade pip
python3 -m pip install setuptools wheel twine auditwheel

# Publish
python3 -m pip wheel . -w dist/ --no-deps
twine upload --verbose --skip-existing dist/*
9 changes: 9 additions & 0 deletions .github/workflows/scripts/release_osx.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/sh

export MACOSX_DEPLOYMENT_TARGET=10.14

python -m pip install --upgrade pip
pip install setuptools wheel twine auditwheel

python3 setup.py build bdist_wheel --plat-name macosx_10_14_x86_64 --dist-dir wheel
twine upload --skip-existing wheel/*
7 changes: 7 additions & 0 deletions .github/workflows/scripts/release_windows.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
echo on

python -m pip install --upgrade pip
pip install setuptools wheel twine auditwheel

pip wheel . -w wheel/ --no-deps
twine upload --skip-existing wheel/*
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9']
python-version: ['3.7', '3.8', '3.9', "3.10"]
os: [macos-latest, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions catenets/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers, stax
from jax.experimental.stax import Dense, Elu, Relu, Sigmoid
from jax.example_libraries import optimizers, stax
from jax.example_libraries.stax import Dense, Elu, Relu, Sigmoid
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.model_selection import ParameterGrid

Expand Down
2 changes: 1 addition & 1 deletion catenets/models/jax/disentangled_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.example_libraries import optimizers

import catenets.logger as log
from catenets.models.constants import (
Expand Down
11 changes: 9 additions & 2 deletions catenets/models/jax/flextenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.experimental.stax import Dense, Sigmoid, elu, glorot_normal, normal, serial
from jax.example_libraries import optimizers
from jax.example_libraries.stax import (
Dense,
Sigmoid,
elu,
glorot_normal,
normal,
serial,
)

import catenets.logger as log
from catenets.models.constants import (
Expand Down
4 changes: 2 additions & 2 deletions catenets/models/jax/offsetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.experimental.stax import sigmoid
from jax.example_libraries import optimizers
from jax.example_libraries.stax import sigmoid

import catenets.logger as log
from catenets.models.constants import (
Expand Down
2 changes: 1 addition & 1 deletion catenets/models/jax/representation_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.example_libraries import optimizers

import catenets.logger as log
from catenets.models.constants import (
Expand Down
14 changes: 7 additions & 7 deletions catenets/models/jax/rnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as onp
import pandas as pd
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from sklearn.model_selection import StratifiedKFold

import catenets.logger as log
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
nonlin: str = DEFAULT_NONLIN,
binary_y: bool = False
binary_y: bool = False,
) -> None:
# settings
self.binary_y = binary_y
Expand Down Expand Up @@ -243,7 +243,7 @@ def train_r_net(
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
nonlin: str = DEFAULT_NONLIN,
binary_y: bool = False
binary_y: bool = False,
) -> Any:
# get shape of data
n, d = X.shape
Expand Down Expand Up @@ -288,7 +288,7 @@ def train_r_net(
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
binary_y=binary_y
binary_y=binary_y,
)
if data_split:
# keep only prediction data
Expand Down Expand Up @@ -333,7 +333,7 @@ def train_r_net(
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
binary_y=binary_y
binary_y=binary_y,
)

log.debug("Training second stage.")
Expand Down Expand Up @@ -415,7 +415,7 @@ def _train_and_predict_r_stage1(
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
nonlin: str = DEFAULT_NONLIN,
binary_y: bool = False
binary_y: bool = False,
) -> Any:
if len(w.shape) > 1:
w = w.reshape((len(w),))
Expand Down Expand Up @@ -443,7 +443,7 @@ def _train_and_predict_r_stage1(
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
binary_y=binary_y
binary_y=binary_y,
)
mu_hat = predict_fun_out(params_out, X_pred)

Expand Down
2 changes: 1 addition & 1 deletion catenets/models/jax/snet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.example_libraries import optimizers

import catenets.logger as log
from catenets.models.constants import (
Expand Down
2 changes: 1 addition & 1 deletion catenets/models/jax/tnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.experimental import optimizers
from jax.example_libraries import optimizers

import catenets.logger as log
from catenets.models.constants import (
Expand Down
Loading

0 comments on commit 21782d9

Please sign in to comment.