diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 13e8fbd..a589840 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ['3.10'] steps: - uses: actions/checkout@v2 @@ -55,18 +55,13 @@ jobs: matrix: include: - name-prefix: "all tests" - python-version: 3.8 + python-version: '3.10' os: ubuntu-latest package-overrides: "none" - name-prefix: "all tests" - python-version: 3.9 + python-version: '3.11' os: ubuntu-latest package-overrides: "none" - - name-prefix: "with internal numpy" - python-version: 3.9 - os: ubuntu-latest - # Test with numpy version that matches Google-internal version - package-overrides: "numpy==1.21.5" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -95,6 +90,6 @@ jobs: run: | python -m pytest - name: Run multi-device tests - if: matrix.python-version == 3.8 + if: matrix.python-version == '3.10' run: | FERMINET_CHEX_N_CPU_DEVICES=2 python -m pytest ferminet/tests/train_test.py diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 123c61e..75c5767 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -24,7 +24,7 @@ Scalar = kfac_jax.utils.Scalar Numeric = kfac_jax.utils.Numeric -vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0) +vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0) vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) diff --git a/ferminet/psiformer.py b/ferminet/psiformer.py index c5e058e..5a1e52d 100644 --- a/ferminet/psiformer.py +++ b/ferminet/psiformer.py @@ -401,7 +401,7 @@ def make_fermi_net( heads_dim=heads_dim, mlp_hidden_dims=mlp_hidden_dims, use_layer_norm=use_layer_norm, - ) + ) # pytype: disable=wrong-keyword-args psiformer_layers = make_psiformer_layers(nspins, charges.shape[0], options)