Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) enzyme integration #1125

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ if(ENABLE_CUDA AND ${CMAKE_VERSION} VERSION_LESS 3.18.0)
message(FATAL_ERROR "Serac requires CMake version 3.18.0+ when CUDA is enabled.")
endif()

# N.B. leave compilers unspecified when configuring CMake to ensure that
# the clang/clang++ binaries appropriate for enzyme are chosen
set(CMAKE_C_COMPILER "${LLVM_BIN_DIR}/clang")
set(CMAKE_CXX_COMPILER "${LLVM_BIN_DIR}/clang++")

project(serac LANGUAGES CXX C)

# MPI is required in Serac.
Expand Down
17 changes: 17 additions & 0 deletions cmake/thirdparty/SetupSeracThirdParty.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,20 @@ if (NOT SERAC_THIRD_PARTY_LIBRARIES_FOUND)
endforeach()
endif()
endif()


#------------------------------------------------------------------------------
# Enzyme
#------------------------------------------------------------------------------
if (FALSE)
include(FetchContent)

FetchContent_Declare(
enzyme
URL https://github.com/EnzymeAD/Enzyme/archive/refs/tags/v0.0.110.tar.gz
SOURCE_SUBDIR enzyme
)
FetchContent_MakeAvailable(enzyme)
else()
add_subdirectory("/home/sam/code/Enzyme/official/enzyme" ${PROJECT_BINARY_DIR}/tmp/enzyme)
endif ()
9 changes: 8 additions & 1 deletion src/serac/numerics/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,15 @@ blt_add_library(
HEADERS ${functional_headers} ${functional_detail_headers}
SOURCES ${functional_sources}
DEPENDS_ON ${functional_depends}
)
)

# without this, I get
# "error while loading shared libraries: libomp.so: cannot open
# shared object file: No such file or directory"
target_link_libraries(serac_functional PUBLIC OpenMP::OpenMP_CXX)
target_link_libraries(serac_functional PUBLIC ClangEnzymeFlags)

target_compile_options(serac_functional PUBLIC -mllvm -enzyme-loose-types)

install(FILES ${functional_headers} DESTINATION include/serac/numerics/functional )
install(FILES ${functional_detail_headers} DESTINATION include/serac/numerics/functional/detail )
Expand Down
2 changes: 1 addition & 1 deletion src/serac/numerics/functional/detail/triangle_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ struct finite_element<mfem::Geometry::TRIANGLE, H1<p, c> > {
}

template <typename in_t, int q>
static auto batch_apply_shape_fn(int j, tensor<in_t, q*(q + 1) / 2> input, const TensorProductQuadratureRule<q>&)
static auto batch_apply_shape_fn(int j, tensor<in_t, q*(q+1) / 2> input, const TensorProductQuadratureRule<q>&)
{
using source_t = decltype(get<0>(get<0>(in_t{})) + dot(get<1>(get<0>(in_t{})), tensor<double, 2>{}));
using flux_t = decltype(get<0>(get<1>(in_t{})) + dot(get<1>(get<1>(in_t{})), tensor<double, 2>{}));
Expand Down
94 changes: 66 additions & 28 deletions src/serac/numerics/functional/domain_integral_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "serac/numerics/functional/quadrature_data.hpp"
#include "serac/numerics/functional/function_signature.hpp"
#include "serac/numerics/functional/differentiate_wrt.hpp"
#include "serac/numerics/functional/enzyme_wrapper.hpp"

#include <RAJA/index/RangeSegment.hpp>
#include <RAJA/RAJA.hpp>
Expand Down Expand Up @@ -90,24 +91,36 @@ template <typename lambda, typename coords_type, typename... T, typename qpt_dat
SERAC_HOST_DEVICE auto apply_qf(lambda&& qf, double t, coords_type&& x_q, qpt_data_type&& qpt_data,
const serac::tuple<T...>& arg_tuple)
{
return apply_qf_helper(qf, t, x_q, qpt_data, arg_tuple,
std::make_integer_sequence<int, static_cast<int>(sizeof...(T))>{});
return apply_qf_helper(qf, t, x_q, qpt_data, arg_tuple, std::make_integer_sequence<int, static_cast<int>(sizeof...(T))>{});
}

template <int i, int dim, typename... trials, typename lambda, typename qpt_data_type>
auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data)
{
using qf_arguments = serac::tuple<typename QFunctionArgument<trials, serac::Dimension<dim> >::type...>;
return get_gradient(apply_qf(qf, double{}, serac::tuple<tensor<double, dim>, tensor<double, dim, dim> >{}, qpt_data,
make_dual_wrt<i>(qf_arguments{})));
using output_type = decltype(apply_qf(qf, double{}, serac::tuple<tensor<double, dim>, tensor<double, dim, dim> >{}, qpt_data, qf_arguments{}));
return typename impl::nested< output_type, decltype(type<i>(qf_arguments{})) >::type{};
};

