diff --git a/src/RandomFeatureMaps.jl b/src/RandomFeatureMaps.jl index f159ecc..b89d4bd 100644 --- a/src/RandomFeatureMaps.jl +++ b/src/RandomFeatureMaps.jl @@ -128,7 +128,7 @@ end _rof(rof::RandomOrientationFeatures, T1::Rigid, T2::Rigid) = norms(T1(rof.FA) .- T2(rof.FB); dims=1) function (rof::RandomOrientationFeatures)(T1::Rigid, T2::Rigid; pairdim::Union{Nothing,Int}=nothing) - if dims isa Int + if pairdim isa Int T1, T2 = batchunsqueeze(T1, dims=pairdim+1), batchunsqueeze(T2, dims=pairdim) end _rof(rof, T1, T2) diff --git a/test/GraphNeuralNetworksExt.jl b/test/GraphNeuralNetworksExt.jl index 9d54a48..0cc9e61 100644 --- a/test/GraphNeuralNetworksExt.jl +++ b/test/GraphNeuralNetworksExt.jl @@ -9,7 +9,7 @@ using GraphNeuralNetworks g = ones(Bool, n, n) graph = GNNGraph(g, graph_type=:dense) rof = RandomOrientationFeatures(dim, 0.1f0) - @test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :) + @test rof(rigid, graph) == reshape(rof(rigid; pairdim=1), dim, :) end @testset "Random edges" begin @@ -20,7 +20,7 @@ using GraphNeuralNetworks graph = GNNGraph(g, graph_type=:dense) rof = RandomOrientationFeatures(dim, 0.1f0) @test size(rof(rigid, rigid, graph), 2) == count(g) - @test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :)[:,findall(vec(g))] + @test rof(rigid, graph) == reshape(rof(rigid; pairdim=1), dim, :)[:,findall(vec(g))] @test rof(rigid, graph) == rof(rigid, rigid, graph) end diff --git a/test/runtests.jl b/test/runtests.jl index e48fc1a..7fd176f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,8 +18,8 @@ using BatchedTransformations k = 5 rof = RandomOrientationFeatures(dim, 0.1f0) rigid = rand_rigid(Float32, (n,k)) - @test rof(rigid, dims=1) |> size == (dim, n, n, k) - @test rof((rand(3,3,n,k), rand(3,1,n,k)), dims=1) |> size == (dim, n, n, k) + @test rof(rigid, pairdim=1) |> size == (dim, n, n, k) + @test rof((rand(3,3,n,k), rand(3,1,n,k)), pairdim=1) |> size == (dim, n, n, k) @test rof(rigid) == rof(rigid, rigid) end