diff --git a/Project.toml b/Project.toml index e011f6b..0e668a6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "RandomFeatureMaps" uuid = "780baa95-dd42-481b-93db-80fe3d88832c" authors = ["murrellb and contributors"] -version = "0.1.1" +version = "0.2.0" [deps] BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" [weakdeps] GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" @@ -14,9 +15,10 @@ GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" GraphNeuralNetworksExt = "GraphNeuralNetworks" [compat] -BatchedTransformations = "0.4" -Flux = "0.13, 0.14" +BatchedTransformations = "0.5" +Functors = "0.4" GraphNeuralNetworks = "0.6" +Optimisers = "0.3" julia = "1.9" [extras] diff --git a/README.md b/README.md index d22f020..ba1a6f1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # RandomFeatureMaps -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/dev/) [![Build Status](https://github.com/MurrellGroup/RandomFeatureMaps.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/RandomFeatureMaps.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/RandomFeatureMaps.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/RandomFeatureMaps.jl) diff --git a/ext/GraphNeuralNetworksExt.jl b/ext/GraphNeuralNetworksExt.jl index b139b6f..710b780 100644 --- a/ext/GraphNeuralNetworksExt.jl +++ b/ext/GraphNeuralNetworksExt.jl @@ -4,15 +4,15 @@ using RandomFeatureMaps using BatchedTransformations using GraphNeuralNetworks +using RandomFeatureMaps: norms + subt(xi, xj, e) = xj .- xi -function (rof::RandomOrientationFeatures)(rigid::Rigid, graph::GNNGraph) - points1 = rigid * rof.FA - points2 = rigid * rof.FB - diffs = apply_edges(subt, graph, xi=points2, xj=points1) - return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1) +function (rof::RandomOrientationFeatures)(T1::Rigid, T2::Rigid, graph::GNNGraph) + @assert length(batchsize(linear(T1))) == 1 && batchsize(linear(T1)) == batchsize(linear(T2)) + diffs = apply_edges(subt, graph, xj=T1(rof.FA), xi=T2(rof.FB)) + norms(diffs; dims=1) end -# deprecated -(rof::RandomOrientationFeatures)(graph::GNNGraph, rigid) = rof(rigid, graph) +(rof::RandomOrientationFeatures)(T, graph::GNNGraph) = rof(T, T, graph) end \ No newline at end of file diff --git a/src/RandomFeatureMaps.jl b/src/RandomFeatureMaps.jl index 5de3a4d..f159ecc 100644 --- a/src/RandomFeatureMaps.jl +++ b/src/RandomFeatureMaps.jl @@ -2,17 +2,24 @@ module RandomFeatureMaps export RandomFourierFeatures export RandomOrientationFeatures -export rand_rigid, construct_rigid +export pairwiserof +export rand_rigid, get_rigid -using Flux: @functor, Optimisers, unsqueeze +using Functors: @functor +import Optimisers using BatchedTransformations +sumdrop(f, A::AbstractArray; dims) = dropdims(sum(f, A; dims); dims) +norms(A::AbstractArray; dims) = sqrt.(sumdrop(abs2, A; dims)) + """ RandomFourierFeatures(n => m, σ) Maps `n`-dimensional data and projects it to `m`-dimensional random fourier features. +This type has no trainable parameters. + ## Examples ```jldoctest @@ -42,8 +49,8 @@ function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::Abstrac end function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T<:Real - WtX = rff.W'X - return [cos.(WtX); sin.(WtX)] + Y = rff.W'X + return [cos.(Y); sin.(Y)] end function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real @@ -53,32 +60,54 @@ function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real return Y end -rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size) - -function construct_rigid(R::AbstractArray, t::AbstractArray) - batch_size = size(R)[3:end] - t = reshape(t, 3, 1, batch_size...) - Translation(t) ∘ Rotation(R) -end """ RandomOrientationFeatures -Can be called on rigid transformations to create pairwise maps of random orientation features. +Holds two random matrices which are used to embed rigid transformations. + This type has no trainable parameters. +## Methods + +- `(::RandomOrientationFeatures)(rigid1, rigid2)`: returns the distances between the corresponding +rigid transformations, embedded using the two random matrices of the random orientation features. + +- `(::RandomOrientationFeatures)(rigid1, rigid2; dims::Int)`: unsqueezes batch dimension `dim+1` +of `rigid1` and `dim` of `rigid2` to broadcast the `rof` call and produce a pairwise map. + +- `(::RandomOrientationFeatures)(rigid1, rigid2, graph::GraphNeuralNetworks.GNNGraph)`: similar to +the first method, but takes two sets rigid transformations of equal size and unrolls a graph to +get the pairs of rigid transformations. Equivalent to the second method (with broadcasted dimensions +flattened) when the graph is complete. + +Each of these have single rigid argument methods for when `rigid1 == rigid2`, i.e. `rof(rigid)` + ## Examples ```jldoctest -julia> rof = RandomOrientationFeatures(4, 0.1); +julia> rof = RandomOrientationFeatures(10, 0.1f0); + +julia> rigid = rand_rigid(Float32, (2, 3)); + +julia> rof(rigid, rigid) |> size +(10, 4, 3) + +julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); -julia> rigid = (randn(3, 3, 2), randn(3, 1, 2)); # cba to make it orthonormal and whatevs +julia> rof(rigid1, rigid2; dims=1) |> size +(10, 4, 3, 2) -julia> rof(rigid) |> size -(4, 2, 2) +julia> using GraphNeuralNetworks + +julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) + +julia> rigid = rand_rigid(Float32, (4,)); + +julia> rof(rigid, graph) |> size ``` """ -struct RandomOrientationFeatures{T<:Real,A<:AbstractMatrix{T}} +struct RandomOrientationFeatures{A<:AbstractArray{<:Real}} FA::A FB::A end @@ -96,13 +125,35 @@ function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat) return RandomOrientationFeatures(randn(typeof(σ), 3, dim) * σ, randn(typeof(σ), 3, dim) * σ) end -function (rof::RandomOrientationFeatures)(rigid::Rigid) - points1 = rigid * rof.FA - points2 = rigid * rof.FB - diffs = unsqueeze(points1, dims=4) .- unsqueeze(points2, dims=3) - return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1) +_rof(rof::RandomOrientationFeatures, T1::Rigid, T2::Rigid) = norms(T1(rof.FA) .- T2(rof.FB); dims=1) + +function (rof::RandomOrientationFeatures)(T1::Rigid, T2::Rigid; pairdim::Union{Nothing,Int}=nothing) + if dims isa Int + T1, T2 = batchunsqueeze(T1, dims=pairdim+1), batchunsqueeze(T2, dims=pairdim) + end + _rof(rof, T1, T2) +end + +(rof::RandomOrientationFeatures)(T; kwargs...) = rof(T, T; kwargs...) + + +rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size) + +""" + get_rigid(R::AbstractArray, t::AbstractArray) + +Converts a rotation `R` and translation `t` to a `BatchedTransformations.Rigid`, designed to +handle batch dimensions. + +The transformation gets applied according to `NNlib.batched_mul(R, x) .+ t` +""" +function get_rigid(R::AbstractArray, t::AbstractArray) + batch_size = size(R)[3:end] + t = reshape(t, 3, 1, batch_size...) + Translation(t) ∘ Rotation(R) end -(rof::RandomOrientationFeatures)((R, t)::Tuple{AbstractArray,AbstractArray}, args...) = rof(construct_rigid(R, t), args...) +(rof::RandomOrientationFeatures)(T1::Tuple, T2::Tuple, args...; kwargs...) = + rof(get_rigid(T1...), get_rigid(T2...), args...; kwargs...) end diff --git a/test/GraphNeuralNetworksExt.jl b/test/GraphNeuralNetworksExt.jl index ba990ee..9d54a48 100644 --- a/test/GraphNeuralNetworksExt.jl +++ b/test/GraphNeuralNetworksExt.jl @@ -2,27 +2,26 @@ using GraphNeuralNetworks @testset "GraphNeuralNetworksExt.jl" begin - @testset "Equivalence" begin - m = 10 - n = 8 + @testset "Complete graph equivalence" begin + dim = 10 + n = 5 rigid = rand_rigid(Float32, (n,)) g = ones(Bool, n, n) graph = GNNGraph(g, graph_type=:dense) - rof = RandomOrientationFeatures(m, 0.1f0) - @test rof(rigid, graph) == reshape(rof(rigid), m, :) - - @test rof((rand(3,3,n), rand(3,1,n)), graph) |> size == (m, n^2) # deprecated + rof = RandomOrientationFeatures(dim, 0.1f0) + @test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :) end - @testset "Random graph" begin - m = 10 - n = 8 + @testset "Random edges" begin + dim = 10 + n = 5 rigid = rand_rigid(Float32, (n,)) g = rand(Bool, n, n) graph = GNNGraph(g, graph_type=:dense) - rof = RandomOrientationFeatures(m, 0.1f0) - @test size(rof(rigid, graph), 2) == count(g) - @test rof(rigid, graph) == reshape(rof(rigid), m, :)[:,findall(vec(g))] + rof = RandomOrientationFeatures(dim, 0.1f0) + @test size(rof(rigid, rigid, graph), 2) == count(g) + @test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :)[:,findall(vec(g))] + @test rof(rigid, graph) == rof(rigid, rigid, graph) end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 62f9228..e48fc1a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,9 +13,14 @@ using BatchedTransformations end @testset "RandomOrientationFeatures" begin - rof = RandomOrientationFeatures(10, 0.1f0) - @test rof(rand_rigid(Float32, (4,2))) |> size == (10, 4, 4, 2) - @test rof((rand(3,3,4,2), rand(3,1,4,2))) |> size == (10, 4, 4, 2) # deprecated + dim = 10 + n = 8 + k = 5 + rof = RandomOrientationFeatures(dim, 0.1f0) + rigid = rand_rigid(Float32, (n,k)) + @test rof(rigid, dims=1) |> size == (dim, n, n, k) + @test rof((rand(3,3,n,k), rand(3,1,n,k)), dims=1) |> size == (dim, n, n, k) + @test rof(rigid) == rof(rigid, rigid) end include("GraphNeuralNetworksExt.jl")