Skip to content

Commit

Permalink
Fix 4
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-penev committed Oct 5, 2024
1 parent aa54112 commit 7d1e84f
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 10 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
code);
}

#ifdef CLAD_ENABLE_ENZYME_BACKEND
// Gradient Structure for Reverse Mode Enzyme
template <unsigned N> struct EnzymeGradient { double d_arr[N]; };
#endif
} // namespace clad
#endif // CLAD_DIFFERENTIATOR

Expand Down
2 changes: 0 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ namespace clad {
// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();

#ifdef CLAD_ENABLE_ENZYME_BACKEND
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();
#endif

/// Tries to find and build call to user-provided `_forw` function.
clang::Expr* BuildCallToCustomForwPassFn(
Expand Down
6 changes: 0 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,14 +455,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

Stmt* gradientBody = nullptr;

#ifdef CLAD_ENABLE_ENZYME_BACKEND
if (!m_DiffReq.use_enzyme)
DifferentiateWithClad();
else
DifferentiateWithEnzyme();
#else
DifferentiateWithClad();
#endif

gradientBody = endBlock();
m_Derivative->setBody(gradientBody);
Expand Down Expand Up @@ -668,7 +664,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActOnEndOfDerivedFnBody();
}

#ifdef CLAD_ENABLE_ENZYME_BACKEND
void ReverseModeVisitor::DifferentiateWithEnzyme() {
unsigned numParams = m_DiffReq->getNumParams();
auto origParams = m_DiffReq->parameters();
Expand Down Expand Up @@ -777,7 +772,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}
#endif

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
Expand Down

0 comments on commit 7d1e84f

Please sign in to comment.