diff --git a/gvp/__init__.py b/gvp/__init__.py index 6e2e161..eaf8435 100644 --- a/gvp/__init__.py +++ b/gvp/__init__.py @@ -329,7 +329,7 @@ def forward(self, x, edge_index, edge_attr, :param edge_index: array of shape [2, n_edges] :param edge_attr: tuple (s, V) of `torch.Tensor` :param autoregressive_x: tuple (s, V) of `torch.Tensor`. - If not `None`, will be used as srcqq node embeddings + If not `None`, will be used as src node embeddings for forming messages where src >= dst. The corrent node embeddings `x` will still be the base of the update and the pointwise feedforward. @@ -371,4 +371,4 @@ def forward(self, x, edge_index, edge_attr, if node_mask is not None: x_[0][node_mask], x_[1][node_mask] = x[0], x[1] x = x_ - return x \ No newline at end of file + return x