Skip to content

Commit

Permalink
Fix 2 typos and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Dec 13, 2023
1 parent d3beb29 commit 2ad570b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/ManifoldsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ export ×,
change_representer!,
copy,
copyto!,
default_estimation_method,
default_inverse_retraction_method,
default_retraction_method,
default_vector_transport_method,
Expand Down
5 changes: 2 additions & 3 deletions src/estimation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Method for estimation based on geodesic interpolation that is restricted to some
"""
struct GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod
radius::T

function GeodesicInterpolationWithinRadius(radius::T) where {T}
radius > 0 && return new{T}(radius)
return throw(
Expand All @@ -81,5 +80,5 @@ The exceptional functions are
"""
default_estimation_method(M::AbstractManifold)

default_estimation_method(M::AbstractManifold, f, T) = get_default_estimation_method(M, f)
default_estimation_method(M::AbstractManifold, f) = get_default_estimation_method(M)
default_estimation_method(M::AbstractManifold, f, T) = default_estimation_method(M, f)
default_estimation_method(M::AbstractManifold, f) = default_estimation_method(M)
54 changes: 54 additions & 0 deletions test/default_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ Base.getindex(x::MatrixVectorTransport, i) = x.m[:, i]

Base.size(x::MatrixVectorTransport) = (size(x.m, 2),)

ManifoldsBase.default_estimation_method(::DefaultManifold) = GradientDescentEstimation()

@testset "Testing Default (Euclidean)" begin
M = ManifoldsBase.DefaultManifold(3)
types = [
Expand Down Expand Up @@ -887,4 +889,56 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),)
@test isapprox(vee(MC, p, [1 + 2im, 3 + 4im, 5 + 6im]), [1, 3, 5, 2, 4, 6])
@test isapprox(hat(MC, p, [1, 3, 5, 2, 4, 6]), [1 + 2im, 3 + 4im, 5 + 6im])
end

@testset "Estimation Method defaults" begin
M = ManifoldsBase.DefaultManifold(3)
@test default_estimation_method(M) == GradientDescentEstimation()
# fallbacks
@test default_estimation_method(M, manifold_dimension) ==
default_estimation_method(M)
@test default_estimation_method(M, manifold_dimension, DefaultPoint) ==
default_estimation_method(M)
# Retraction
@test default_estimation_method(M, retract) == default_retraction_method(M)
@test default_estimation_method(M, retract, DefaultPoint) ==
default_retraction_method(M)
@test default_estimation_method(M, retract!) == default_retraction_method(M)
@test default_estimation_method(M, retract!, DefaultPoint) ==
default_retraction_method(M)
# Inverse Retraction
@test default_estimation_method(M, inverse_retract) ==
default_inverse_retraction_method(M)
@test default_estimation_method(M, inverse_retract, DefaultPoint) ==
default_inverse_retraction_method(M)
@test default_estimation_method(M, inverse_retract!) ==
default_inverse_retraction_method(M)
@test default_estimation_method(M, inverse_retract!, DefaultPoint) ==
default_inverse_retraction_method(M)
# Vector Transsports – all 3: to
@test default_estimation_method(M, vector_transport_to) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_to, DefaultPoint) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_to!) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_to!, DefaultPoint) ==
default_vector_transport_method(M)
# along
@test default_estimation_method(M, vector_transport_along) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_along, DefaultPoint) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_along!) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_along!, DefaultPoint) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_direction) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_direction, DefaultPoint) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_direction!) ==
default_vector_transport_method(M)
@test default_estimation_method(M, vector_transport_direction!, DefaultPoint) ==
default_vector_transport_method(M)
end
end
4 changes: 4 additions & 0 deletions test/manifold_fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ end
@test_throws DomainError ODEExponentialRetraction(ExponentialRetraction(), B)
@test_throws ErrorException PadeRetraction(0)
end

@testset "Estimation errors" begin
@test_throws DomainError GeodesicInterpolationWithinRadius(-1)
end

0 comments on commit 2ad570b

Please sign in to comment.