Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape mismatch error when switching the model to GAT #40

Open
Yujun-Yan opened this issue Jul 6, 2021 · 3 comments
Open

Shape mismatch error when switching the model to GAT #40

Yujun-Yan opened this issue Jul 6, 2021 · 3 comments

Comments

@Yujun-Yan
Copy link

Hi, I got shape mismatch error for this line "x_j += edge_attr" in the message function of GATConv class when I tried to switch to the GAT model. It seems that the reshaping "x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)" in the forward function mess up the shape of "x_j".

@wubo2180
Copy link

wubo2180 commented Sep 2, 2021

You may change x = self.weight_linear(x).view(-1, self.heads, self.emb_dim) to x = self.weight_linear(x). And add x_i=x_i.view(-1, self.heads, self.emb_dim),x_j=x_j.view(-1, self.heads, self.emb_dim) in function def message(self, edge_index, x_i, x_j, edge_attr):. The whole code is described as follows:
` def forward(self, x, edge_index, edge_attr):

    #add self loops in the edge space
    edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

    #add features corresponding to self-loop edges.
    self_loop_attr = torch.zeros(x.size(0), 2)
    self_loop_attr[:,0] = 4 #bond type for self-loop edge
    self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
    edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

    edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

    #x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
    x = self.weight_linear(x)
    #return self.propagate(self.aggr, edge_index[0], x=x, edge_attr=edge_embeddings)
    return self.propagate( edge_index[0], x=x, edge_attr=edge_embeddings)

def message(self, edge_index, x_i, x_j, edge_attr):
    x_i=x_i.view(-1, self.heads, self.emb_dim)
    x_j=x_j.view(-1, self.heads, self.emb_dim)
    edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
    x_j += edge_attr

    alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)

    alpha = F.leaky_relu(alpha, self.negative_slope)
    alpha = softmax(alpha, edge_index[0])

    return x_j * alpha.view(-1, self.heads, 1)`

@kajjana
Copy link

kajjana commented Apr 20, 2022

I change the code as you described but I still get this error.

Traceback (most recent call last):
File "attribute_masking.py", line 784, in
main()
File "attribute_masking.py", line 778, in main
train_loss, train_acc_atom, train_acc_bond = train(mask_edge, model_list, loader, optimizer_list, device)
File "attribute_masking.py", line 699, in train
node_rep = model(batch.x, batch.edge_index, batch.edge_attr)
File "/home/programs/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "attribute_masking.py", line 276, in forward
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
File "/home/programs/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "attribute_masking.py", line 157, in forward
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)
File "/home/programs/conda/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 344, in propagate
out = self.aggregate(out, **aggr_kwargs)
File "/home/programs/conda/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 428, in aggregate
reduce=self.aggr)
File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/scatter.py", line 152, in scatter
return scatter_sum(src, index, dim, out, dim_size)
File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum
index = broadcast(index, src, dim)
File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/utils.py", line 12, in broadcast
src = src.expand(other.size())
RuntimeError: The expanded size of the tensor (2) must match the existing size (2918) at non-singleton dimension 1. Target sizes: [2918, 2, 300]. Tensor sizes: [1, 2918, 1]

@chao1224
Copy link

chao1224 commented Jul 25, 2022

@kajjana
It seems to be the issue of self.node_dim. Originally it's set as -2, and I hack it by setting it as 0. In specific, there are two ways to handle this:

  1. Rewrite the following function:
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr)
  1. Or another simple way is to fix it by setting self.node_dim=0 in GATConv.

I have a clean version in this repo, feel free to check out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants