Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In graph_network.py of gns module, it seems that the class InteractionNetwork does not update edge features. #92

Open
Yanghuoshan opened this issue Dec 3, 2024 · 4 comments

Comments

@Yanghuoshan
Copy link

Describe the bug
I checked the code of the class InteractionNetwork in graph_network.py and found that in the message function, the new edge_features are not used to update the original ones but are directly returned. This results in the update function still having the initial tensor for edge_features, causing the residual connection to simply double the original tensor.

class InteractionNetwork(MessagePassing):
  def __init__(
      self,
      nnode_in: int,
      nnode_out: int,
      nedge_in: int,
      nedge_out: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
  ):
    # Aggregate features from neighbors
    super(InteractionNetwork, self).__init__(aggr='add')
    # Node MLP
    self.node_fn = nn.Sequential(*[build_mlp(nnode_in + nedge_out,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nnode_out),
                                   nn.LayerNorm(nnode_out)])
    # Edge MLP
    self.edge_fn = nn.Sequential(*[build_mlp(nnode_in + nnode_in + nedge_in,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nedge_out),
                                   nn.LayerNorm(nedge_out)])

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor):
 
    # Save particle state and edge features
    x_residual = x
    edge_features_residual = edge_features
    # Start propagating messages.
    # Takes in the edge indices and all additional data which is needed to
    # construct messages and to update node embeddings.
    x, edge_features = self.propagate(
        edge_index=edge_index, x=x, edge_features=edge_features)

    return x + x_residual, edge_features + edge_features_residual

  def message(self,
              x_i: torch.tensor,
              x_j: torch.tensor,
              edge_features: torch.tensor) -> torch.tensor:
    # Concat edge features with a final shape of [nedges, latent_dim*3]
    edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
    edge_features = self.edge_fn(edge_features)        <-- Here is the question. At line 198 in graph_network.py
    return edge_features

  def update(self,
             x_updated: torch.tensor,
             x: torch.tensor,
             edge_features: torch.tensor):      <--  Edge_features are still original ones
    
    # Concat node features with a final shape of
    # [nparticles, latent_dim (or nnode_in) *2]
    x_updated = torch.cat([x_updated, x], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, edge_features

To Reproduce
I instantiated this class separately to verify the issue. The code is as follows:

from gns.graph_network import *
import torch
from torch_geometric.data import Data
simulator = InteractionNetwork(
    nnode_in= 2,
    nnode_out= 2,
    nedge_in= 2,
    nedge_out= 2,
    nmlp_layers= 2,
    mlp_hidden_dim= 2
)
edge_index = torch.tensor([[0, 1],
                           [1, 0]], dtype=torch.long)
x = torch.tensor([[1,1], [2,2]], dtype=torch.float)
edge_attr = torch.tensor([[1,1], [2,2]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(edge_attr)
print(simulator(x=x,edge_index=edge_index,edge_features=edge_attr)[1])

The outputs are as follows:

tensor([[1., 1.],
        [2., 2.]])
tensor([[2., 2.],
        [4., 4.]])

Expected behavior
Maybe you can use a member variable to store this tensor. Such as:

  def message(self,
              x_i: torch.tensor,
              x_j: torch.tensor,
              edge_features: torch.tensor) -> torch.tensor:
 
    # Concat edge features with a final shape of [nedges, latent_dim*3]
    edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
    edge_features = self.edge_fn(edge_features)
    self.new_edge_features = edge_features
    return edge_features

  def update(self,
             x_updated: torch.tensor,
             x: torch.tensor,
             edge_features: torch.tensor):
    
    # Concat node features with a final shape of
    # [nparticles, latent_dim (or nnode_in) *2]
    x_updated = torch.cat([x_updated, x], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, self.new_edge_features

Additional context
Maybe the code is correct while I missed something, or I misunderstood the formulas in the paper. I would greatly appreciate it if you could respond as soon as possible.

@yjchoi1
Copy link
Collaborator

yjchoi1 commented Dec 3, 2024

Thank you for leaving the comment. The edge feature is first updated in the message function. The update function does not update the edge feature, but takes in the updated edge feature computed from the message function, and returns itself.

@Yanghuoshan
Copy link
Author

Yanghuoshan commented Dec 4, 2024

Thank you for leaving the comment. The edge feature is first updated in the message function. The update function does not update the edge feature, but takes in the updated edge feature computed from the message function, and returns itself.

Thx for your reply. However, the update function does not seem to take in the updated edge features but instead uses the initial edge features. According to the PyG documentation, these edge features are the ones initially passed to the propagate function, not the updated edge features computed from the message function. If it is as you said, could you please explain this?

@yjchoi1
Copy link
Collaborator

yjchoi1 commented Dec 4, 2024

I found this in the PyG documentation. Based on their documentation, the propagate() first calls message() which takes in any argument as input which was initially passed to propagate. After the message() constructs the message, which is essentially the updated edge feature, aggregate() takes in the output of message computation. update() takes in the output of aggregation as first argument and any argument which was initially passed to propagate(). The update() function uses the output computed from message() (I refer to message_passing.py). I hope this helps. I will also double-check the part you pointed out where the edge feature doubles.

@Yanghuoshan
Copy link
Author

Yanghuoshan commented Dec 4, 2024

Thx! I found this issue during the process of stepping into the debug of the class InteractionNetwork, i.e., the edge features in the update() are the original ones. Maybe this can help you. Thanks again for your reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants