diff --git a/Project.toml b/Project.toml index 9651550..e011f6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,27 @@ name = "RandomFeatureMaps" uuid = "780baa95-dd42-481b-93db-80fe3d88832c" authors = ["murrellb and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" + +[weakdeps] +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" + +[extensions] +GraphNeuralNetworksExt = "GraphNeuralNetworks" [compat] -Functors = "0.4" -NNlib = "0.9" -Optimisers = "0.3" +BatchedTransformations = "0.4" +Flux = "0.13, 0.14" +GraphNeuralNetworks = "0.6" julia = "1.9" [extras] +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "GraphNeuralNetworks"] diff --git a/ext/GraphNeuralNetworksExt.jl b/ext/GraphNeuralNetworksExt.jl new file mode 100644 index 0000000..b139b6f --- /dev/null +++ b/ext/GraphNeuralNetworksExt.jl @@ -0,0 +1,18 @@ +module GraphNeuralNetworksExt + +using RandomFeatureMaps +using BatchedTransformations +using GraphNeuralNetworks + +subt(xi, xj, e) = xj .- xi +function (rof::RandomOrientationFeatures)(rigid::Rigid, graph::GNNGraph) + points1 = rigid * rof.FA + points2 = rigid * rof.FB + diffs = apply_edges(subt, graph, xi=points2, xj=points1) + return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1) +end + +# deprecated +(rof::RandomOrientationFeatures)(graph::GNNGraph, rigid) = rof(rigid, graph) + +end \ No newline at end of file diff --git a/src/RandomFeatureMaps.jl b/src/RandomFeatureMaps.jl index 1c6b836..5de3a4d 100644 --- a/src/RandomFeatureMaps.jl +++ b/src/RandomFeatureMaps.jl @@ -2,34 +2,37 @@ module RandomFeatureMaps export RandomFourierFeatures export RandomOrientationFeatures +export rand_rigid, construct_rigid -using NNlib: batched_mul -using Functors: @functor -import Optimisers +using Flux: @functor, Optimisers, unsqueeze + +using BatchedTransformations """ RandomFourierFeatures(n => m, σ) Maps `n`-dimensional data and projects it to `m`-dimensional random fourier features. -## Example +## Examples ```jldoctest julia> rff = RandomFourierFeatures(2 => 4, 1.0); # maps 2D data to 4D julia> rff(rand(2, 3)) |> size # 3 samples (4, 3) + +julia> rff(rand(2, 3, 5)) |> size # extra batch dim +(4, 3, 5) ``` """ -struct RandomFourierFeatures{T <: Real, A <: AbstractMatrix{T}} +struct RandomFourierFeatures{T<:Real,A<:AbstractMatrix{T}} W::A end @functor RandomFourierFeatures Optimisers.trainable(::RandomFourierFeatures) = (;) # no trainable parameters -RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = - RandomFourierFeatures(dims, float(σ)) +RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = RandomFourierFeatures(dims, float(σ)) # d1: input dimension, d2: output dimension (d1 => d2) function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::AbstractFloat) @@ -38,25 +41,44 @@ function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::Abstrac return RandomFourierFeatures(randn(typeof(σ), d1, d2 ÷ 2) * σ * oftype(σ, 2π)) end -function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T <: Real +function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T<:Real WtX = rff.W'X return [cos.(WtX); sin.(WtX)] end -function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T <: Real +function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real X′ = reshape(X, size(X, 1), :) Y′ = rff(X′) Y = reshape(Y′, :, size(X)[2:end]...) return Y end +rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size) + +function construct_rigid(R::AbstractArray, t::AbstractArray) + batch_size = size(R)[3:end] + t = reshape(t, 3, 1, batch_size...) + Translation(t) ∘ Rotation(R) +end + """ - RandomOrientationFeatures(m, σ) + RandomOrientationFeatures + +Can be called on rigid transformations to create pairwise maps of random orientation features. +This type has no trainable parameters. -Projects rigid transformations them to `m` features. -These will be the pairwise distances between points. +## Examples + +```jldoctest +julia> rof = RandomOrientationFeatures(4, 0.1); + +julia> rigid = (randn(3, 3, 2), randn(3, 1, 2)); # cba to make it orthonormal and whatevs + +julia> rof(rigid) |> size +(4, 2, 2) +``` """ -struct RandomOrientationFeatures{A <: AbstractArray{<:Real}} +struct RandomOrientationFeatures{T<:Real,A<:AbstractMatrix{T}} FA::A FB::A end @@ -64,33 +86,23 @@ end @functor RandomOrientationFeatures Optimisers.trainable(::RandomOrientationFeatures) = (;) # no trainable parameters -# should it just have a single array? such that pairwise distances require two of these +""" + RandomOrientationFeatures(m, σ) + +Creates a `RandomOrientationFeatures` instance, mapping to `m` features. +""" function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat) isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive")) - return RandomOrientationFeatures(randn(typeof(σ), 3, dim, 1) * σ, randn(typeof(σ), 3, dim, 1) * σ) + return RandomOrientationFeatures(randn(typeof(σ), 3, dim) * σ, randn(typeof(σ), 3, dim) * σ) end -### For non-graph version, with batch dim -function transform_rigid(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T - x′ = reshape(x, 3, size(x, 2), :) - R′ = reshape(R, 3, 3, :) - t′ = reshape(t, 3, 1, :) - y′ = batched_mul(R′, x′) .+ t′ - y = reshape(y′, 3, size(x, 2), size(R)[3:end]...) - return y -end - -function (rof::RandomOrientationFeatures)(rigid::Tuple{AbstractArray, AbstractArray}) - dim = size(rof.FA, 2) - Nr, batch... = size(rigid[1])[3:end] - p1 = reshape(transform_rigid(rof.FA, rigid...), 3, dim, Nr, batch...) - p2 = reshape(transform_rigid(rof.FB, rigid...), 3, dim, Nr, batch...) - return dropdims(sqrt.(sum(abs2, - reshape(p1, 3, dim, Nr, 1, batch...) .- - reshape(p2, 3, dim, 1, Nr, batch...), - dims=1)), dims=1) +function (rof::RandomOrientationFeatures)(rigid::Rigid) + points1 = rigid * rof.FA + points2 = rigid * rof.FB + diffs = unsqueeze(points1, dims=4) .- unsqueeze(points2, dims=3) + return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1) end -### TODO: GRAPH version - see https://github.com/MurrellGroup/RandomFeatures.jl/blob/main/src/RandomFeatures.jl#L91 +(rof::RandomOrientationFeatures)((R, t)::Tuple{AbstractArray,AbstractArray}, args...) = rof(construct_rigid(R, t), args...) end diff --git a/test/GraphNeuralNetworksExt.jl b/test/GraphNeuralNetworksExt.jl new file mode 100644 index 0000000..ba990ee --- /dev/null +++ b/test/GraphNeuralNetworksExt.jl @@ -0,0 +1,28 @@ +using GraphNeuralNetworks + +@testset "GraphNeuralNetworksExt.jl" begin + + @testset "Equivalence" begin + m = 10 + n = 8 + rigid = rand_rigid(Float32, (n,)) + g = ones(Bool, n, n) + graph = GNNGraph(g, graph_type=:dense) + rof = RandomOrientationFeatures(m, 0.1f0) + @test rof(rigid, graph) == reshape(rof(rigid), m, :) + + @test rof((rand(3,3,n), rand(3,1,n)), graph) |> size == (m, n^2) # deprecated + end + + @testset "Random graph" begin + m = 10 + n = 8 + rigid = rand_rigid(Float32, (n,)) + g = rand(Bool, n, n) + graph = GNNGraph(g, graph_type=:dense) + rof = RandomOrientationFeatures(m, 0.1f0) + @test size(rof(rigid, graph), 2) == count(g) + @test rof(rigid, graph) == reshape(rof(rigid), m, :)[:,findall(vec(g))] + end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8ef4702..62f9228 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,23 @@ using RandomFeatureMaps using Test +using BatchedTransformations + @testset "RandomFeatureMaps.jl" begin - rff = RandomFourierFeatures(10 => 20, 0.1f0) - x = randn(Float32, 10, 4) - @test rff(x) |> size == (20, 4) - @test rff(reshape(x, 10, 2, 2)) |> size == (20, 2, 2) - rof = RandomOrientationFeatures(10, 0.1f0) - @test rof((randn(Float32, 3, 3, 4, 2), randn(Float32, 3, 1, 4, 2))) |> size == (10, 4, 4, 2) + @testset "RandomFourierFeatures" begin + rff = RandomFourierFeatures(10 => 20, 0.1f0) + x = randn(Float32, 10, 4) + @test rff(x) |> size == (20, 4) + @test rff(reshape(x, 10, 2, 2)) |> size == (20, 2, 2) + end + + @testset "RandomOrientationFeatures" begin + rof = RandomOrientationFeatures(10, 0.1f0) + @test rof(rand_rigid(Float32, (4,2))) |> size == (10, 4, 4, 2) + @test rof((rand(3,3,4,2), rand(3,1,4,2))) |> size == (10, 4, 4, 2) # deprecated + end + + include("GraphNeuralNetworksExt.jl") + end