diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index cca1cd5cf..4710e3d00 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -577,9 +577,11 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { code); } +#ifdef CLAD_ENABLE_ENZYME_BACKEND // Gradient Structure for Reverse Mode Enzyme template struct EnzymeGradient { double d_arr[N]; }; } // namespace clad +#endif #endif // CLAD_DIFFERENTIATOR // Enable clad after the header was included. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index ad9981bb1..05935f50f 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -95,8 +95,10 @@ namespace clad { // Function to Differentiate with Clad as Backend void DifferentiateWithClad(); +#ifdef CLAD_ENABLE_ENZYME_BACKEND // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); +#endif /// Tries to find and build call to user-provided `_forw` function. clang::Expr* BuildCallToCustomForwPassFn( diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4d600caf7..a40584d31 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -457,8 +457,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!m_DiffReq.use_enzyme) DifferentiateWithClad(); +#ifdef CLAD_ENABLE_ENZYME_BACKEND else DifferentiateWithEnzyme(); +#endif gradientBody = endBlock(); m_Derivative->setBody(gradientBody); @@ -664,6 +666,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActOnEndOfDerivedFnBody(); } +#ifdef CLAD_ENABLE_ENZYME_BACKEND void ReverseModeVisitor::DifferentiateWithEnzyme() { unsigned numParams = m_DiffReq->getNumParams(); auto origParams = m_DiffReq->parameters(); @@ -772,6 +775,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(enzymeCall); } } +#endif StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr( const clang::CXXStdInitializerListExpr* ILE) {