|
2 | 2 | from argparse import ArgumentParser |
3 | 3 | from math import sqrt |
4 | 4 |
|
5 | | -import atom3.case as ca |
6 | 5 | import dgl |
7 | 6 | import dgl.function as fn |
8 | 7 | import numpy as np |
|
12 | 11 | import torch.nn as nn |
13 | 12 | import torchmetrics as tm |
14 | 13 | import wandb |
15 | | -from dgl.nn.pytorch import GATConv |
16 | 14 | from dgl.nn.pytorch import GraphConv |
17 | 15 | from torch.nn import functional as F |
18 | 16 | from torch.optim import AdamW |
@@ -1535,7 +1533,6 @@ def __init__(self, num_node_input_feats: int, num_edge_input_feats: int, gnn_act |
1535 | 1533 |
|
1536 | 1534 | # Set up GNN node and edge embedding layers (if requested) |
1537 | 1535 | self.using_gcn = self.gnn_layer_type.lower() == 'gcn' |
1538 | | - self.using_gat = self.gnn_layer_type.lower() == 'gat' |
1539 | 1536 | self.using_node_embedding = self.num_node_input_feats != self.num_gnn_hidden_channels |
1540 | 1537 | self.node_in_embedding = nn.Linear(self.num_node_input_feats, self.num_gnn_hidden_channels, bias=False) \ |
1541 | 1538 | if self.using_node_embedding \ |
@@ -1586,16 +1583,6 @@ def build_gnn_module(self): |
1586 | 1583 | weight=True, |
1587 | 1584 | activation=None, |
1588 | 1585 | allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)] |
1589 | | - elif self.using_gat: |
1590 | | - gnn_layers = [GATConv(in_feats=num_node_input_feats, |
1591 | | - out_feats=num_node_input_feats, |
1592 | | - num_heads=self.num_gnn_attention_heads, |
1593 | | - feat_drop=0.1, |
1594 | | - attn_drop=0.1, |
1595 | | - negative_slope=0.2, |
1596 | | - residual=False, |
1597 | | - activation=None, |
1598 | | - allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)] |
1599 | 1586 | else: # Default to using a Geometric Transformer for learning node representations |
1600 | 1587 | if self.num_gnn_layers > 0: |
1601 | 1588 | gnn_layers = [DGLGeometricTransformer(node_count_limit=self.max_num_graph_nodes, |
@@ -1666,17 +1653,6 @@ def gnn_forward(self, graph: dgl.DGLGraph): |
1666 | 1653 | graph.ndata['f'] = layer(graph, graph.ndata['f'], edge_weight=graph.edata['f'][:, 1]).squeeze() |
1667 | 1654 | # Retain the original batch number of nodes and edges |
1668 | 1655 | graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges) |
1669 | | - elif self.using_gat: |
1670 | | - # Forward propagate with each GNN layer |
1671 | | - for layer in self.gnn_module: |
1672 | | - # Cache the original batch number of nodes and edges |
1673 | | - batch_num_nodes, batch_num_edges = graph.batch_num_nodes(), graph.batch_num_edges() |
1674 | | - if self.num_gnn_attention_heads > 1: |
1675 | | - graph.ndata['f'] = torch.sum(layer(graph, graph.ndata['f']), dim=1) # Sum the attention heads |
1676 | | - else: |
1677 | | - graph.ndata['f'] = layer(graph, graph.ndata['f']).squeeze() |
1678 | | - # Retain the original batch number of nodes and edges |
1679 | | - graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges) |
1680 | 1656 | else: # The GeometricTransformer updates simply by returning a graph containing the updated node/edge feats |
1681 | 1657 | for layer in self.gnn_module: |
1682 | 1658 | graph = layer(graph) # Geometric Transformers can handle their own depth |
@@ -2209,14 +2185,16 @@ def add_model_specific_args(parent_parser): |
2209 | 2185 | # ----------------- |
2210 | 2186 | parser.add_argument('--gnn_layer_type', type=str, default='geotran', |
2211 | 2187 | help='Which type of GNN layer to use' |
2212 | | - ' (i.e. gat for DGLGATConv or geotran for DGLGeometricTransformer)') |
| 2188 | + ' (i.e., gcn for GraphConv,' |
| 2189 | + ' geotran w/ --disable_geometric_mode for DGLGraphTransformer,' |
| 2190 | + ' or geotran for DGLGeometricTransformer)') |
2213 | 2191 | parser.add_argument('--num_gnn_hidden_channels', type=int, default=128, |
2214 | 2192 | help='Dimensionality of GNN filters (for nodes and edges alike after embedding)') |
2215 | 2193 | parser.add_argument('--num_gnn_attention_heads', type=int, default=4, |
2216 | 2194 | help='How many multi-head GNN attention blocks to run in parallel') |
2217 | 2195 | parser.add_argument('--interact_module_type', type=str, default='dil_resnet', |
2218 | 2196 | help='Which type of dense prediction interaction module to use' |
2219 | | - ' (i.e. dil_resnet for Dilated ResNet, or deeplab for DeepLabV3Plus)') |
| 2197 | + ' (i.e. dil_resnet for ResNet2DInputWithOptAttention, or deeplab for DeepLabV3Plus)') |
2220 | 2198 | parser.add_argument('--num_interact_hidden_channels', type=int, default=128, |
2221 | 2199 | help='Dimensionality of interaction module filters') |
2222 | 2200 | parser.add_argument('--use_interact_attention', action='store_true', dest='use_interact_attention', |
|
0 commit comments