Skip to content

Commit

Permalink
Overhaul ROF calling, with optional "pairdim" keyword argument
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Sep 18, 2024
1 parent 0b70448 commit 55a66e5
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 52 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "RandomFeatureMaps"
uuid = "780baa95-dd42-481b-93db-80fe3d88832c"
authors = ["murrellb <[email protected]> and contributors"]
version = "0.1.1"
version = "0.2.0"

[deps]
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

[weakdeps]
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Expand All @@ -14,9 +15,10 @@ GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
GraphNeuralNetworksExt = "GraphNeuralNetworks"

[compat]
BatchedTransformations = "0.4"
Flux = "0.13, 0.14"
BatchedTransformations = "0.5"
Functors = "0.4"
GraphNeuralNetworks = "0.6"
Optimisers = "0.3"
julia = "1.9"

[extras]
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# RandomFeatureMaps

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/dev/)
[![Build Status](https://github.com/MurrellGroup/RandomFeatureMaps.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/RandomFeatureMaps.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/MurrellGroup/RandomFeatureMaps.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/RandomFeatureMaps.jl)
Expand Down
14 changes: 7 additions & 7 deletions ext/GraphNeuralNetworksExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ using RandomFeatureMaps
using BatchedTransformations
using GraphNeuralNetworks

using RandomFeatureMaps: norms

subt(xi, xj, e) = xj .- xi
function (rof::RandomOrientationFeatures)(rigid::Rigid, graph::GNNGraph)
points1 = rigid * rof.FA
points2 = rigid * rof.FB
diffs = apply_edges(subt, graph, xi=points2, xj=points1)
return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1)
function (rof::RandomOrientationFeatures)(T1::Rigid, T2::Rigid, graph::GNNGraph)
@assert length(batchsize(linear(T1))) == 1 && batchsize(linear(T1)) == batchsize(linear(T2))
diffs = apply_edges(subt, graph, xj=T1(rof.FA), xi=T2(rof.FB))
norms(diffs; dims=1)
end

# deprecated
(rof::RandomOrientationFeatures)(graph::GNNGraph, rigid) = rof(rigid, graph)
(rof::RandomOrientationFeatures)(T, graph::GNNGraph) = rof(T, T, graph)

end
97 changes: 74 additions & 23 deletions src/RandomFeatureMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@ module RandomFeatureMaps

export RandomFourierFeatures
export RandomOrientationFeatures
export rand_rigid, construct_rigid
export pairwiserof
export rand_rigid, get_rigid

using Flux: @functor, Optimisers, unsqueeze
using Functors: @functor
import Optimisers

using BatchedTransformations

sumdrop(f, A::AbstractArray; dims) = dropdims(sum(f, A; dims); dims)
norms(A::AbstractArray; dims) = sqrt.(sumdrop(abs2, A; dims))

"""
RandomFourierFeatures(n => m, σ)
Maps `n`-dimensional data and projects it to `m`-dimensional random fourier features.
This type has no trainable parameters.
## Examples
```jldoctest
Expand Down Expand Up @@ -42,8 +49,8 @@ function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::Abstrac
end

function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T<:Real
WtX = rff.W'X
return [cos.(WtX); sin.(WtX)]
Y = rff.W'X
return [cos.(Y); sin.(Y)]
end

function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real
Expand All @@ -53,32 +60,54 @@ function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T<:Real
return Y
end

rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size)

function construct_rigid(R::AbstractArray, t::AbstractArray)
batch_size = size(R)[3:end]
t = reshape(t, 3, 1, batch_size...)
Translation(t) Rotation(R)
end

"""
RandomOrientationFeatures
Can be called on rigid transformations to create pairwise maps of random orientation features.
Holds two random matrices which are used to embed rigid transformations.
This type has no trainable parameters.
## Methods
- `(::RandomOrientationFeatures)(rigid1, rigid2)`: returns the distances between the corresponding
rigid transformations, embedded using the two random matrices of the random orientation features.
- `(::RandomOrientationFeatures)(rigid1, rigid2; dims::Int)`: unsqueezes batch dimension `dim+1`
of `rigid1` and `dim` of `rigid2` to broadcast the `rof` call and produce a pairwise map.
- `(::RandomOrientationFeatures)(rigid1, rigid2, graph::GraphNeuralNetworks.GNNGraph)`: similar to
the first method, but takes two sets rigid transformations of equal size and unrolls a graph to
get the pairs of rigid transformations. Equivalent to the second method (with broadcasted dimensions
flattened) when the graph is complete.
Each of these have single rigid argument methods for when `rigid1 == rigid2`, i.e. `rof(rigid)`
## Examples
```jldoctest

Check failure on line 88 in src/RandomFeatureMaps.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:88-108 ```jldoctest julia> rof = RandomOrientationFeatures(10, 0.1f0); julia> rigid = rand_rigid(Float32, (2, 3)); julia> rof(rigid, rigid) |> size (10, 4, 3) julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); julia> rof(rigid1, rigid2; dims=1) |> size (10, 4, 3, 2) julia> using GraphNeuralNetworks julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) julia> rigid = rand_rigid(Float32, (4,)); julia> rof(rigid, graph) |> size ``` Subexpression: rof(rigid, rigid) |> size Evaluated output: ERROR: UndefVarError: `dims` not defined Stacktrace: [1] #_#5 @ ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:131 [inlined] [2] (::RandomOrientationFeatures{Matrix{Float32}})(T1::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, T2::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}) @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:130 [3] top-level scope @ none:1 Expected output: (10, 4, 3) diff = Warning: Diff output requires color. (10, 4, 3)ERROR: UndefVarError: `dims` not defined Stacktrace: [1] #_#5 @ ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:131 [inlined] [2] (::RandomOrientationFeatures{Matrix{Float32}})(T1::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, T2::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}) @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:130 [3] top-level scope @ none:1

Check failure on line 88 in src/RandomFeatureMaps.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:88-108 ```jldoctest julia> rof = RandomOrientationFeatures(10, 0.1f0); julia> rigid = rand_rigid(Float32, (2, 3)); julia> rof(rigid, rigid) |> size (10, 4, 3) julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); julia> rof(rigid1, rigid2; dims=1) |> size (10, 4, 3, 2) julia> using GraphNeuralNetworks julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) julia> rigid = rand_rigid(Float32, (4,)); julia> rof(rigid, graph) |> size ``` Subexpression: rof(rigid1, rigid2; dims=1) |> size Evaluated output: ERROR: MethodError: no method matching (::RandomOrientationFeatures{Matrix{Float32}})(::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}; dims::Int64) Closest candidates are: (::RandomOrientationFeatures)(::BatchedTransformations.Rigid, ::BatchedTransformations.Rigid; pairdim) got unsupported keyword argument "dims" @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:130 (::RandomOrientationFeatures)(::Any; kwargs...) @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:137 Stacktrace: [1] kwerr(::@NamedTuple{dims::Int64}, ::RandomOrientationFeatures{Matrix{Float32}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}) @ Base ./error.jl:165 [2] top-level scope @ none:1 Expected output: (10, 4, 3, 2) diff = Warning: Diff output requires color. (10, 4, 3, 2)ERROR: MethodError: no method matching (::RandomOrientationFeatures{Matrix{Float32}})(::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}; dims::Int64) Closest candidates are: (::RandomOrientationFeatures)(::BatchedTransformations.Rigid, ::BatchedTransformations.Rigid; pairdim) got unsupported keyword argument "dims" @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:130 (::RandomOrientationFeatures)(::Any; kwargs...) @ RandomFeatureMaps ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:137 Stacktrace: [1] kwerr(::@NamedTuple{dims::Int64}, ::RandomOrientationFeatures{Matrix{Float32}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}, ::BatchedTransformations.Rigid{BatchedTransformations.Translation{Array{Float32, 4}}, BatchedTransformations.Rotation{Array{Float32, 4}}}) @ Base ./error.jl:165 [2] top-level scope @ none:1

Check failure on line 88 in src/RandomFeatureMaps.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:88-108 ```jldoctest julia> rof = RandomOrientationFeatures(10, 0.1f0); julia> rigid = rand_rigid(Float32, (2, 3)); julia> rof(rigid, rigid) |> size (10, 4, 3) julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); julia> rof(rigid1, rigid2; dims=1) |> size (10, 4, 3, 2) julia> using GraphNeuralNetworks julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) julia> rigid = rand_rigid(Float32, (4,)); julia> rof(rigid, graph) |> size ``` Subexpression: using GraphNeuralNetworks Evaluated output: ERROR: ArgumentError: Package GraphNeuralNetworks not found in current path. - Run `import Pkg; Pkg.add("GraphNeuralNetworks")` to install the GraphNeuralNetworks package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746 Expected output: diff = Warning: Diff output requires color. ERROR: ArgumentError: Package GraphNeuralNetworks not found in current path. - Run `import Pkg; Pkg.add("GraphNeuralNetworks")` to install the GraphNeuralNetworks package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746

Check failure on line 88 in src/RandomFeatureMaps.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:88-108 ```jldoctest julia> rof = RandomOrientationFeatures(10, 0.1f0); julia> rigid = rand_rigid(Float32, (2, 3)); julia> rof(rigid, rigid) |> size (10, 4, 3) julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); julia> rof(rigid1, rigid2; dims=1) |> size (10, 4, 3, 2) julia> using GraphNeuralNetworks julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) julia> rigid = rand_rigid(Float32, (4,)); julia> rof(rigid, graph) |> size ``` Subexpression: graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) Evaluated output: ERROR: UndefVarError: `GNNGraph` not defined Stacktrace: [1] top-level scope @ none:1 Expected output: diff = Warning: Diff output requires color. ERROR: UndefVarError: `GNNGraph` not defined Stacktrace: [1] top-level scope @ none:1

Check failure on line 88 in src/RandomFeatureMaps.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/RandomFeatureMaps.jl/RandomFeatureMaps.jl/src/RandomFeatureMaps.jl:88-108 ```jldoctest julia> rof = RandomOrientationFeatures(10, 0.1f0); julia> rigid = rand_rigid(Float32, (2, 3)); julia> rof(rigid, rigid) |> size (10, 4, 3) julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2)); julia> rof(rigid1, rigid2; dims=1) |> size (10, 4, 3, 2) julia> using GraphNeuralNetworks julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense) julia> rigid = rand_rigid(Float32, (4,)); julia> rof(rigid, graph) |> size ``` Subexpression: rof(rigid, graph) |> size Evaluated output: ERROR: UndefVarError: `graph` not defined Stacktrace: [1] top-level scope @ none:1 Expected output: diff = Warning: Diff output requires color. ERROR: UndefVarError: `graph` not defined Stacktrace: [1] top-level scope @ none:1
julia> rof = RandomOrientationFeatures(4, 0.1);
julia> rof = RandomOrientationFeatures(10, 0.1f0);
julia> rigid = rand_rigid(Float32, (2, 3));
julia> rof(rigid, rigid) |> size
(10, 4, 3)
julia> rigid1, rigid2 = rand_rigid(Float32, (4, 2)), rand_rigid(Float32, (3, 2));
julia> rigid = (randn(3, 3, 2), randn(3, 1, 2)); # cba to make it orthonormal and whatevs
julia> rof(rigid1, rigid2; dims=1) |> size
(10, 4, 3, 2)
julia> rof(rigid) |> size
(4, 2, 2)
julia> using GraphNeuralNetworks
julia> graph = GNNGraph(rand(Bool, 4, 4), graph_type=:dense)
julia> rigid = rand_rigid(Float32, (4,));
julia> rof(rigid, graph) |> size
```
"""
struct RandomOrientationFeatures{T<:Real,A<:AbstractMatrix{T}}
struct RandomOrientationFeatures{A<:AbstractArray{<:Real}}
FA::A
FB::A
end
Expand All @@ -96,13 +125,35 @@ function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat)
return RandomOrientationFeatures(randn(typeof(σ), 3, dim) * σ, randn(typeof(σ), 3, dim) * σ)
end

function (rof::RandomOrientationFeatures)(rigid::Rigid)
points1 = rigid * rof.FA
points2 = rigid * rof.FB
diffs = unsqueeze(points1, dims=4) .- unsqueeze(points2, dims=3)
return dropdims(sqrt.(sum(abs2, diffs; dims=1)); dims=1)
_rof(rof::RandomOrientationFeatures, T1::Rigid, T2::Rigid) = norms(T1(rof.FA) .- T2(rof.FB); dims=1)

function (rof::RandomOrientationFeatures)(T1::Rigid, T2::Rigid; pairdim::Union{Nothing,Int}=nothing)
if dims isa Int
T1, T2 = batchunsqueeze(T1, dims=pairdim+1), batchunsqueeze(T2, dims=pairdim)
end
_rof(rof, T1, T2)
end

(rof::RandomOrientationFeatures)(T; kwargs...) = rof(T, T; kwargs...)


rand_rigid(T::Type, batch_size::Dims) = rand(T, Rigid, 3, batch_size)

"""
get_rigid(R::AbstractArray, t::AbstractArray)
Converts a rotation `R` and translation `t` to a `BatchedTransformations.Rigid`, designed to
handle batch dimensions.
The transformation gets applied according to `NNlib.batched_mul(R, x) .+ t`
"""
function get_rigid(R::AbstractArray, t::AbstractArray)
batch_size = size(R)[3:end]
t = reshape(t, 3, 1, batch_size...)
Translation(t) Rotation(R)
end

(rof::RandomOrientationFeatures)((R, t)::Tuple{AbstractArray,AbstractArray}, args...) = rof(construct_rigid(R, t), args...)
(rof::RandomOrientationFeatures)(T1::Tuple, T2::Tuple, args...; kwargs...) =
rof(get_rigid(T1...), get_rigid(T2...), args...; kwargs...)

end
27 changes: 13 additions & 14 deletions test/GraphNeuralNetworksExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,26 @@ using GraphNeuralNetworks

@testset "GraphNeuralNetworksExt.jl" begin

@testset "Equivalence" begin
m = 10
n = 8
@testset "Complete graph equivalence" begin
dim = 10
n = 5
rigid = rand_rigid(Float32, (n,))
g = ones(Bool, n, n)
graph = GNNGraph(g, graph_type=:dense)
rof = RandomOrientationFeatures(m, 0.1f0)
@test rof(rigid, graph) == reshape(rof(rigid), m, :)

@test rof((rand(3,3,n), rand(3,1,n)), graph) |> size == (m, n^2) # deprecated
rof = RandomOrientationFeatures(dim, 0.1f0)
@test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :)
end

@testset "Random graph" begin
m = 10
n = 8
@testset "Random edges" begin
dim = 10
n = 5
rigid = rand_rigid(Float32, (n,))
g = rand(Bool, n, n)
graph = GNNGraph(g, graph_type=:dense)
rof = RandomOrientationFeatures(m, 0.1f0)
@test size(rof(rigid, graph), 2) == count(g)
@test rof(rigid, graph) == reshape(rof(rigid), m, :)[:,findall(vec(g))]
rof = RandomOrientationFeatures(dim, 0.1f0)
@test size(rof(rigid, rigid, graph), 2) == count(g)
@test rof(rigid, graph) == reshape(rof(rigid; dims=1), dim, :)[:,findall(vec(g))]
@test rof(rigid, graph) == rof(rigid, rigid, graph)
end

end
end
11 changes: 8 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@ using BatchedTransformations
end

@testset "RandomOrientationFeatures" begin
rof = RandomOrientationFeatures(10, 0.1f0)
@test rof(rand_rigid(Float32, (4,2))) |> size == (10, 4, 4, 2)
@test rof((rand(3,3,4,2), rand(3,1,4,2))) |> size == (10, 4, 4, 2) # deprecated
dim = 10
n = 8
k = 5
rof = RandomOrientationFeatures(dim, 0.1f0)
rigid = rand_rigid(Float32, (n,k))
@test rof(rigid, dims=1) |> size == (dim, n, n, k)
@test rof((rand(3,3,n,k), rand(3,1,n,k)), dims=1) |> size == (dim, n, n, k)
@test rof(rigid) == rof(rigid, rigid)
end

include("GraphNeuralNetworksExt.jl")
Expand Down

0 comments on commit 55a66e5

Please sign in to comment.