diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 5ee43d2af3..7037a00a6c 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -577,6 +577,15 @@ class ProdEnvMatAOp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut_r, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } // launch the gpu(nv) compute function deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong, @@ -875,6 +884,16 @@ class ProdEnvMatROp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } + // launch the gpu(nv) compute function deepmd::prod_env_mat_r_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong, @@ -1221,6 +1240,16 @@ class ProdEnvMatAMixOp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut_r, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } + // launch the gpu(nv) compute function deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong,