Skip to content

Commit

Permalink
Add GraphNeuralNetworks dep, use Flux, BatchedTransformations
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Sep 16, 2024
1 parent 9cb21ed commit 0b70448
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 49 deletions.
22 changes: 14 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
name = "RandomFeatureMaps"
uuid = "780baa95-dd42-481b-93db-80fe3d88832c"
authors = ["murrellb <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

[weakdeps]
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"

[extensions]
GraphNeuralNetworksExt = "GraphNeuralNetworks"

[compat]
Functors = "0.4"
NNlib = "0.9"
Optimisers = "0.3"
BatchedTransformations = "0.4"
Flux = "0.13, 0.14"
GraphNeuralNetworks = "0.6"
julia = "1.9"

[extras]
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "GraphNeuralNetworks"]
18 changes: 18 additions & 0 deletions ext/GraphNeuralNetworksExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module GraphNeuralNetworksExt

using RandomFeatureMaps
using BatchedTransformations
using GraphNeuralNetworks

subt(xi, xj, e) = xj .- xi
function (rof::RandomOrientationFeatures)(rigid::Rigid, graph::GNNGraph)
points1 = rigid * rof.FA
points2 = rigid * rof.FB
diffs = apply_edges(subt, graph, xi=points2, xj=points1)
return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1)
end

# deprecated
(rof::RandomOrientationFeatures)(graph::GNNGraph, rigid) = rof(rigid, graph)

end
82 changes: 47 additions & 35 deletions src/RandomFeatureMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,37 @@ module RandomFeatureMaps

export RandomFourierFeatures
export RandomOrientationFeatures
export rand_rigid, construct_rigid

using NNlib: batched_mul
using Functors: @functor
import Optimisers
using Flux: @functor, Optimisers, unsqueeze

using BatchedTransformations

"""
RandomFourierFeatures(n => m, σ)
Maps `n`-dimensional data and projects it to `m`-dimensional random fourier features.
## Example
## Examples
```jldoctest
julia> rff = RandomFourierFeatures(2 => 4, 1.0); # maps 2D data to 4D
julia> rff(rand(2, 3)) |> size # 3 samples
(4, 3)
julia> rff(rand(2, 3, 5)) |> size # extra batch dim
(4, 3, 5)
```
"""
struct RandomFourierFeatures{T <: Real, A <: AbstractMatrix{T}}
struct RandomFourierFeatures{T<:Real,A<:AbstractMatrix{T}}
W::A
end

@functor RandomFourierFeatures
Optimisers.trainable(::RandomFourierFeatures) = (;) # no trainable parameters

RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) =
RandomFourierFeatures(dims, float(σ))
RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = RandomFourierFeatures(dims, float(σ))

# d1: input dimension, d2: output dimension (d1 => d2)
function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::AbstractFloat)
Expand All @@ -38,59 +41,68 @@ function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::Abstrac
return RandomFourierFeatures(randn(typeof(σ), d1, d2 ÷ 2) * σ * oftype(σ, 2π))
end

function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T <: Real
function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T<:Real
WtX = rff.W'X
return [cos.(WtX); sin.(WtX)]
end

function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T <: Real
function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real
X′ = reshape(X, size(X, 1), :)
Y′ = rff(X′)
Y = reshape(Y′, :, size(X)[2:end]...)
return Y
end

rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size)

function construct_rigid(R::AbstractArray, t::AbstractArray)
batch_size = size(R)[3:end]
t = reshape(t, 3, 1, batch_size...)
Translation(t) Rotation(R)
end

"""
RandomOrientationFeatures(m, σ)
RandomOrientationFeatures
Can be called on rigid transformations to create pairwise maps of random orientation features.
This type has no trainable parameters.
Projects rigid transformations them to `m` features.
These will be the pairwise distances between points.
## Examples
```jldoctest
julia> rof = RandomOrientationFeatures(4, 0.1);
julia> rigid = (randn(3, 3, 2), randn(3, 1, 2)); # cba to make it orthonormal and whatevs
julia> rof(rigid) |> size
(4, 2, 2)
```
"""
struct RandomOrientationFeatures{A <: AbstractArray{<:Real}}
struct RandomOrientationFeatures{T<:Real,A<:AbstractMatrix{T}}
FA::A
FB::A
end

@functor RandomOrientationFeatures
Optimisers.trainable(::RandomOrientationFeatures) = (;) # no trainable parameters

