You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to train a GNN where the edges of the graph are also learned (crossposting from the GraphNeuralNetworks.jl repo). To do that, I manually readjust the edges of the graph and the weights of the neural network as part of the prediction scheme. I have the following MWE:
using Graphs, GraphNeuralNetworks, Flux, ComponentArrays, Zygote
time = 1:10
obs = rand(9,10)
x0 = obs[:,1]
fullGraph = GNNGraph(complete_digraph(3))
layer1 = GCNConv(3 => 10,tanh,use_edge_weight=true)
layer2 = GCNConv(10 => 3,use_edge_weight=true)
chain = GNNChain(layer1,layer2)
pinit = ComponentArray{Float32}(weights = rand(ne(fullGraph)),
layer1 = layer1.weight,layer2 = layer2.weight)
function predict(p,x0)
fullGraph = GNNGraph(complete_digraph(3))
fullGraph = set_edge_weight(fullGraph,p.weights)
chain.layers[1].weight .= p.layer1
chain.layers[2].weight .= p.layer2
sol = chain(fullGraph,reshape(x0,(3,3)))
return Array(sol)
end
function loss_function(p)
pred = reduce(hcat,[reshape(predict(p,obs[:,i]),(9,1)) for i in 2:size(obs,2)])
sum(abs2,pred .- obs[:,2:end])
end
Zygote.gradient(loss_function,pinit)
which is yielding the ERROR: Mutating arrays is not supported -- called copyto!(Matrix{Float32}, ...) error. I tried setting up a Zygote.Buffer on the Flux chain, but I get the error:
I am trying to train a GNN where the edges of the graph are also learned (crossposting from the GraphNeuralNetworks.jl repo). To do that, I manually readjust the edges of the graph and the weights of the neural network as part of the prediction scheme. I have the following MWE:
which is yielding the
ERROR: Mutating arrays is not supported -- called copyto!(Matrix{Float32}, ...)
error. I tried setting up aZygote.Buffer
on theFlux
chain, but I get the error:Is there any way to do this with
Zygote
? RunningZygote.pullback
works flawlessly.The text was updated successfully, but these errors were encountered: