diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4b54852..e629a75 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,14 +23,12 @@ jobs: fail-fast: false matrix: version: - - '1.10' - '1.9' - 'nightly' os: - ubuntu-latest arch: - x64 - - x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -71,6 +69,6 @@ jobs: shell: julia --project=docs --color=yes {0} run: | using Documenter: DocMeta, doctest - using RandomFeatures - DocMeta.setdocmeta!(RandomFeatures, :DocTestSetup, :(using RandomFeatures); recursive=true) - doctest(RandomFeatures) + using RandomFeatureMaps + DocMeta.setdocmeta!(RandomFeatureMaps, :DocTestSetup, :(using RandomFeatureMaps); recursive=true) + doctest(RandomFeatureMaps) diff --git a/.gitignore b/.gitignore index 0887050..31d573e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem /docs/Manifest.toml /docs/build/ +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 867bb70..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,795 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.0" -manifest_format = "2.0" -project_hash = "a48a09a36849f1caf255a5d83d7e8c8138a822da" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.40" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.63.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+1" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - - [deps.CompositionsBase.weakdeps] - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.5" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.18" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.1" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.9.3" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "502d0232ec6430d40b6e5b57637333f32192592e" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.14" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8ae30e786837ce0a24f5e2186938bf3251ab94b2" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.8" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "47e4686ec18a9620850bad110b79966132f14283" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.0.2" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.12" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "7c6650580b4c3169d9905858160db895bff6d2e2" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.1" - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.4" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.12" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+2" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "264b061c1903bc0fe9be77cb9050ebacff66bb63" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.2" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.2" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.80" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 41cb79a..b8b961e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,10 @@ -name = "RandomFeatures" -uuid = "73bd944d-2787-4b1c-bb49-32a090f669f3" +name = "RandomFeatureMaps" +uuid = "780baa95-dd42-481b-93db-80fe3d88832c" authors = ["murrellb and contributors"] -version = "1.0.0-DEV" +version = "0.1.0" [deps] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/README.md b/README.md index b5c2a7a..355ff29 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ -# RandomFeatures +# RandomFeatureMaps -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/RandomFeatures.jl/stable/) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/RandomFeatures.jl/dev/) -[![Build Status](https://github.com/murrellb/RandomFeatures.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/RandomFeatures.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/murrellb/RandomFeatures.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/RandomFeatures.jl) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/stable/) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/RandomFeatureMaps.jl/dev/) +[![Build Status](https://github.com/murrellb/RandomFeatureMaps.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/RandomFeatureMaps.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/murrellb/RandomFeatureMaps.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/RandomFeatureMaps.jl) -# This package is deprecated - -[RandomFeatures.jl](https://github.com/CliMA/RandomFeatures.jl) is already taken and registered, so new and registered versions of this package can be found at https://github.com/MurrellGroup/RandomFeatureMaps.jl +This Julia package provides types that can be used as layers in neural networks for creating random feature embeddings of data. \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index 0bf2bf8..8b1932b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -RandomFeatures = "73bd944d-2787-4b1c-bb49-32a090f669f3" +RandomFeatureMaps = "73bd944d-2787-4b1c-bb49-32a090f669f3" diff --git a/docs/make.jl b/docs/make.jl index 8e3dff1..9722e1c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,14 +1,14 @@ -using RandomFeatures +using RandomFeatureMaps using Documenter -DocMeta.setdocmeta!(RandomFeatures, :DocTestSetup, :(using RandomFeatures); recursive=true) +DocMeta.setdocmeta!(RandomFeatureMaps, :DocTestSetup, :(using RandomFeatureMaps); recursive=true) makedocs(; - modules=[RandomFeatures], + modules=[RandomFeatureMaps], authors="murrellb and contributors", - sitename="RandomFeatures.jl", + sitename="RandomFeatureMaps.jl", format=Documenter.HTML(; - canonical="https://MurrellGroup.github.io/RandomFeatures.jl", + canonical="https://MurrellGroup.github.io/RandomFeatureMaps.jl", edit_link="main", assets=String[], ), @@ -18,6 +18,6 @@ makedocs(; ) deploydocs(; - repo="github.com/MurrellGroup/RandomFeatures.jl", + repo="github.com/MurrellGroup/RandomFeatureMaps.jl", devbranch="main", ) diff --git a/docs/src/index.md b/docs/src/index.md index e9320c4..cf56241 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,14 +1,14 @@ ```@meta -CurrentModule = RandomFeatures +CurrentModule = RandomFeatureMaps ``` -# RandomFeatures +# RandomFeatureMaps -Documentation for [RandomFeatures](https://github.com/MurrellGroup/RandomFeatures.jl). +Documentation for [RandomFeatureMaps](https://github.com/MurrellGroup/RandomFeatureMaps.jl). ```@index ``` ```@autodocs -Modules = [RandomFeatures] +Modules = [RandomFeatureMaps] ``` diff --git a/src/RandomFeatureMaps.jl b/src/RandomFeatureMaps.jl new file mode 100644 index 0000000..1c6b836 --- /dev/null +++ b/src/RandomFeatureMaps.jl @@ -0,0 +1,96 @@ +module RandomFeatureMaps + +export RandomFourierFeatures +export RandomOrientationFeatures + +using NNlib: batched_mul +using Functors: @functor +import Optimisers + +""" + RandomFourierFeatures(n => m, σ) + +Maps `n`-dimensional data and projects it to `m`-dimensional random fourier features. + +## Example + +```jldoctest +julia> rff = RandomFourierFeatures(2 => 4, 1.0); # maps 2D data to 4D + +julia> rff(rand(2, 3)) |> size # 3 samples +(4, 3) +``` +""" +struct RandomFourierFeatures{T <: Real, A <: AbstractMatrix{T}} + W::A +end + +@functor RandomFourierFeatures +Optimisers.trainable(::RandomFourierFeatures) = (;) # no trainable parameters + +RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = + RandomFourierFeatures(dims, float(σ)) + +# d1: input dimension, d2: output dimension (d1 => d2) +function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::AbstractFloat) + iseven(d2) || throw(ArgumentError("dimension must be even")) + isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive")) + return RandomFourierFeatures(randn(typeof(σ), d1, d2 ÷ 2) * σ * oftype(σ, 2π)) +end + +function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where T <: Real + WtX = rff.W'X + return [cos.(WtX); sin.(WtX)] +end + +function (rff::RandomFourierFeatures{T})(X::AbstractArray{T}) where T <: Real + X′ = reshape(X, size(X, 1), :) + Y′ = rff(X′) + Y = reshape(Y′, :, size(X)[2:end]...) + return Y +end + +""" + RandomOrientationFeatures(m, σ) + +Projects rigid transformations them to `m` features. +These will be the pairwise distances between points. +""" +struct RandomOrientationFeatures{A <: AbstractArray{<:Real}} + FA::A + FB::A +end + +@functor RandomOrientationFeatures +Optimisers.trainable(::RandomOrientationFeatures) = (;) # no trainable parameters + +# should it just have a single array? such that pairwise distances require two of these +function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat) + isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive")) + return RandomOrientationFeatures(randn(typeof(σ), 3, dim, 1) * σ, randn(typeof(σ), 3, dim, 1) * σ) +end + +### For non-graph version, with batch dim +function transform_rigid(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T + x′ = reshape(x, 3, size(x, 2), :) + R′ = reshape(R, 3, 3, :) + t′ = reshape(t, 3, 1, :) + y′ = batched_mul(R′, x′) .+ t′ + y = reshape(y′, 3, size(x, 2), size(R)[3:end]...) + return y +end + +function (rof::RandomOrientationFeatures)(rigid::Tuple{AbstractArray, AbstractArray}) + dim = size(rof.FA, 2) + Nr, batch... = size(rigid[1])[3:end] + p1 = reshape(transform_rigid(rof.FA, rigid...), 3, dim, Nr, batch...) + p2 = reshape(transform_rigid(rof.FB, rigid...), 3, dim, Nr, batch...) + return dropdims(sqrt.(sum(abs2, + reshape(p1, 3, dim, Nr, 1, batch...) .- + reshape(p2, 3, dim, 1, Nr, batch...), + dims=1)), dims=1) +end + +### TODO: GRAPH version - see https://github.com/MurrellGroup/RandomFeatures.jl/blob/main/src/RandomFeatures.jl#L91 + +end diff --git a/src/RandomFeatures.jl b/src/RandomFeatures.jl deleted file mode 100644 index 4490b59..0000000 --- a/src/RandomFeatures.jl +++ /dev/null @@ -1,127 +0,0 @@ -module RandomFeatures - -using Flux: Flux, unsqueeze, batched_mul -using Optimisers - -# Random fourier features -# ----------------------- -struct RandomFourierFeatures{T <: Real, A <: AbstractMatrix{T}} - W::A -end - -Flux.@functor RandomFourierFeatures -Optimisers.trainable(::RandomFourierFeatures) = (;) # no trainable parameters - -RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = - RandomFourierFeatures(dims, float(σ)) - -# d1: input dimension, d2: output dimension (d1 => d2) -function RandomFourierFeatures((d1, d2)::Pair{<:Integer, <:Integer}, σ::AbstractFloat) - iseven(d2) || throw(ArgumentError("dimension must be even")) - isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive")) - return RandomFourierFeatures(randn(typeof(σ), d1, d2 ÷ 2) * σ * oftype(σ, 2π)) -end - -function (rff::RandomFourierFeatures{T})(X::AbstractMatrix{T}) where {T <: Real} - WtX = rff.W'X - return [cos.(WtX); sin.(WtX)] -end - -#Handling batched inputs -function (rff::RandomFourierFeatures{T})(X::AbstractArray{T,3}) where {T <: Real} - WtX = batched_mul(rff.W',X) - return [cos.(WtX); sin.(WtX)] -end - -function (rff::RandomFeatures.RandomFourierFeatures{T})(X::AbstractArray{T}) where {T <: Real} - return reshape( - rff(reshape(X, size(X, 1), :)) - , :, size(X)[2:end]...) -end - -export RandomFourierFeatures - -#rff = RandomFourierFeatures(10=>20,0.1f0) -#x = randn(Float32,10,11,1) -#isapprox(rff(x[:,:]) , rff(x)[:,:,1]) - -###RANDOM ORIENTATION FEATURES### - -#We want to build a map that takes a pair of residue locations+orientations (each residue is a vector and a rot matrix) -#and projects this to M features. These will be the pairwise distances between points. - -###SHARED -struct RandomOrientationFeatures{A} - FA::A - FB::A -end - -Flux.@functor RandomOrientationFeatures -Optimisers.trainable(::RandomOrientationFeatures) = (;) # no trainable parameters - - -# d1: input dimension, d2: output dimension (d1 => d2) -function RandomOrientationFeatures(dim::Integer, σ::AbstractFloat) - isfinite(σ) && σ > 0 || throw(ArgumentError("scale must be finite and positive")) - return RandomOrientationFeatures(randn(typeof(σ), 3, dim, 1) .* σ, randn(typeof(σ), 3, dim, 1) .* σ) -end - -###NON-GRAPH -#For non-graph version, with batch dim -function T_R3(mat::AbstractArray, rot::AbstractArray, trans::AbstractArray) - rotc = reshape(rot, 3,3,:) - trans = reshape(trans, 3,1,:) - matc = reshape(mat,3,size(mat,2),:) - rotated_mat = batched_mul(rotc,matc) .+ trans - return rotated_mat -end - -function (rof::RandomOrientationFeatures)(Ti::Tuple{AbstractArray, AbstractArray}) - dim = size(rof.FA, 2) - Nr,batch = size(Ti[1])[3:4] - - p1 = reshape(T_R3(rof.FA, Ti[1], Ti[2]),3,dim,Nr,batch) - p2 = reshape(T_R3(rof.FB, Ti[1], Ti[2]),3,dim,Nr,batch) - return sqrt.(sum(abs2.(unsqueeze(p1,dims=4) .- unsqueeze(p2, dims = 3)),dims=1))[1,:,:,:,:] -end - - - - -###GRAPH - will be enabled when InvariantPointAttention is registered - -#= -#Note: This might be bad at sensing position in frame A for a frame B that is far from A. Unclear. -####Random Orientation Features#### -function rot_features(T1,T2,FA,FB) - l = size(T1.translations, 2) - pointsA = transform(T1, FA) - pointsB = transform(T2, FB) - return reshape(sqrt.(sum(abs2.(pointsA .- pointsB), dims = 1)), (:, l)) -end - -subt(xi, xj, e) = xj .- xi -function graph_rot_features(g, gT, FA, FB) - K = size(FA, 2) - pointsA = transform(gT, FA) - pointsB = transform(gT, FB) - diffs = apply_edges(subt, g, xi=pointsB, xj=pointsA) - return reshape(sqrt.(sum(abs2.(diffs), dims = 1)), (K, :)) -end - - -function (rof::RandomOrientationFeatures)(g::GNNGraph, gT::RigidTransformation{T}) where {T <: Real} - return graph_rot_features(g, gT,rof.FA,rof.FB) -end - -function (rof::RandomOrientationFeatures)(T1::RigidTransformation{T}, T2::RigidTransformation{T}) where {T <: Real} - return rot_features(T1,T2,rof.FA,rof.FB) -end - -=# - -export RandomOrientationFeatures - - - -end diff --git a/test/runtests.jl b/test/runtests.jl index ce8d477..8ef4702 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,12 @@ -using RandomFeatures +using RandomFeatureMaps using Test -@testset "RandomFeatures.jl" begin - rff = RandomFourierFeatures(10=>20,0.1f0) - x = randn(Float32,10,11,1) - @test isapprox(rff(x[:,:]) , rff(x)[:,:,1]) +@testset "RandomFeatureMaps.jl" begin + rff = RandomFourierFeatures(10 => 20, 0.1f0) + x = randn(Float32, 10, 4) + @test rff(x) |> size == (20, 4) + @test rff(reshape(x, 10, 2, 2)) |> size == (20, 2, 2) - rof = RandomOrientationFeatures(16,0.1f0) - @test size(rof((randn(3,3,10,1),randn(3,1,10,1)))) == (16, 10, 10, 1) + rof = RandomOrientationFeatures(10, 0.1f0) + @test rof((randn(Float32, 3, 3, 4, 2), randn(Float32, 3, 1, 4, 2))) |> size == (10, 4, 4, 2) end