template <typename lambda, int dim, int n, typename... T>

template <typename T1, typename T2, int dim>
SERAC_HOST_DEVICE auto parent_to_physical(const tuple<T1, T2>& qf_input, const tensor<double, dim, dim> & invJ) {
return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)};
}

template <typename T1, typename T2, int dim>
SERAC_HOST_DEVICE auto physical_to_parent(const tuple<T1, T2>& qf_output, const tensor<double, dim, dim> & invJ, double detJ) {
// assumes family == Family::H1 for now
return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ};
}

template < typename lambda, int dim, int n, typename... T>
SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor<double, dim, n>& x,
const tensor<double, dim, dim, n>& J, const T&... inputs)
{
using position_t = serac::tuple<tensor<double, dim>, tensor<double, dim, dim> >;
using return_type = decltype(qf(double{}, position_t{}, T{}[0]...));

tensor<return_type, n> outputs{};
for (int i = 0; i < n; i++) {
tensor<double, dim> x_q;
Expand All @@ -118,11 +131,39 @@ SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor
}
x_q[j] = x(j, i);
}
outputs[i] = qf(t, serac::tuple{x_q, J_q}, inputs[i]...);
double detJ_q = det(J_q);
tensor<double, dim, dim > invJ_q = inv(J_q);
auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...);
outputs[i] = physical_to_parent(qf_output, invJ_q, detJ_q);
}
return outputs;
}

template < uint32_t differentiation_index, typename derivative_type, typename lambda, int dim, int n, typename... T>
SERAC_HOST_DEVICE auto batch_apply_qf_derivative(derivative_type * doutputs, lambda qf, double t, const tensor<double, dim, n>& x,
const tensor<double, dim, dim, n>& J, const tensor<T, n> &... inputs)
{
for (int i = 0; i < n; i++) {
tensor<double, dim> x_q;
tensor<double, dim, dim> J_q;
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
J_q[j][k] = J(k, j, i);
}
x_q[j] = x(j, i);
}
double detJ_q = det(J_q);
tensor<double, dim, dim > invJ_q = inv(J_q);
auto func = [&](const auto & ... input){
return physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(input, invJ_q) ...), invJ_q, detJ_q);
};

doutputs[i] = jacfwd<differentiation_index>(func, inputs[i]...);
}

return doutputs;
}

