diff --git a/test/call_tests.jl b/test/call_tests.jl index 4e8396c7..9bcdedc4 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -29,28 +29,28 @@ Test.@testset "Call Tests" begin nvars_ = Int[2] aug_steers = Bool[false, true] inplaces = Bool[false, true] - adb_list = - AbstractDifferentiation.AbstractBackend[AbstractDifferentiation.ZygoteBackend(), - # AbstractDifferentiation.ReverseDiffBackend(), - # AbstractDifferentiation.ForwardDiffBackend(), - ] - adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(), - # ADTypes.AutoEnzyme(Enzyme.Forward), + adb_list = AbstractDifferentiation.AbstractBackend[ + # AbstractDifferentiation.ZygoteBackend(), + # AbstractDifferentiation.ReverseDiffBackend(), + # AbstractDifferentiation.ForwardDiffBackend(), + ] + adtypes = ADTypes.AbstractADType[ADTypes.AutoEnzyme(Enzyme.Forward), # ADTypes.AutoEnzyme(Enzyme.Reverse), + # ADTypes.AutoZygote(), # ADTypes.AutoReverseDiff(), # ADTypes.AutoForwardDiff(), ] compute_modes = ContinuousNormalizingFlows.ComputeMode[ - ContinuousNormalizingFlows.ADVecJacVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.ADJacVecVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.ADVecJacVectorMode( + # AbstractDifferentiation.ZygoteBackend(), + # ), + # ContinuousNormalizingFlows.ADJacVecVectorMode( + # AbstractDifferentiation.ZygoteBackend(), + # ), + # ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), # ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme(Enzyme.Reverse)), ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme(Enzyme.Forward)), # ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme(Enzyme.Reverse)), diff --git a/test/fit_tests.jl b/test/fit_tests.jl index 7ac12439..855ed017 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -26,23 +26,23 @@ Test.@testset "Fit Tests" begin nvars_ = Int[2] aug_steers = Bool[false, true] inplaces = Bool[false, true] - adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(), - # ADTypes.AutoEnzyme(Enzyme.Forward), + adtypes = ADTypes.AbstractADType[ADTypes.AutoEnzyme(Enzyme.Forward), # ADTypes.AutoEnzyme(Enzyme.Reverse), + # ADTypes.AutoZygote(), # ADTypes.AutoReverseDiff(), # ADTypes.AutoForwardDiff(), ] compute_modes = ContinuousNormalizingFlows.ComputeMode[ - ContinuousNormalizingFlows.ADVecJacVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.ADJacVecVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.ADVecJacVectorMode( + # AbstractDifferentiation.ZygoteBackend(), + # ), + # ContinuousNormalizingFlows.ADJacVecVectorMode( + # AbstractDifferentiation.ZygoteBackend(), + # ), + # ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + # ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), # ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme(Enzyme.Reverse)), ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme(Enzyme.Forward)), # ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme(Enzyme.Reverse)),