Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing W_f Implementation in ChildSumTreeLSTMCell: Alignment with Paper and Proposed Fix #7848

Open
scopeofaperture opened this issue Dec 18, 2024 · 0 comments

Comments

@scopeofaperture
Copy link

📚 Documentation

Issue is in the examples and thus in the tutorials

I believe the current implementation of the ChildSumTreeLSTMCell in examples/pytorch/tree_lstm/tree_lstm.py does not conform to the Child-Sum Tree LSTM described in the original paper or the documentation. Specifically, the implementation lacks the weight matrix $W^{(f)}$ described in Equation 4 of the paper.

Evidence:

The equation specifies a weight matrix $W^{(f)}$, which contributes to the calculation of forget gates for each child node. However, this matrix is not implemented in the code.

equation 4

The current code does not include a linear layer to compute $W^{(f)}$. This deviation can lead to incorrect behavior, as the forget gates are not calculated as per the paper.

Proposed Fix:

I have drafted an alternative implementation based on the original implementation of the ChildSumTreeLSTMCell that incorporates $W^{(f)}$ with the old API:

class ChildSumTreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))

        self.W_f = nn.Linear(x_size, h_size, bias=False)
        self.U_f = nn.Linear(h_size, h_size, bias=False)
        self.b_f = nn.Parameter(torch.zeros(1, h_size))

    def message_func(self, edges):
        return {"h": edges.src["h"], "c": edges.src["c"]}

    def reduce_func(self, nodes):
        h_tild = torch.sum(nodes.mailbox["h"], 1)
        wx = self.W_f(nodes.data["x"]).unsqueeze(1)
        uh = self.U_f(nodes.mailbox["h"])
        f = torch.sigmoid(wx + uh + self.b_f.unsqueeze(1))
        c_tild = torch.sum(f * nodes.mailbox["c"], 1)
        return {"h_tild": h_tild, "c_tild": c_tild}

    def apply_node_func(self, nodes):
        # equation (3), (5), (6)
        iou = self.W_iou(nodes.data["x"]) + self.b_iou
        if "h_tild" in nodes.data:
            iou += self.U_iou(nodes.data["h_tild"])
        i, o, u = torch.chunk(iou, 3, 1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        # equation (7)
        c = i * u
        if "c_tild" in nodes.data:
            c += nodes.data["c_tild"]
        # equation (8)
        h = o * torch.tanh(c)
        return {"h": h, "c": c}

Observed Issue with Training:

Currently, the cell isn't training as I imagined for my use case and I was hoping to strike a balance between helping to fix the example and getting confirmation that the cell is implemented correctly.

Request for Feedback:

  1. Could someone confirm whether the proposed implementation aligns with the Child-Sum Tree LSTM described in the paper?
  2. Any suggestions to align it with DGL best practices?

Thank you for your time and assistance. I am happy to further contribute by refining the implementation or submitting a PR if this approach is confirmed to be correct (new to Graphs and DGL).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant