diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py new file mode 100644 index 000000000..a4b122016 --- /dev/null +++ b/pina/model/block/message_passing/__init__.py @@ -0,0 +1,9 @@ +"""Module for the message passing blocks of the graph neural models.""" + +__all__ = [ + "InteractionNetworkBlock", + "DeepTensorNetworkBlock", +] + +from .interaction_network_block import InteractionNetworkBlock +from .deep_tensor_network_block import DeepTensorNetworkBlock diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py new file mode 100644 index 000000000..950e32f05 --- /dev/null +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -0,0 +1,152 @@ +"""Module for the Deep Tensor Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class DeepTensorNetworkBlock(MessagePassing): + """ + Implementation of the Deep Tensor Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schutt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. + *Quantum-Chemical Insights from Deep Tensor Neural Networks*. + Nature Communications 8, 13890 (2017). + DOI: `_`. + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + activation=torch.nn.Tanh, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`DeepTensorNetworkBlocklock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(edge_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if edge_feature_dim <= 0: + raise ValueError( + "`edge_feature_dim` must be a positive integer," + f" got {edge_feature_dim}." + ) + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.activation = activation + + # Layer for processing node features + self.node_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for processing edge features + self.edge_layer = torch.nn.Linear( + in_features=self.edge_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for computing the message + self.message_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=False, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + # Process node and edge features + filter_node = self.node_layer(x_j) + filter_edge = self.edge_layer(edge_attr) + + # Compute the message to be passed + message = self.message_layer(filter_node * filter_edge) + + return self.activation(message) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py new file mode 100644 index 000000000..b6a605070 --- /dev/null +++ b/pina/model/block/message_passing/egnn_block.py @@ -0,0 +1,137 @@ +"""Module for the E(n) Equivariant Graph Neural Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree + + +class EnEquivariantGraphBlock(MessagePassing): + """ + Implementation of the E(n) Equivariant Graph Neural Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Satorras et al. (2021). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). + E (n) equivariant graph neural networks. + In International conference on machine learning (pp. 9323-9332). PMLR. + """ + + def __init__( + self, + channels_x, + channels_m, + channels_a, + aggr: str = "add", + hidden_channels: int = 64, + **kwargs, + ): + """ + Initialization of the :class:`EnEquivariantGraphBlock` class. + + :param int channels_x: The dimension of the node features. + :param int channels_m: The dimension of the Euclidean coordinates (should be =3). + :param int channels_a: The dimension of the edge features. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block. + """ + super().__init__(aggr=aggr, **kwargs) + + self.phi_e = torch.nn.Sequential( + torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_m), + torch.nn.LayerNorm(channels_m), + torch.nn.SiLU(), + ) + self.phi_pos = torch.nn.Sequential( + torch.nn.Linear(channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, 1), + ) + self.phi_x = torch.nn.Sequential( + torch.nn.Linear(channels_x + channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_x), + ) + + def forward(self, x, pos, edge_attr, edge_index, c=None): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos_i: 3D Euclidean coordinates. + :type pos_i: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + if c is None: + c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) + return self.propagate( + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c + ) + + def message(self, x_i, x_j, pos_i, pos_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: Node features of the sender nodes. + :type x_i: torch.Tensor | LabelTensor + :param pos_i: 3D Euclidean coordinates of the sender nodes. + :type pos_i: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + mpos_ij = self.phi_e( + torch.cat( + [ + x_i, + x_j, + torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2, + edge_attr, + ], + dim=-1, + ) + ) + mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij) + return mpos_ij + + def update(self, message, x, pos, c): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The 3D Euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes) + :type pos: torch.Tensor + :return: The concatenation of the update position features and the updated node features. + :rtype: torch.Tensor + """ + x = self.phi_x(torch.cat([x, message], dim=-1)) + pos = pos + (message / c) + return pos, x diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py new file mode 100644 index 000000000..f27169448 --- /dev/null +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -0,0 +1,168 @@ +"""Module for the Interaction Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....model import FeedForward +from ....utils import check_consistency + + +class InteractionNetworkBlock(MessagePassing): + """ + Implementation of the Interaction Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Battaglia et al. + (2016). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + multi-layer perceptron (MLP) to the concatenation of the sender and + recipient node features. Messages are then aggregated using an aggregation + scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. + + .. seealso:: + + **Original reference**: Battaglia, P. W., et al. (2016). + *Interaction Networks for Learning about Objects, Relations and + Physics*. + In Advances in Neural Information Processing Systems (NeurIPS 2016). + DOI: `_`. + """ + + def __init__( + self, + node_feature_dim, + hidden_dim, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`InteractionNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int hidden_dim: The dimension of the hidden features. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `hidden_dim` is not a positive integer. + :raises ValueError: If `n_message_layers` is not a positive integer. + :raises ValueError: If `n_update_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(hidden_dim, int) + check_consistency(n_message_layers, int) + check_consistency(n_update_layers, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if hidden_dim <= 0: + raise ValueError( + "`hidden_dim` must be a positive integer," f" got {hidden_dim}." + ) + + if n_message_layers <= 0: + raise ValueError( + "`n_message_layers` must be a positive integer," + f" got {n_message_layers}." + ) + + if n_update_layers <= 0: + raise ValueError( + "`n_update_layers` must be a positive integer," + f" got {n_update_layers}." + ) + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Message network + self.message_net = FeedForward( + input_dimensions=2 * self.node_feature_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + # Update network + self.update_net = FeedForward( + input_dimensions=self.node_feature_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.node_feature_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + + # TODO: edge_attr is not used in the message function + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_i, x_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + return self.message_net(torch.cat((x_i, x_j), dim=-1)) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((x, message), dim=-1)) diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py new file mode 100644 index 000000000..0d3257d48 --- /dev/null +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -0,0 +1,133 @@ +"""Module for the Radial Field Network block.""" + +import torch +from ....model import FeedForward +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class RadialFieldNetworkBlock(MessagePassing): + """ + Implementation of the Radial Field Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Köhler et al. (2020). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November). + Equivariant flows: exact likelihood generative learning for symmetric densities. + In International conference on machine learning (pp. 5361-5370). PMLR. + """ + + + + def __init__( + self, + node_feature_dim, + hidden_dim, + radial_hidden_dim=16, + n_radial_layers=2, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Layer for processing node features + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, + ) + + + def forward(self, x, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x) + + def message(self, x_j, x_i): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + r = torch.norm(x_i-x_j) + + + return self.radial_field(r)*(x_i-x_j) + + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py new file mode 100644 index 000000000..7ee2b129c --- /dev/null +++ b/pina/model/block/message_passing/schnet_block.py @@ -0,0 +1,153 @@ +"""Module for the Schnet block.""" + +import torch +from ....model import FeedForward +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class SchnetBlock(MessagePassing): + """ + Implementation of the Schnet block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schütt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017). + Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. + Advances in neural information processing systems, 30. + """ + + + + def __init__( + self, + node_feature_dim, + node_pos_dim, + hidden_dim, + radial_hidden_dim=16, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`SchnetBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.node_pos_dim = node_pos_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Layer for processing node features + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, + ) + + self.update_net = FeedForward( + input_dimensions=self.node_pos_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + self.message_net = FeedForward( + input_dimensions=self.node_feature_dim, + output_dimensions=self.node_pos_dim + self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + + def forward(self, x, pos, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, pos=pos) + + def message(self, x_i, pos_i ,pos_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + + return self.radial_field(torch.norm(pos_i-pos_j))*self.message_net(x_i) + + + def update(self, message, pos): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The concatenation of the update position features and the updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((pos, message), dim=-1))