Skip to content

Commit

Permalink
Try to exclude Enzyme from coverage when CLAD_ENABLE_ENZYME_BACKEND=Off
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-penev committed Oct 5, 2024
1 parent b6fea1b commit 5de9471
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,11 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
code);
}

#ifdef CLAD_ENABLE_ENZYME_BACKEND
// Gradient Structure for Reverse Mode Enzyme
template <unsigned N> struct EnzymeGradient { double d_arr[N]; };
} // namespace clad
#endif
#endif // CLAD_DIFFERENTIATOR

// Enable clad after the header was included.
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ namespace clad {
// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();

#ifdef CLAD_ENABLE_ENZYME_BACKEND
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();
#endif

/// Tries to find and build call to user-provided `_forw` function.
clang::Expr* BuildCallToCustomForwPassFn(
Expand Down
4 changes: 4 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (!m_DiffReq.use_enzyme)
DifferentiateWithClad();
#ifdef CLAD_ENABLE_ENZYME_BACKEND
else
DifferentiateWithEnzyme();
#endif

gradientBody = endBlock();
m_Derivative->setBody(gradientBody);
Expand Down Expand Up @@ -664,6 +666,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActOnEndOfDerivedFnBody();
}

#ifdef CLAD_ENABLE_ENZYME_BACKEND
void ReverseModeVisitor::DifferentiateWithEnzyme() {
unsigned numParams = m_DiffReq->getNumParams();
auto origParams = m_DiffReq->parameters();
Expand Down Expand Up @@ -772,6 +775,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}
#endif

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
Expand Down

0 comments on commit 5de9471

Please sign in to comment.