Skip to content

Commit

Permalink
Merge pull request #278 from ACEsuit/co/chspl
Browse files Browse the repository at this point in the history
Update ChunkSplitters to v3
  • Loading branch information
cortner authored Oct 22, 2024
2 parents 05f95bb + 8bdfea1 commit 7b223a3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ AtomsBase = "0.4.1"
AtomsBuilder = "0.2.0"
AtomsCalculators = "0.2"
AtomsCalculatorsUtilities = "0.1"
ChunkSplitters = "< 3"
ChunkSplitters = "3.0"
EquivariantModels = "0.0.5"
ExtXYZ = "0.2.0"
Interpolations = "0.15"
Expand Down
8 changes: 4 additions & 4 deletions src/models/calculators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ function energy_forces_virial(
init_f() = AtomsCalculators.zero_forces(at, V)
init_v() = AtomsCalculators.zero_virial(at, V)

E_F_V = Folds.sum(collect(chunks(domain, ntasks)),
E_F_V = Folds.sum(collect(index_chunks(domain; n = ntasks)),
executor;
init = [init_e(), init_f(), init_v()],
) do (sub_domain, _)
) do sub_domain

energy = init_e()
forces = init_f()
Expand Down Expand Up @@ -196,10 +196,10 @@ function pullback_EFV(Δefv,
# assumes that the loss is dimensionless and that the
# gradient w.r.t. parameters therefore must also be dimensionless

g_vec = Folds.sum(collect(chunks(domain, ntasks)),
g_vec = Folds.sum(collect(index_chunks(domain; n = ntasks)),
executor;
init = zeros(TP, length(ps_vec)),
) do (sub_domain, _)
) do sub_domain

g_loc = zeros(TP, length(ps_vec))
for i in sub_domain
Expand Down

0 comments on commit 7b223a3

Please sign in to comment.