# should it just have a single array? such that pairwise distances require two of these
"""
RandomOrientationFeatures(m, σ)
Creates a `RandomOrientationFeatures` instance, mapping to `m` features.
"""
function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat)
isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive"))
return RandomOrientationFeatures(randn(typeof(σ), 3, dim, 1) * σ, randn(typeof(σ), 3, dim, 1) * σ)
return RandomOrientationFeatures(randn(typeof(σ), 3, dim) * σ, randn(typeof(σ), 3, dim) * σ)
end

### For non-graph version, with batch dim
function transform_rigid(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T
x′ = reshape(x, 3, size(x, 2), :)
R′ = reshape(R, 3, 3, :)
t′ = reshape(t, 3, 1, :)
y′ = batched_mul(R′, x′) .+ t′
y = reshape(y′, 3, size(x, 2), size(R)[3:end]...)
return y
end

function (rof::RandomOrientationFeatures)(rigid::Tuple{AbstractArray, AbstractArray})
dim = size(rof.FA, 2)
Nr, batch... = size(rigid[1])[3:end]
p1 = reshape(transform_rigid(rof.FA, rigid...), 3, dim, Nr, batch...)
p2 = reshape(transform_rigid(rof.FB, rigid...), 3, dim, Nr, batch...)
return dropdims(sqrt.(sum(abs2,
reshape(p1, 3, dim, Nr, 1, batch...) .-
reshape(p2, 3, dim, 1, Nr, batch...),
dims=1)), dims=1)
function (rof::RandomOrientationFeatures)(rigid::Rigid)
points1 = rigid * rof.FA
points2 = rigid * rof.FB
diffs = unsqueeze(points1, dims=4) .- unsqueeze(points2, dims=3)
return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1)
end

### TODO: GRAPH version - see https://github.com/MurrellGroup/RandomFeatures.jl/blob/main/src/RandomFeatures.jl#L91
(rof::RandomOrientationFeatures)((R, t)::Tuple{AbstractArray,AbstractArray}, args...) = rof(construct_rigid(R, t), args...)

end
28 changes: 28 additions & 0 deletions test/GraphNeuralNetworksExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using GraphNeuralNetworks

@testset "GraphNeuralNetworksExt.jl" begin

@testset "Equivalence" begin
m = 10
n = 8
rigid = rand_rigid(Float32, (n,))
g = ones(Bool, n, n)
graph = GNNGraph(g, graph_type=:dense)
rof = RandomOrientationFeatures(m, 0.1f0)
@test rof(rigid, graph) == reshape(rof(rigid), m, :)

@test rof((rand(3,3,n), rand(3,1,n)), graph) |> size == (m, n^2) # deprecated
end

@testset "Random graph" begin
m = 10
n = 8
rigid = rand_rigid(Float32, (n,))
g = rand(Bool, n, n)
graph = GNNGraph(g, graph_type=:dense)
rof = RandomOrientationFeatures(m, 0.1f0)
@test size(rof(rigid, graph), 2) == count(g)
@test rof(rigid, graph) == reshape(rof(rigid), m, :)[:,findall(vec(g))]
end

end
23 changes: 17 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
using RandomFeatureMaps
using Test

using BatchedTransformations

@testset "RandomFeatureMaps.jl" begin
rff = RandomFourierFeatures(10 => 20, 0.1f0)
x = randn(Float32, 10, 4)
@test rff(x) |> size == (20, 4)
@test rff(reshape(x, 10, 2, 2)) |> size == (20, 2, 2)

rof = RandomOrientationFeatures(10, 0.1f0)
@test rof((randn(Float32, 3, 3, 4, 2), randn(Float32, 3, 1, 4, 2))) |> size == (10, 4, 4, 2)
@testset "RandomFourierFeatures" begin
rff = RandomFourierFeatures(10 => 20, 0.1f0)
x = randn(Float32, 10, 4)
@test rff(x) |> size == (20, 4)
@test rff(reshape(x, 10, 2, 2)) |> size == (20, 2, 2)
end

@testset "RandomOrientationFeatures" begin
rof = RandomOrientationFeatures(10, 0.1f0)
@test rof(rand_rigid(Float32, (4,2))) |> size == (10, 4, 4, 2)
@test rof((rand(3,3,4,2), rand(3,1,4,2))) |> size == (10, 4, 4, 2) # deprecated
end

include("GraphNeuralNetworksExt.jl")

end

2 comments on commit 0b70448

@AntonOresten
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Add GraphNeuralNetworks extension for passing graphs along with rigid transformations in RandomOrientationFeatures calls.
  • Use BatchedTransformations to represent rigid transformations internally. Old methods still work, but should be deprecated.
  • Replace Functors, Optimisers, NNlib dependencies with Flux.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115265

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" 0b704488869b996255fba20c36091f7b91857a01
git push origin v0.1.1

Please sign in to comment.