From cb8d287890fd30b217d5448018116a1b127ae1fb Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Thu, 21 Mar 2024 17:23:45 +0100 Subject: [PATCH 1/7] wip: expand --- Manifest.toml | 262 +++++++++++++++++++++++++++++--------------------- src/layers.jl | 80 +++++++++++++++ 2 files changed, 234 insertions(+), 108 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 28ecc11..1573492 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.2" +julia_version = "1.10.2" manifest_format = "2.0" project_hash = "5dda15bf4a9cbd828be2b3ad2d454bfcbd288388" @@ -17,9 +17,9 @@ weakdeps = ["ChainRulesCore", "Test"] [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" +version = "4.0.4" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -45,9 +45,9 @@ version = "0.1.0" [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" +git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.39" +version = "0.3.40" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" @@ -72,21 +72,25 @@ uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" [[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" +version = "0.5.0" [[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.53.0" +version = "1.63.0" [[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" +version = "1.23.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -95,10 +99,10 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "5ce999a19f4ca23ea484e92a1774a61b8ca4cf8e" +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.8.0" +version = "4.14.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -107,7 +111,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.1.0+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -122,9 +126,9 @@ version = "0.1.2" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" +git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.3" +version = "1.5.5" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" @@ -141,15 +145,15 @@ uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.3" [[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" +version = "1.16.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" +git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.14" +version = "0.18.18" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -214,16 +218,26 @@ version = "0.1.1" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "f372472e8672b1d993e93dada09e23139b509f9e" +deps = ["LinearAlgebra", "Random"] +git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.5.0" +version = "1.9.3" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "e0a829d77e750a916a52df71b82fde7f6b336a92" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.1" +version = "0.14.15" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -239,9 +253,9 @@ version = "0.14.1" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -249,9 +263,9 @@ weakdeps = ["StaticArrays"] [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545" +git-tree-sha1 = "8ae30e786837ce0a24f5e2186938bf3251ab94b2" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.5" +version = "0.4.8" [[deps.Future]] deps = ["Random"] @@ -259,21 +273,21 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" +git-tree-sha1 = "47e4686ec18a9620850bad110b79966132f14283" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.8.1" +version = "10.0.2" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" +version = "0.1.6" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" +git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.10" +version = "0.4.12" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -295,10 +309,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" +version = "1.5.0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -308,9 +322,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" +git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" +version = "0.9.18" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -319,16 +333,22 @@ version = "0.9.8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "ab01dde107f21aa76144d0771dccc08f152ccac7" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.1.0" +version = "6.6.2" + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + + [deps.LLVM.weakdeps] + BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" +git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.23+0" +version = "0.0.29+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -337,21 +357,26 @@ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -362,9 +387,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.24" +version = "0.3.27" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -386,15 +411,15 @@ version = "0.4.17" [[deps.MLUtils]] deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "3504cdb8c2bc05bde4d4b09a81b01df88fcbbba0" +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.3" +version = "0.4.4" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" +version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] @@ -403,7 +428,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.2+1" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -422,22 +447,24 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2023.1.10" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" +git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.4" +version = "0.9.12" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NaNMath]] @@ -458,19 +485,19 @@ version = "1.2.0" [[deps.OneHotArrays]] deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "5e4029759e8699ec12ebdf8721e51a659443403c" +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.4" +version = "0.2.5" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" +version = "0.8.1+2" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -480,31 +507,31 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "16776280310aa5553c370b9c7b17f34aadaf3c8e" +git-tree-sha1 = "264b061c1903bc0fe9be77cb9050ebacff66bb63" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.19" +version = "0.3.2" [[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" +version = "1.6.3" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.10.0" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.2" +version = "1.2.1" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" +version = "1.4.3" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -526,7 +553,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.RealDot]] @@ -575,19 +602,26 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" +version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.0" +version = "2.3.1" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -600,13 +634,14 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.2" -weakdeps = ["Statistics"] +version = "1.9.3" +weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] @@ -617,30 +652,41 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" +version = "1.10.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.6.0" +version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" +git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" +version = "0.34.2" [[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] -git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.15" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" +version = "7.2.1+1" [[deps.TOML]] deps = ["Dates"] @@ -654,10 +700,10 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.1" +version = "1.11.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -670,9 +716,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" +git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.78" +version = "0.4.80" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -709,13 +755,13 @@ version = "0.1.3" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" +git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.62" +version = "0.6.69" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -729,21 +775,21 @@ version = "0.6.62" [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" +version = "0.2.5" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.8.0+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/src/layers.jl b/src/layers.jl index 579b9d4..c5231a6 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -264,3 +264,83 @@ function (structuremodulelayer::Union{IPCrossAStructureModuleLayer, IPAStructure T_R = l.backbone(T_R, S_R) return T_R, S_R end + +struct IPACache + sizeL + sizeR + batchsize + + # cached arrays + qh # channel × head × residues (R) × batch + kh # channel × head × residues (L) × batch + vh # channel × head × residues (L) × batch + + #qhp # 3 × head × query points × residues (R) × batch + #khp # 3 × head × query points × residues (L) × batch + #vhp # 3 × head × query points × residues (L) × batch +end + +function IPACache(settings, batchsize) + (; c, N_head, N_query_points, N_point_values) = settings + qh = zeros(Float32, c, N_head, 0, batchsize) + kh = zeros(Float32, c, N_head, 0, batchsize) + vh = zeros(Float32, c, N_head, 0, batchsize) + #qhp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) + #khp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) + #vhp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) + IPACache(0, 0, batchsize, qh, kh, vh, #=qhp, khp, vhp,=# ) +end + +function expand( + ipa::IPCrossA, + cache::IPACache, + TiL::Tuple, siL::AbstractArray, ΔL::Integer, + TiR::Tuple, siR::AbstractArray, ΔR::Integer, +) + dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings + L, R, batchsize = cache.sizeL, cache.sizeR, cache.batchsize + + layer = ipa.layers + Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, :)) + Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, :)) + Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, :)) + + kh = cat(cache.kh, Δkh, dims = 3) + vh = cat(cache.vh, Δvh, dims = 3) + + # calculate inner products + ΔqhT = permutedims(Δqh, (3, 1, 2, 4)) + kh = permutedims(kh, (1, 3, 2, 4)) + ΔqhTkh = permutedims(batched_mul(ΔqhT, kh), (3, 1, 2, 4)) + + dim_scale = Float32(1/sqrt(c)) + Δatt_logits = reshape(dim_scale .* ΔqhTkh, (N_head, ΔR, L + ΔL, batchsize)) + + w_L = Float32(sqrt(1/2)) # TODO + Δatt = softmax(w_L .* Δatt_logits, dims = 3) + + # take the attention weighted sum of the value vectors + oh = sumdrop( + reshape(Δatt, (1, N_head, ΔR, L + ΔL, batchsize)) .* + reshape( vh, (c, N_head, 1, L + ΔL, batchsize)), + dims = 4, + ) + + o = cat( + reshape(oh, (c * N_head, ΔR, batchsize)), + zeros(Float32, (4 * N_point_values * N_head, ΔR, batchsize)), + dims = 1, + ) + cache = IPACache( + L + ΔL, + R + ΔR, + batchsize, + cat(cache.qh, Δqh, dims = 3), + cat(cache.kh, Δkh, dims = 3), + cat(cache.vh, Δvh, dims = 3), + ) + + layer.ipa_linear(o), cache +end + +sumdrop(x; dims) = dropdims(sum(x; dims); dims) From 9b22e35f4df3251871047daa4f3a4990e8b45d0e Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 16:30:59 +0100 Subject: [PATCH 2/7] vector points --- src/layers.jl | 87 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index c5231a6..a96e5ab 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -139,7 +139,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:)) khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:)) vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:)) - + # This should be Q'K, following IPA, which isn't like the regular QK' # Dot products between queries and keys. #FramesR, c, N_head, Batch @@ -275,9 +275,9 @@ struct IPACache kh # channel × head × residues (L) × batch vh # channel × head × residues (L) × batch - #qhp # 3 × head × query points × residues (R) × batch - #khp # 3 × head × query points × residues (L) × batch - #vhp # 3 × head × query points × residues (L) × batch + qhp # 3 × {head × query points} × residues (R) × batch + khp # 3 × {head × query points} × residues (L) × batch + vhp # 3 × {head × point values} × residues (L) × batch end function IPACache(settings, batchsize) @@ -285,10 +285,10 @@ function IPACache(settings, batchsize) qh = zeros(Float32, c, N_head, 0, batchsize) kh = zeros(Float32, c, N_head, 0, batchsize) vh = zeros(Float32, c, N_head, 0, batchsize) - #qhp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) - #khp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) - #vhp = zeros(Float32, 3, N_head, N_query_points, 0, batchsize) - IPACache(0, 0, batchsize, qh, kh, vh, #=qhp, khp, vhp,=# ) + qhp = zeros(Float32, 3, N_head * N_query_points, 0, batchsize) + khp = zeros(Float32, 3, N_head * N_query_points, 0, batchsize) + vhp = zeros(Float32, 3, N_head * N_point_values, 0, batchsize) + IPACache(0, 0, batchsize, qh, kh, vh, qhp, khp, vhp) end function expand( @@ -301,22 +301,49 @@ function expand( L, R, batchsize = cache.sizeL, cache.sizeR, cache.batchsize layer = ipa.layers - Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, :)) - Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, :)) - Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, :)) + + gamma_h = min.(softplus(layer.gamma_h), 1f2) + + Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, batchsize)) + Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, batchsize)) + Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, batchsize)) + + Δqhp = reshape(layer.proj_qhp(@view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, batchsize)) + Δkhp = reshape(layer.proj_khp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, batchsize)) + Δvhp = reshape(layer.proj_vhp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, batchsize)) kh = cat(cache.kh, Δkh, dims = 3) vh = cat(cache.vh, Δvh, dims = 3) + khp = cat(cache.khp, Δkhp, dims = 3) + vhp = cat(cache.vhp, Δvhp, dims = 3) + # calculate inner products ΔqhT = permutedims(Δqh, (3, 1, 2, 4)) kh = permutedims(kh, (1, 3, 2, 4)) ΔqhTkh = permutedims(batched_mul(ΔqhT, kh), (3, 1, 2, 4)) - dim_scale = Float32(1/sqrt(c)) - Δatt_logits = reshape(dim_scale .* ΔqhTkh, (N_head, ΔR, L + ΔL, batchsize)) + # transform vector points to the global frames + rot_TiL, translate_TiL = TiL + rot_TiR, translate_TiR = TiR + ΔTqhp = reshape(T_R3(Δqhp, @view(rot_TiR[:,:,R+1:R+ΔR,:]), @view(translate_TiR[:,:,R+1:R+ΔR,:])), (3, N_head, N_query_points, ΔR, batchsize)) + Tkhp = reshape( + T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL) * batchsize)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), + (3, N_head, N_query_points, L + ΔL, batchsize) + ) + Tvhp = reshape( + T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL) * batchsize)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), + (3, N_head, N_point_values, L + ΔL, batchsize) + ) - w_L = Float32(sqrt(1/2)) # TODO + diffs = unsqueeze(ΔTqhp, dims = 5) .- unsqueeze(Tkhp, dims = 4) + sum_norms = sumdrop(abs2, diffs, dims = (1, 3)) + + w_C = sqrt(2f0 / 9N_query_points) + dim_scale = sqrt(1f0 / c) + Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, batchsize)) + + w_L = sqrt(1f0/2) # TODO Δatt = softmax(w_L .* Δatt_logits, dims = 3) # take the attention weighted sum of the value vectors @@ -325,12 +352,30 @@ function expand( reshape( vh, (c, N_head, 1, L + ΔL, batchsize)), dims = 4, ) - - o = cat( - reshape(oh, (c * N_head, ΔR, batchsize)), - zeros(Float32, (4 * N_point_values * N_head, ΔR, batchsize)), - dims = 1, + ohp = reshape( + T_R3_inv( + reshape( + # 3 × N_head × N_point_values × ΔR × batch + sumdrop( + reshape(Δatt, (1, N_head, 1, ΔR, L + ΔL, batchsize)) .* + reshape(Tvhp, (3, N_head, N_point_values, 1, L + ΔL, batchsize)), + dims = 5, + ), + (3, N_head * N_point_values, ΔR * batchsize) + ), + @view(rot_TiR[:,:,R+1:R+ΔR,:]), + @view(translate_TiR[:,:,R+1:R+ΔR,:]) + ), + (3, N_head, N_point_values, ΔR, batchsize) ) + ohp_norms = sqrt.(sumdrop(abs2, ohp, dims = 1)) + + # concatenate all outputs + o = [ + reshape(oh, (c * N_head, ΔR, batchsize)) + reshape(ohp, (3 * N_head * N_point_values, ΔR, batchsize)) + reshape(ohp_norms, (N_head * N_point_values, ΔR, batchsize)) + ] cache = IPACache( L + ΔL, R + ΔR, @@ -338,9 +383,13 @@ function expand( cat(cache.qh, Δqh, dims = 3), cat(cache.kh, Δkh, dims = 3), cat(cache.vh, Δvh, dims = 3), + cat(cache.qhp, Δqhp, dims = 3), + cat(cache.khp, Δkhp, dims = 3), + cat(cache.vhp, Δvhp, dims = 3), ) layer.ipa_linear(o), cache end sumdrop(x; dims) = dropdims(sum(x; dims); dims) +sumdrop(f, x; dims) = dropdims(sum(f, x; dims); dims) From c5d01f28f61ef378f307a148d3a87564f2010d70 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 17:19:08 +0100 Subject: [PATCH 3/7] pairwise --- src/layers.jl | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index a96e5ab..3f006a6 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -295,7 +295,8 @@ function expand( ipa::IPCrossA, cache::IPACache, TiL::Tuple, siL::AbstractArray, ΔL::Integer, - TiR::Tuple, siR::AbstractArray, ΔR::Integer, + TiR::Tuple, siR::AbstractArray, ΔR::Integer; + zij = nothing ) dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings L, R, batchsize = cache.sizeL, cache.sizeR, cache.batchsize @@ -343,8 +344,14 @@ function expand( dim_scale = sqrt(1f0 / c) Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, batchsize)) - w_L = sqrt(1f0/2) # TODO - Δatt = softmax(w_L .* Δatt_logits, dims = 3) + if pairwise + bij = reshape(layer.pair(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:])), (N_head, ΔR, L + ΔL, batchsize)) + w_L = sqrt(1f0/3) + Δatt = softmax(w_L .* (Δatt_logits .+ bij), dims = 3) + else + w_L = sqrt(1f0/2) + Δatt = softmax(w_L .* Δatt_logits, dims = 3) + end # take the attention weighted sum of the value vectors oh = sumdrop( @@ -376,6 +383,20 @@ function expand( reshape(ohp, (3 * N_head * N_point_values, ΔR, batchsize)) reshape(ohp_norms, (N_head * N_point_values, ΔR, batchsize)) ] + if pairwise + o = [ + o + reshape( + sumdrop( + reshape( Δatt, ( 1, N_head, ΔR, L + ΔL, batchsize)) .* + reshape(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:]), (c_z, 1, ΔR, L + ΔL, batchsize)), + dims = 4 + ), + (c_z * N_head, ΔR, batchsize) + ) + ] + end + cache = IPACache( L + ΔL, R + ΔR, @@ -387,7 +408,6 @@ function expand( cat(cache.khp, Δkhp, dims = 3), cat(cache.vhp, Δvhp, dims = 3), ) - layer.ipa_linear(o), cache end From 60177a6d3a119605c97a0f47dbcfce3f315afa15 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 17:21:04 +0100 Subject: [PATCH 4/7] rename variable --- src/layers.jl | 54 +++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 3f006a6..2328730 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -299,19 +299,19 @@ function expand( zij = nothing ) dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings - L, R, batchsize = cache.sizeL, cache.sizeR, cache.batchsize + L, R, B = cache.sizeL, cache.sizeR, cache.batchsize layer = ipa.layers gamma_h = min.(softplus(layer.gamma_h), 1f2) - Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, batchsize)) - Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, batchsize)) - Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, batchsize)) + Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, B)) + Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) + Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) - Δqhp = reshape(layer.proj_qhp(@view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, batchsize)) - Δkhp = reshape(layer.proj_khp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, batchsize)) - Δvhp = reshape(layer.proj_vhp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, batchsize)) + Δqhp = reshape(layer.proj_qhp(@view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B)) + Δkhp = reshape(layer.proj_khp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B)) + Δvhp = reshape(layer.proj_vhp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B)) kh = cat(cache.kh, Δkh, dims = 3) vh = cat(cache.vh, Δvh, dims = 3) @@ -327,14 +327,14 @@ function expand( # transform vector points to the global frames rot_TiL, translate_TiL = TiL rot_TiR, translate_TiR = TiR - ΔTqhp = reshape(T_R3(Δqhp, @view(rot_TiR[:,:,R+1:R+ΔR,:]), @view(translate_TiR[:,:,R+1:R+ΔR,:])), (3, N_head, N_query_points, ΔR, batchsize)) + ΔTqhp = reshape(T_R3(Δqhp, @view(rot_TiR[:,:,R+1:R+ΔR,:]), @view(translate_TiR[:,:,R+1:R+ΔR,:])), (3, N_head, N_query_points, ΔR, B)) Tkhp = reshape( - T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL) * batchsize)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), - (3, N_head, N_query_points, L + ΔL, batchsize) + T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL) * B)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), + (3, N_head, N_query_points, L + ΔL, B) ) Tvhp = reshape( - T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL) * batchsize)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), - (3, N_head, N_point_values, L + ΔL, batchsize) + T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL) * B)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])), + (3, N_head, N_point_values, L + ΔL, B) ) diffs = unsqueeze(ΔTqhp, dims = 5) .- unsqueeze(Tkhp, dims = 4) @@ -342,10 +342,10 @@ function expand( w_C = sqrt(2f0 / 9N_query_points) dim_scale = sqrt(1f0 / c) - Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, batchsize)) + Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, B)) if pairwise - bij = reshape(layer.pair(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:])), (N_head, ΔR, L + ΔL, batchsize)) + bij = reshape(layer.pair(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:])), (N_head, ΔR, L + ΔL, B)) w_L = sqrt(1f0/3) Δatt = softmax(w_L .* (Δatt_logits .+ bij), dims = 3) else @@ -355,8 +355,8 @@ function expand( # take the attention weighted sum of the value vectors oh = sumdrop( - reshape(Δatt, (1, N_head, ΔR, L + ΔL, batchsize)) .* - reshape( vh, (c, N_head, 1, L + ΔL, batchsize)), + reshape(Δatt, (1, N_head, ΔR, L + ΔL, B)) .* + reshape( vh, (c, N_head, 1, L + ΔL, B)), dims = 4, ) ohp = reshape( @@ -364,35 +364,35 @@ function expand( reshape( # 3 × N_head × N_point_values × ΔR × batch sumdrop( - reshape(Δatt, (1, N_head, 1, ΔR, L + ΔL, batchsize)) .* - reshape(Tvhp, (3, N_head, N_point_values, 1, L + ΔL, batchsize)), + reshape(Δatt, (1, N_head, 1, ΔR, L + ΔL, B)) .* + reshape(Tvhp, (3, N_head, N_point_values, 1, L + ΔL, B)), dims = 5, ), - (3, N_head * N_point_values, ΔR * batchsize) + (3, N_head * N_point_values, ΔR * B) ), @view(rot_TiR[:,:,R+1:R+ΔR,:]), @view(translate_TiR[:,:,R+1:R+ΔR,:]) ), - (3, N_head, N_point_values, ΔR, batchsize) + (3, N_head, N_point_values, ΔR, B) ) ohp_norms = sqrt.(sumdrop(abs2, ohp, dims = 1)) # concatenate all outputs o = [ - reshape(oh, (c * N_head, ΔR, batchsize)) - reshape(ohp, (3 * N_head * N_point_values, ΔR, batchsize)) - reshape(ohp_norms, (N_head * N_point_values, ΔR, batchsize)) + reshape(oh, (c * N_head, ΔR, B)) + reshape(ohp, (3 * N_head * N_point_values, ΔR, B)) + reshape(ohp_norms, (N_head * N_point_values, ΔR, B)) ] if pairwise o = [ o reshape( sumdrop( - reshape( Δatt, ( 1, N_head, ΔR, L + ΔL, batchsize)) .* - reshape(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:]), (c_z, 1, ΔR, L + ΔL, batchsize)), + reshape( Δatt, ( 1, N_head, ΔR, L + ΔL, B)) .* + reshape(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:]), (c_z, 1, ΔR, L + ΔL, B)), dims = 4 ), - (c_z * N_head, ΔR, batchsize) + (c_z * N_head, ΔR, B) ) ] end @@ -400,7 +400,7 @@ function expand( cache = IPACache( L + ΔL, R + ΔR, - batchsize, + B, cat(cache.qh, Δqh, dims = 3), cat(cache.kh, Δkh, dims = 3), cat(cache.vh, Δvh, dims = 3), From 81149b2d11c1981ee02c5e39aac7429eea08857a Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 17:29:27 +0100 Subject: [PATCH 5/7] mask --- src/layers.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 2328730..dacdbe5 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -296,7 +296,8 @@ function expand( cache::IPACache, TiL::Tuple, siL::AbstractArray, ΔL::Integer, TiR::Tuple, siR::AbstractArray, ΔR::Integer; - zij = nothing + zij = nothing, + mask = 0, ) dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings L, R, B = cache.sizeL, cache.sizeR, cache.batchsize @@ -344,13 +345,17 @@ function expand( dim_scale = sqrt(1f0 / c) Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, B)) + if mask != 0 + mask = unsqueeze(@view(mask[R+1:R+ΔR,1:L+ΔL]), dims = 1) + end + if pairwise bij = reshape(layer.pair(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:])), (N_head, ΔR, L + ΔL, B)) w_L = sqrt(1f0/3) - Δatt = softmax(w_L .* (Δatt_logits .+ bij), dims = 3) + Δatt = softmax(w_L .* (Δatt_logits .+ bij) .+ mask, dims = 3) else w_L = sqrt(1f0/2) - Δatt = softmax(w_L .* Δatt_logits, dims = 3) + Δatt = softmax(w_L .* Δatt_logits .+ mask, dims = 3) end # take the attention weighted sum of the value vectors From b02ea5485d2a3bd69586b3e085ac6b4f7b859b63 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 18:04:50 +0100 Subject: [PATCH 6/7] test --- src/layers.jl | 26 ++++++++++++++++++-------- test/runtests.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index dacdbe5..9bb4fbc 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -280,7 +280,12 @@ struct IPACache vhp # 3 × {head × point values} × residues (L) × batch end -function IPACache(settings, batchsize) +""" + IPACache(settings, batchsize) + +Initialize an empty IPA cache. +""" +function IPACache(settings::NamedTuple, batchsize::Integer) (; c, N_head, N_query_points, N_point_values) = settings qh = zeros(Float32, c, N_head, 0, batchsize) kh = zeros(Float32, c, N_head, 0, batchsize) @@ -301,18 +306,17 @@ function expand( ) dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings L, R, B = cache.sizeL, cache.sizeR, cache.batchsize - layer = ipa.layers gamma_h = min.(softplus(layer.gamma_h), 1f2) - Δqh = reshape(layer.proj_qh(@view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, B)) - Δkh = reshape(layer.proj_kh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) - Δvh = reshape(layer.proj_vh(@view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) + Δqh = reshape(calldense(layer.proj_qh, @view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, B)) + Δkh = reshape(calldense(layer.proj_kh, @view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) + Δvh = reshape(calldense(layer.proj_vh, @view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B)) - Δqhp = reshape(layer.proj_qhp(@view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B)) - Δkhp = reshape(layer.proj_khp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B)) - Δvhp = reshape(layer.proj_vhp(@view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B)) + Δqhp = reshape(calldense(layer.proj_qhp, @view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B)) + Δkhp = reshape(calldense(layer.proj_khp, @view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B)) + Δvhp = reshape(calldense(layer.proj_vhp, @view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B)) kh = cat(cache.kh, Δkh, dims = 3) vh = cat(cache.vh, Δvh, dims = 3) @@ -418,3 +422,9 @@ end sumdrop(x; dims) = dropdims(sum(x; dims); dims) sumdrop(f, x; dims) = dropdims(sum(f, x; dims); dims) + +# dense(x) to avoid https://github.com/FluxML/Flux.jl/issues/2407 +function calldense(dense::Dense, x::AbstractArray) + d1 = size(dense.weight, 1) + reshape(dense(reshape(x, size(x, 1), :)), d1, size(x)[2:end]...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 256ae63..f47f48c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,37 @@ using InvariantPointAttention +using InvariantPointAttention: get_rotation, get_translation using Test @testset "InvariantPointAttention.jl" begin # Write your tests here. + + + @testset "IPACache" begin + dims = 8 + c_z = 2 + settings = IPA_settings(dims; c_z) + ipa = IPCrossA(settings) + + # generate random data + L = 5 + R = 6 + B = 4 + siL = randn(Float32, dims, L, B) + siR = randn(Float32, dims, R, B) + zij = randn(Float32, c_z, R, L, B) + TiL = (get_rotation(L, B), get_translation(L, B)) + TiR = (get_rotation(R, B), get_translation(R, B)) + + # check the consistency + cache = InvariantPointAttention.IPACache(settings, B) + siR′, cache′ = InvariantPointAttention.expand(ipa, cache, TiL, siL, L, TiR, siR, R; zij) + @test size(siR′) == size(siR) + @test siR′ ≈ ipa(TiL, siL, TiR, siR; zij) + + # calculate in two steps + cache = InvariantPointAttention.IPACache(settings, B) + siR1, cache = InvariantPointAttention.expand(ipa, cache, TiL, siL, L, TiR, siR, 2; zij) + siR2, cache = InvariantPointAttention.expand(ipa, cache, TiL, siL, 0, TiR, siR, 4; zij) + @test cat(siR1, siR2, dims = 2) ≈ ipa(TiL, siL, TiR, siR; zij) + end end From 8649728ec94b4812c84b8a34a6d76981171ecd34 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 22 Mar 2024 18:18:51 +0100 Subject: [PATCH 7/7] revert changes to Manifest.toml --- Manifest.toml | 262 +++++++++++++++++++++----------------------------- 1 file changed, 108 insertions(+), 154 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 1573492..28ecc11 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.9.2" manifest_format = "2.0" project_hash = "5dda15bf4a9cbd828be2b3ad2d454bfcbd288388" @@ -17,9 +17,9 @@ weakdeps = ["ChainRulesCore", "Test"] [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "3.6.2" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -45,9 +45,9 @@ version = "0.1.0" [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" +git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.40" +version = "0.3.39" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" @@ -72,25 +72,21 @@ uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" [[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" +version = "0.4.2" [[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c" +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.63.0" +version = "1.53.0" [[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" +version = "1.16.0" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -99,10 +95,10 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +deps = ["UUIDs"] +git-tree-sha1 = "5ce999a19f4ca23ea484e92a1774a61b8ca4cf8e" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.8.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -111,7 +107,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.0.5+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -126,9 +122,9 @@ version = "0.1.2" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" +git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.5" +version = "1.5.3" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" @@ -145,15 +141,15 @@ uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.3" [[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" +version = "1.15.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" +git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.18" +version = "0.18.14" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -218,26 +214,16 @@ version = "0.1.1" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "f372472e8672b1d993e93dada09e23139b509f9e" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.9.3" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.5.0" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" +deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "e0a829d77e750a916a52df71b82fde7f6b336a92" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.15" +version = "0.14.1" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -253,9 +239,9 @@ version = "0.14.15" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" +version = "0.10.35" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -263,9 +249,9 @@ weakdeps = ["StaticArrays"] [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8ae30e786837ce0a24f5e2186938bf3251ab94b2" +git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.8" +version = "0.4.5" [[deps.Future]] deps = ["Random"] @@ -273,21 +259,21 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "47e4686ec18a9620850bad110b79966132f14283" +git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.0.2" +version = "8.8.1" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" +version = "0.1.5" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" +git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.12" +version = "0.4.10" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -309,10 +295,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.4.1" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -322,9 +308,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" +version = "0.9.8" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -333,22 +319,16 @@ version = "0.9.18" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "ab01dde107f21aa76144d0771dccc08f152ccac7" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.2" - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "6.1.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" +version = "0.0.23+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -357,26 +337,21 @@ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" +version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "7.84.0+0" [[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" +version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -387,9 +362,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" +version = "0.3.24" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -411,15 +386,15 @@ version = "0.4.17" [[deps.MLUtils]] deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +git-tree-sha1 = "3504cdb8c2bc05bde4d4b09a81b01df88fcbbba0" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" +version = "0.4.3" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" +version = "0.5.10" [[deps.Markdown]] deps = ["Base64"] @@ -428,7 +403,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.2+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -447,24 +422,22 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2022.10.11" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" +git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.12" +version = "0.9.4" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NaNMath]] @@ -485,19 +458,19 @@ version = "1.2.0" [[deps.OneHotArrays]] deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +git-tree-sha1 = "5e4029759e8699ec12ebdf8721e51a659443403c" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" +version = "0.2.4" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.21+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" +version = "0.8.1+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -507,31 +480,31 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "264b061c1903bc0fe9be77cb9050ebacff66bb63" +git-tree-sha1 = "16776280310aa5553c370b9c7b17f34aadaf3c8e" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.2" +version = "0.2.19" [[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" +version = "1.6.2" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.9.2" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" +version = "1.1.2" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" +version = "1.4.0" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -553,7 +526,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA"] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.RealDot]] @@ -602,26 +575,19 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" +version = "1.1.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.3.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -634,14 +600,13 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] +git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" -weakdeps = ["ChainRulesCore", "Statistics"] +version = "1.6.2" +weakdeps = ["Statistics"] [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] @@ -652,41 +617,30 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" +version = "1.6.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.2" +version = "0.34.0" [[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +version = "0.6.15" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "5.10.1+6" [[deps.TOML]] deps = ["Dates"] @@ -700,10 +654,10 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.1" +version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -716,9 +670,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" +git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.80" +version = "0.4.78" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -755,13 +709,13 @@ version = "0.1.3" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" +version = "1.2.13+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" +git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" +version = "0.6.62" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -775,21 +729,21 @@ version = "0.6.69" [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" +version = "0.2.3" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.48.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" +version = "17.4.0+0"