From 5a4fc440e65811c0f42b9b3b8793e25600025f5d Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Fri, 21 Mar 2025 10:34:52 +0100 Subject: [PATCH 1/6] add buggy egnn block --- .../model/block/message_passing/egnn_block.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 pina/model/block/message_passing/egnn_block.py diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py new file mode 100644 index 000000000..8154aeb8e --- /dev/null +++ b/pina/model/block/message_passing/egnn_block.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree +from ....utils import check_consistency + + +class EnEquivariantGraphBlock(MessagePassing): + def __init__(self, + channels_h, + channels_m, + channels_a, + aggr: str = 'add', + hidden_channels: int = 64, + **kwargs): + super().__init__(aggr=aggr, **kwargs) + + self.phi_e = nn.Sequential( + nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, channels_m), + nn.LayerNorm(channels_m), + nn.SiLU() + ) + self.phi_x = nn.Sequential( + nn.Linear(channels_m, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, 1), + ) + self.phi_h = nn.Sequential( + nn.Linear(channels_h + channels_m, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, channels_h), + ) + + def forward(self, x, h, edge_attr, edge_index, c=None): + if c is None: + c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) + return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) + + def message(self, x_i, x_j, h_i, h_j, edge_attr): + mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) + mx_ij = (x_i - x_j) * self.phi_x(mh_ij) + return torch.cat((mx_ij, mh_ij), dim=-1) + + def update(self, aggr_out, x, h, edge_attr, c): + m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] + h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) + x_l1 = x + (m_x / c) + return x_l1, h_l1 + + @property + def edge_function(self): + return self._edge_function + + @property + def attribute_function(self): + return self._attribute_function From a7c8c35b7254059ad40fbed6e5c5307c1c2888ef Mon Sep 17 00:00:00 2001 From: giovanni Date: Wed, 9 Apr 2025 15:10:40 +0200 Subject: [PATCH 2/6] add deep tensor network block --- pina/model/block/message_passing/__init__.py | 9 ++ .../deep_tensor_network_block.py | 128 ++++++++++++++++++ .../model/block/message_passing/egnn_block.py | 96 +++++++++---- .../interaction_network_block.py | 10 ++ 4 files changed, 213 insertions(+), 30 deletions(-) create mode 100644 pina/model/block/message_passing/__init__.py create mode 100644 pina/model/block/message_passing/deep_tensor_network_block.py create mode 100644 pina/model/block/message_passing/interaction_network_block.py diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py new file mode 100644 index 000000000..a4b122016 --- /dev/null +++ b/pina/model/block/message_passing/__init__.py @@ -0,0 +1,9 @@ +"""Module for the message passing blocks of the graph neural models.""" + +__all__ = [ + "InteractionNetworkBlock", + "DeepTensorNetworkBlock", +] + +from .interaction_network_block import InteractionNetworkBlock +from .deep_tensor_network_block import DeepTensorNetworkBlock diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py new file mode 100644 index 000000000..fe48d8e13 --- /dev/null +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -0,0 +1,128 @@ +"""Module for the Deep Tensor Network block.""" + +import torch +from torch_geometric.nn import MessagePassing + + +class DeepTensorNetworkBlock(MessagePassing): + """ + Implementation of the Deep Tensor Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schutt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. + *Quantum-Chemical Insights from Deep Tensor Neural Networks*. + Nature Communications 8, 13890 (2017). + DOI: `_` + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + activation=torch.nn.Tanh, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "source_to_target". + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.activation = activation + + # Layer for processing node features + self.node_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for processing edge features + self.edge_layer = torch.nn.Linear( + in_features=self.edge_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for computing the message + self.message_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=False, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block. It performs a message-passing operation + between nodes and edges. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + # Process node and edge features + filter_node = self.node_layer(x_j) + filter_edge = self.edge_layer(edge_attr) + + # Compute the message to be passed + message = self.message_layer(filter_node * filter_edge) + + return self.activation(message) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py index 8154aeb8e..7c137ac0e 100644 --- a/pina/model/block/message_passing/egnn_block.py +++ b/pina/model/block/message_passing/egnn_block.py @@ -1,61 +1,97 @@ +"""Module for the E(n) Equivariant Graph Neural Network block.""" + import torch -import torch.nn as nn from torch_geometric.nn import MessagePassing from torch_geometric.utils import degree -from ....utils import check_consistency class EnEquivariantGraphBlock(MessagePassing): - def __init__(self, - channels_h, - channels_m, - channels_a, - aggr: str = 'add', - hidden_channels: int = 64, - **kwargs): + """ + TODO + """ + + def __init__( + self, + channels_h, + channels_m, + channels_a, + aggr: str = "add", + hidden_channels: int = 64, + **kwargs, + ): + """ + TODO + """ super().__init__(aggr=aggr, **kwargs) - self.phi_e = nn.Sequential( - nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, channels_m), - nn.LayerNorm(channels_m), - nn.SiLU() + self.phi_e = torch.nn.Sequential( + torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_m), + torch.nn.LayerNorm(channels_m), + torch.nn.SiLU(), ) - self.phi_x = nn.Sequential( - nn.Linear(channels_m, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, 1), + self.phi_x = torch.nn.Sequential( + torch.nn.Linear(channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, 1), + ) + self.phi_h = torch.nn.Sequential( + torch.nn.Linear(channels_h + channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_h), ) - self.phi_h = nn.Sequential( - nn.Linear(channels_h + channels_m, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, channels_h), - ) def forward(self, x, h, edge_attr, edge_index, c=None): + """ + TODO + """ if c is None: c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) - return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) + return self.propagate( + edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c + ) def message(self, x_i, x_j, h_i, h_j, edge_attr): - mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) + """ + TODO + """ + mh_ij = self.phi_e( + torch.cat( + [ + h_i, + h_j, + torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, + edge_attr, + ], + dim=-1, + ) + ) mx_ij = (x_i - x_j) * self.phi_x(mh_ij) return torch.cat((mx_ij, mh_ij), dim=-1) def update(self, aggr_out, x, h, edge_attr, c): - m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] + """ + TODO + """ + m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) x_l1 = x + (m_x / c) return x_l1, h_l1 @property def edge_function(self): + """ + TODO + """ return self._edge_function @property def attribute_function(self): + """ + TODO + """ return self._attribute_function diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py new file mode 100644 index 000000000..44ecccb27 --- /dev/null +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -0,0 +1,10 @@ +"""Module for the Interaction Network block.""" + +import torch +from torch_geometric.nn import MessagePassing + + +class InteractionNetworkBlock(MessagePassing): + """ + TODO + """ From 9a098218e2ff503d870bd85581d3302596ce5843 Mon Sep 17 00:00:00 2001 From: giovanni Date: Wed, 9 Apr 2025 17:22:29 +0200 Subject: [PATCH 3/6] add interaction network block --- .../deep_tensor_network_block.py | 38 ++++- .../interaction_network_block.py | 160 +++++++++++++++++- 2 files changed, 190 insertions(+), 8 deletions(-) diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py index fe48d8e13..950e32f05 100644 --- a/pina/model/block/message_passing/deep_tensor_network_block.py +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -2,6 +2,7 @@ import torch from torch_geometric.nn import MessagePassing +from ....utils import check_consistency class DeepTensorNetworkBlock(MessagePassing): @@ -25,7 +26,7 @@ class DeepTensorNetworkBlock(MessagePassing): **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. *Quantum-Chemical Insights from Deep Tensor Neural Networks*. Nature Communications 8, 13890 (2017). - DOI: `_` + DOI: `_`. """ def __init__( @@ -38,7 +39,7 @@ def __init__( flow="source_to_target", ): """ - Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class. + Initialization of the :class:`DeepTensorNetworkBlocklock` class. :param int node_feature_dim: The dimension of the node features. :param int edge_feature_dim: The dimension of the edge features. @@ -49,12 +50,36 @@ def __init__( See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "add". :param int node_dim: The axis along which to propagate. Default is -2. - :param str flow: The direction of message passing. - See :class:`torch_geometric.nn.MessagePassing` for more details. - Default is "source_to_target". + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(edge_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if edge_feature_dim <= 0: + raise ValueError( + "`edge_feature_dim` must be a positive integer," + f" got {edge_feature_dim}." + ) + + # Initialize parameters self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim self.activation = activation @@ -82,8 +107,7 @@ def __init__( def forward(self, x, edge_index, edge_attr): """ - Forward pass of the block. It performs a message-passing operation - between nodes and edges. + Forward pass of the block, triggering the message-passing routine. :param x: The node features. :type x: torch.Tensor | LabelTensor diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py index 44ecccb27..f27169448 100644 --- a/pina/model/block/message_passing/interaction_network_block.py +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -2,9 +2,167 @@ import torch from torch_geometric.nn import MessagePassing +from ....model import FeedForward +from ....utils import check_consistency class InteractionNetworkBlock(MessagePassing): """ - TODO + Implementation of the Interaction Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Battaglia et al. + (2016). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + multi-layer perceptron (MLP) to the concatenation of the sender and + recipient node features. Messages are then aggregated using an aggregation + scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. + + .. seealso:: + + **Original reference**: Battaglia, P. W., et al. (2016). + *Interaction Networks for Learning about Objects, Relations and + Physics*. + In Advances in Neural Information Processing Systems (NeurIPS 2016). + DOI: `_`. """ + + def __init__( + self, + node_feature_dim, + hidden_dim, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`InteractionNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int hidden_dim: The dimension of the hidden features. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `hidden_dim` is not a positive integer. + :raises ValueError: If `n_message_layers` is not a positive integer. + :raises ValueError: If `n_update_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(hidden_dim, int) + check_consistency(n_message_layers, int) + check_consistency(n_update_layers, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if hidden_dim <= 0: + raise ValueError( + "`hidden_dim` must be a positive integer," f" got {hidden_dim}." + ) + + if n_message_layers <= 0: + raise ValueError( + "`n_message_layers` must be a positive integer," + f" got {n_message_layers}." + ) + + if n_update_layers <= 0: + raise ValueError( + "`n_update_layers` must be a positive integer," + f" got {n_update_layers}." + ) + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Message network + self.message_net = FeedForward( + input_dimensions=2 * self.node_feature_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + # Update network + self.update_net = FeedForward( + input_dimensions=self.node_feature_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.node_feature_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + + # TODO: edge_attr is not used in the message function + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_i, x_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + return self.message_net(torch.cat((x_i, x_j), dim=-1)) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((x, message), dim=-1)) From 9269702847c739fdc8f1add668fcc9c83efe096c Mon Sep 17 00:00:00 2001 From: AleDinve Date: Thu, 24 Apr 2025 12:59:56 -0400 Subject: [PATCH 4/6] radial field --- .../radial_field_network_block.py | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 pina/model/block/message_passing/radial_field_network_block.py diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py new file mode 100644 index 000000000..ebaa42d28 --- /dev/null +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -0,0 +1,142 @@ +"""Module for the Radial Field Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class RadialFieldNetworkBlock(MessagePassing): + """ + Implementation of the Radial Field Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Köhler et al. (2020). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November). + Equivariant flows: exact likelihood generative learning for symmetric densities. + In International conference on machine learning (pp. 5361-5370). PMLR. + """ + + + + def __init__( + self, + node_feature_dim, + hidden_dim, + edge_feature_dim, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(edge_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if edge_feature_dim <= 0: + raise ValueError( + "`edge_feature_dim` must be a positive integer," + f" got {edge_feature_dim}." + ) + + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + self.layer = lambda i,o: torch.nn.Linear( + in_features=i, + out_features=o, + bias=True, + ) + # Layer for processing node features + self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim), + torch.nn.ReLU, + self.layer(self.hidden_dim,1)] + ) + + + def forward(self, x, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x) + + def message(self, x_j, x_i): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Concatenation of the node position and the + node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + r = torch.norm(x_i-x_j)*(x_i-x_j) + + + return self.activation(self.radial_field(r)) + + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message From b6f7c171c831d48d8666e884e65c8b5d8a9b9a64 Mon Sep 17 00:00:00 2001 From: AleDinve Date: Thu, 24 Apr 2025 14:27:55 -0400 Subject: [PATCH 5/6] fix radial field --- pina/model/block/message_passing/radial_field_network_block.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py index ebaa42d28..4f5982fe3 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -105,7 +105,8 @@ def forward(self, x, edge_index): :param x: The node features. :type x: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: The edge indices. + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. :return: The updated node features. :rtype: torch.Tensor """ From 1c6bef4efbc186358161ba7feac06555f9288e4c Mon Sep 17 00:00:00 2001 From: AleDinve Date: Thu, 24 Apr 2025 15:36:58 -0400 Subject: [PATCH 6/6] radial_field fix + schnet block --- .../radial_field_network_block.py | 32 ++-- .../block/message_passing/schnet_block.py | 154 ++++++++++++++++++ 2 files changed, 166 insertions(+), 20 deletions(-) create mode 100644 pina/model/block/message_passing/schnet_block.py diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py index 4f5982fe3..f7d55a948 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -1,6 +1,7 @@ """Module for the Radial Field Network block.""" import torch +from ....model import FeedForward from torch_geometric.nn import MessagePassing from ....utils import check_consistency @@ -34,7 +35,8 @@ def __init__( self, node_feature_dim, hidden_dim, - edge_feature_dim, + radial_hidden_dim=16, + n_radial_layers=2, activation=torch.nn.ReLU, aggr="add", node_dim=-2, @@ -66,7 +68,6 @@ def __init__( # Check consistency check_consistency(node_feature_dim, int) - check_consistency(edge_feature_dim, int) # Check values if node_feature_dim <= 0: @@ -75,27 +76,18 @@ def __init__( f" got {node_feature_dim}." ) - if edge_feature_dim <= 0: - raise ValueError( - "`edge_feature_dim` must be a positive integer," - f" got {edge_feature_dim}." - ) - - # Initialize parameters self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim self.hidden_dim = hidden_dim self.activation = activation - self.layer = lambda i,o: torch.nn.Linear( - in_features=i, - out_features=o, - bias=True, - ) + # Layer for processing node features - self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim), - torch.nn.ReLU, - self.layer(self.hidden_dim,1)] + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, ) @@ -124,10 +116,10 @@ def message(self, x_j, x_i): :return: The message to be passed. :rtype: torch.Tensor """ - r = torch.norm(x_i-x_j)*(x_i-x_j) + r = torch.norm(x_i-x_j) - return self.activation(self.radial_field(r)) + return self.radial_field(r)*(x_i-x_j) def update(self, message, x): diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py new file mode 100644 index 000000000..955fbbe8e --- /dev/null +++ b/pina/model/block/message_passing/schnet_block.py @@ -0,0 +1,154 @@ +"""Module for the Schnet block.""" + +import torch +from ....model import FeedForward +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class SchnetBlock(MessagePassing): + """ + Implementation of the Schnet block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schütt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017). + Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. + Advances in neural information processing systems, 30. + """ + + + + def __init__( + self, + node_feature_dim, + node_pos_dim, + hidden_dim, + radial_hidden_dim=16, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.node_pos_dim = node_pos_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Layer for processing node features + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, + ) + + self.update_net = FeedForward( + input_dimensions=self.node_pos_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + self.message_net = FeedForward( + input_dimensions=self.node_feature_dim, + output_dimensions=self.node_pos_dim + self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + + def forward(self, x, pos, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, pos=pos) + + def message(self, x_i, pos_i ,pos_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Concatenation of the node position and the + node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + + return self.radial_field(torch.norm(pos_i-pos_j))*self.message_net(x_i) + + + def update(self, message, pos): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The concatenation of the update position features and the updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((pos, message), dim=-1))