From 19fe8939ecd2d1dddeb0b53204a91f19a6184cf1 Mon Sep 17 00:00:00 2001 From: gcassella Date: Mon, 6 Nov 2023 11:06:03 +0000 Subject: [PATCH 1/8] psd_inv_cholesky -> psd_inv --- ferminet/curvature_tags_and_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 123c61e..75c5767 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -24,7 +24,7 @@ Scalar = kfac_jax.utils.Scalar Numeric = kfac_jax.utils.Numeric -vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0) +vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0) vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) From 21a395329e1252bd66cab0ce9bd56eaf2f08ccf4 Mon Sep 17 00:00:00 2001 From: gcassella Date: Mon, 6 Nov 2023 11:14:35 +0000 Subject: [PATCH 2/8] Correct type hints in curvature_tags_and_blocks.py --- ferminet/curvature_tags_and_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 75c5767..44eca3a 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -141,8 +141,8 @@ def _update_cache( self, state: kfac_jax.TwoKroneckerFactored.State, identity_weight: kfac_jax.utils.Numeric, - exact_powers: set[kfac_jax.utils.Scalar], - approx_powers: set[kfac_jax.utils.Scalar], + exact_powers: Set[kfac_jax.utils.Scalar], + approx_powers: Set[kfac_jax.utils.Scalar], eigenvalues: bool, ) -> kfac_jax.TwoKroneckerFactored.State: del eigenvalues From ff2b7c60274f25603fc18866f85f191832c27034 Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Mon, 6 Nov 2023 13:21:15 +0000 Subject: [PATCH 3/8] Update ci-build.yaml -> python=3.10 --- .github/workflows/ci-build.yaml | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 13e8fbd..e0b81ff 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.10] steps: - uses: actions/checkout@v2 @@ -55,15 +55,11 @@ jobs: matrix: include: - name-prefix: "all tests" - python-version: 3.8 - os: ubuntu-latest - package-overrides: "none" - - name-prefix: "all tests" - python-version: 3.9 + python-version: 3.10 os: ubuntu-latest package-overrides: "none" - name-prefix: "with internal numpy" - python-version: 3.9 + python-version: 3.10 os: ubuntu-latest # Test with numpy version that matches Google-internal version package-overrides: "numpy==1.21.5" @@ -95,6 +91,6 @@ jobs: run: | python -m pytest - name: Run multi-device tests - if: matrix.python-version == 3.8 + if: matrix.python-version == 3.10 run: | FERMINET_CHEX_N_CPU_DEVICES=2 python -m pytest ferminet/tests/train_test.py From 313d0ee294d7d6822cc80d153ad84922b4718d4a Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Mon, 6 Nov 2023 13:23:16 +0000 Subject: [PATCH 4/8] 3.10 -> '3.10' --- .github/workflows/ci-build.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index e0b81ff..c5019e9 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.10] + python-version: ['3.10'] steps: - uses: actions/checkout@v2 @@ -55,11 +55,11 @@ jobs: matrix: include: - name-prefix: "all tests" - python-version: 3.10 + python-version: '3.10' os: ubuntu-latest package-overrides: "none" - name-prefix: "with internal numpy" - python-version: 3.10 + python-version: '3.10' os: ubuntu-latest # Test with numpy version that matches Google-internal version package-overrides: "numpy==1.21.5" @@ -91,6 +91,6 @@ jobs: run: | python -m pytest - name: Run multi-device tests - if: matrix.python-version == 3.10 + if: matrix.python-version == '3.10' run: | FERMINET_CHEX_N_CPU_DEVICES=2 python -m pytest ferminet/tests/train_test.py From d5c8966e16f9ee2cb3d5e8302f081dffc807130f Mon Sep 17 00:00:00 2001 From: gcassella Date: Mon, 6 Nov 2023 13:35:14 +0000 Subject: [PATCH 5/8] Silence spurious pytype error due to attr inheritance --- ferminet/psiformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ferminet/psiformer.py b/ferminet/psiformer.py index c5e058e..5a1e52d 100644 --- a/ferminet/psiformer.py +++ b/ferminet/psiformer.py @@ -401,7 +401,7 @@ def make_fermi_net( heads_dim=heads_dim, mlp_hidden_dims=mlp_hidden_dims, use_layer_norm=use_layer_norm, - ) + ) # pytype: disable=wrong-keyword-args psiformer_layers = make_psiformer_layers(nspins, charges.shape[0], options) From ef641d6398b315d13751ec51a1cb9ca23a09a5ed Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Mon, 6 Nov 2023 13:38:48 +0000 Subject: [PATCH 6/8] Add 3.11 test, comment internal numpy --- .github/workflows/ci-build.yaml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c5019e9..c9fe779 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -58,11 +58,15 @@ jobs: python-version: '3.10' os: ubuntu-latest package-overrides: "none" - - name-prefix: "with internal numpy" - python-version: '3.10' + - name-prefix: "all tests" + python-version: '3.11' os: ubuntu-latest - # Test with numpy version that matches Google-internal version - package-overrides: "numpy==1.21.5" + package-overrides: "none" + # - name-prefix: "with internal numpy" + # python-version: '3.10' + # os: ubuntu-latest + # Test with numpy version that matches Google-internal version + # package-overrides: "numpy==1.21.5" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From a9f82497b756b24a1b7de24b69e9542357b5be18 Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Tue, 7 Nov 2023 11:10:15 +0000 Subject: [PATCH 7/8] Delete internal numpy CI --- .github/workflows/ci-build.yaml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c9fe779..a589840 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -62,11 +62,6 @@ jobs: python-version: '3.11' os: ubuntu-latest package-overrides: "none" - # - name-prefix: "with internal numpy" - # python-version: '3.10' - # os: ubuntu-latest - # Test with numpy version that matches Google-internal version - # package-overrides: "numpy==1.21.5" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From fdaca31f6b9e4c8f10881de55a8aa7cef16316bc Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Tue, 7 Nov 2023 11:10:56 +0000 Subject: [PATCH 8/8] Revert Set -> set --- ferminet/curvature_tags_and_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 44eca3a..75c5767 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -141,8 +141,8 @@ def _update_cache( self, state: kfac_jax.TwoKroneckerFactored.State, identity_weight: kfac_jax.utils.Numeric, - exact_powers: Set[kfac_jax.utils.Scalar], - approx_powers: Set[kfac_jax.utils.Scalar], + exact_powers: set[kfac_jax.utils.Scalar], + approx_powers: set[kfac_jax.utils.Scalar], eigenvalues: bool, ) -> kfac_jax.TwoKroneckerFactored.State: del eigenvalues