template <typename lambda, int dim, int n, typename qpt_data_type, typename... T>
SERAC_HOST_DEVICE auto batch_apply_qf(lambda qf, double t, const tensor<double, dim, n>& x,
const tensor<double, dim, dim, n>& J, qpt_data_type* qpt_data, bool update_state,
Expand Down Expand Up @@ -153,7 +194,9 @@ template <uint32_t differentiation_index, int Q, mfem::Geometry::Type geom, type
typename trial_element_tuple, typename lambda_type, typename state_type, typename derivative_type,
int... indices>
void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, double t,
const std::vector<const double*>& inputs, double* outputs, const double* positions,
const std::vector<const double*>& inputs_e,
double* outputs,
const double* positions,
const double* jacobians, lambda_type qf,
[[maybe_unused]] axom::ArrayView<state_type, 2> qf_state,
[[maybe_unused]] derivative_type* qf_derivatives, const int* elements,
Expand All @@ -169,7 +212,7 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do
[[maybe_unused]] auto qpts_per_elem = num_quadrature_points(geom, Q);

[[maybe_unused]] tuple u = {
reinterpret_cast<const typename decltype(type<indices>(trial_elements))::dof_type*>(inputs[indices])...};
reinterpret_cast<const typename decltype(type<indices>(trial_elements))::dof_type*>(inputs_e[indices])...};

// for each element in the domain
for (uint32_t e = 0; e < num_elements; ++e) {
Expand All @@ -179,37 +222,21 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do

//[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{};
// batch-calculate values / derivatives of each trial space, at each quadrature point
[[maybe_unused]] tuple qf_inputs = {promote_each_to_dual_when<indices == differentiation_index>(
get<indices>(trial_elements).interpolate(get<indices>(u)[elements[e]], rule))...};

// use J_e to transform values / derivatives on the parent element
// to the to the corresponding values / derivatives on the physical element
(parent_to_physical<get<indices>(trial_elements).family>(get<indices>(qf_inputs), J_e), ...);
tuple qf_inputs = {get<indices>(trial_elements).interpolate(get<indices>(u)[elements[e]], rule)...};

// (batch) evalute the q-function at each quadrature point
//
// note: the weird immediately-invoked lambda expression is
// a workaround for a bug in GCC(<12.0) where it fails to
// decide which function overload to use, and crashes
auto qf_outputs = [&]() {
if constexpr (std::is_same_v<state_type, Nothing>) {
return batch_apply_qf_no_qdata(qf, t, x_e, J_e, get<indices>(qf_inputs)...);
} else {
return batch_apply_qf(qf, t, x_e, J_e, &qf_state(e, 0), update_state, get<indices>(qf_inputs)...);
}
}();

// use J to transform sources / fluxes on the physical element
// back to the corresponding sources / fluxes on the parent element
physical_to_parent<test_element::family>(qf_outputs, J_e);
// TODO: reenable internal_variables
auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get<indices>(qf_inputs)...);

// write out the q-function derivatives after applying the
// physical_to_parent transformation, so that those transformations
// won't need to be applied in the action_of_gradient and element_gradient kernels
if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) {
for (int q = 0; q < leading_dimension(qf_outputs); q++) {
qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]);
}
batch_apply_qf_derivative<differentiation_index>(qf_derivatives + e * uint32_t(qpts_per_elem), qf, t, x_e, J_e, get<indices>(qf_inputs)...);
}

// (batch) integrate the material response against the test-space basis functions
Expand All @@ -229,6 +256,17 @@ SERAC_HOST_DEVICE auto chain_rule(const S& dfdx, const T& dx)
}

if constexpr (!is_QOI) {
#if 0

serac::tuple<
serac::tuple< serac::tensor<serac::tensor<double, 2>, 2>, serac::tensor<serac::tensor<double, 2, 2>, 2> >,
serac::tuple< serac::tensor<serac::tensor<double, 2>, 2, 2>, serac::tensor<serac::tensor<double, 2, 2>, 2, 2>>
>,
serac::tuple< serac::tensor<double, 2>, serac::tensor<double, 2, 2> >

#endif


return serac::tuple{serac::chain_rule(serac::get<0>(serac::get<0>(dfdx)), serac::get<0>(dx)) +
serac::chain_rule(serac::get<1>(serac::get<0>(dfdx)), serac::get<1>(dx)),
serac::chain_rule(serac::get<0>(serac::get<1>(dfdx)), serac::get<0>(dx)) +
Expand Down
Loading