Skip to content

Commit e522034

Browse files
committed
Update compatible GNN modules' documentation
1 parent 40bfab1 commit e522034

File tree

1 file changed

+4
-26
lines changed

1 file changed

+4
-26
lines changed

project/utils/deepinteract_modules.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from argparse import ArgumentParser
33
from math import sqrt
44

5-
import atom3.case as ca
65
import dgl
76
import dgl.function as fn
87
import numpy as np
@@ -12,7 +11,6 @@
1211
import torch.nn as nn
1312
import torchmetrics as tm
1413
import wandb
15-
from dgl.nn.pytorch import GATConv
1614
from dgl.nn.pytorch import GraphConv
1715
from torch.nn import functional as F
1816
from torch.optim import AdamW
@@ -1535,7 +1533,6 @@ def __init__(self, num_node_input_feats: int, num_edge_input_feats: int, gnn_act
15351533

15361534
# Set up GNN node and edge embedding layers (if requested)
15371535
self.using_gcn = self.gnn_layer_type.lower() == 'gcn'
1538-
self.using_gat = self.gnn_layer_type.lower() == 'gat'
15391536
self.using_node_embedding = self.num_node_input_feats != self.num_gnn_hidden_channels
15401537
self.node_in_embedding = nn.Linear(self.num_node_input_feats, self.num_gnn_hidden_channels, bias=False) \
15411538
if self.using_node_embedding \
@@ -1586,16 +1583,6 @@ def build_gnn_module(self):
15861583
weight=True,
15871584
activation=None,
15881585
allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)]
1589-
elif self.using_gat:
1590-
gnn_layers = [GATConv(in_feats=num_node_input_feats,
1591-
out_feats=num_node_input_feats,
1592-
num_heads=self.num_gnn_attention_heads,
1593-
feat_drop=0.1,
1594-
attn_drop=0.1,
1595-
negative_slope=0.2,
1596-
residual=False,
1597-
activation=None,
1598-
allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)]
15991586
else: # Default to using a Geometric Transformer for learning node representations
16001587
if self.num_gnn_layers > 0:
16011588
gnn_layers = [DGLGeometricTransformer(node_count_limit=self.max_num_graph_nodes,
@@ -1666,17 +1653,6 @@ def gnn_forward(self, graph: dgl.DGLGraph):
16661653
graph.ndata['f'] = layer(graph, graph.ndata['f'], edge_weight=graph.edata['f'][:, 1]).squeeze()
16671654
# Retain the original batch number of nodes and edges
16681655
graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges)
1669-
elif self.using_gat:
1670-
# Forward propagate with each GNN layer
1671-
for layer in self.gnn_module:
1672-
# Cache the original batch number of nodes and edges
1673-
batch_num_nodes, batch_num_edges = graph.batch_num_nodes(), graph.batch_num_edges()
1674-
if self.num_gnn_attention_heads > 1:
1675-
graph.ndata['f'] = torch.sum(layer(graph, graph.ndata['f']), dim=1) # Sum the attention heads
1676-
else:
1677-
graph.ndata['f'] = layer(graph, graph.ndata['f']).squeeze()
1678-
# Retain the original batch number of nodes and edges
1679-
graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges)
16801656
else: # The GeometricTransformer updates simply by returning a graph containing the updated node/edge feats
16811657
for layer in self.gnn_module:
16821658
graph = layer(graph) # Geometric Transformers can handle their own depth
@@ -2209,14 +2185,16 @@ def add_model_specific_args(parent_parser):
22092185
# -----------------
22102186
parser.add_argument('--gnn_layer_type', type=str, default='geotran',
22112187
help='Which type of GNN layer to use'
2212-
' (i.e. gat for DGLGATConv or geotran for DGLGeometricTransformer)')
2188+
' (i.e., gcn for GraphConv,'
2189+
' geotran w/ --disable_geometric_mode for DGLGraphTransformer,'
2190+
' or geotran for DGLGeometricTransformer)')
22132191
parser.add_argument('--num_gnn_hidden_channels', type=int, default=128,
22142192
help='Dimensionality of GNN filters (for nodes and edges alike after embedding)')
22152193
parser.add_argument('--num_gnn_attention_heads', type=int, default=4,
22162194
help='How many multi-head GNN attention blocks to run in parallel')
22172195
parser.add_argument('--interact_module_type', type=str, default='dil_resnet',
22182196
help='Which type of dense prediction interaction module to use'
2219-
' (i.e. dil_resnet for Dilated ResNet, or deeplab for DeepLabV3Plus)')
2197+
' (i.e. dil_resnet for ResNet2DInputWithOptAttention, or deeplab for DeepLabV3Plus)')
22202198
parser.add_argument('--num_interact_hidden_channels', type=int, default=128,
22212199
help='Dimensionality of interaction module filters')
22222200
parser.add_argument('--use_interact_attention', action='store_true', dest='use_interact_attention',

0 commit comments

Comments
 (0)