|
2 | 2 | from argparse import ArgumentParser
|
3 | 3 | from math import sqrt
|
4 | 4 |
|
5 |
| -import atom3.case as ca |
6 | 5 | import dgl
|
7 | 6 | import dgl.function as fn
|
8 | 7 | import numpy as np
|
|
12 | 11 | import torch.nn as nn
|
13 | 12 | import torchmetrics as tm
|
14 | 13 | import wandb
|
15 |
| -from dgl.nn.pytorch import GATConv |
16 | 14 | from dgl.nn.pytorch import GraphConv
|
17 | 15 | from torch.nn import functional as F
|
18 | 16 | from torch.optim import AdamW
|
@@ -1535,7 +1533,6 @@ def __init__(self, num_node_input_feats: int, num_edge_input_feats: int, gnn_act
|
1535 | 1533 |
|
1536 | 1534 | # Set up GNN node and edge embedding layers (if requested)
|
1537 | 1535 | self.using_gcn = self.gnn_layer_type.lower() == 'gcn'
|
1538 |
| - self.using_gat = self.gnn_layer_type.lower() == 'gat' |
1539 | 1536 | self.using_node_embedding = self.num_node_input_feats != self.num_gnn_hidden_channels
|
1540 | 1537 | self.node_in_embedding = nn.Linear(self.num_node_input_feats, self.num_gnn_hidden_channels, bias=False) \
|
1541 | 1538 | if self.using_node_embedding \
|
@@ -1586,16 +1583,6 @@ def build_gnn_module(self):
|
1586 | 1583 | weight=True,
|
1587 | 1584 | activation=None,
|
1588 | 1585 | allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)]
|
1589 |
| - elif self.using_gat: |
1590 |
| - gnn_layers = [GATConv(in_feats=num_node_input_feats, |
1591 |
| - out_feats=num_node_input_feats, |
1592 |
| - num_heads=self.num_gnn_attention_heads, |
1593 |
| - feat_drop=0.1, |
1594 |
| - attn_drop=0.1, |
1595 |
| - negative_slope=0.2, |
1596 |
| - residual=False, |
1597 |
| - activation=None, |
1598 |
| - allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)] |
1599 | 1586 | else: # Default to using a Geometric Transformer for learning node representations
|
1600 | 1587 | if self.num_gnn_layers > 0:
|
1601 | 1588 | gnn_layers = [DGLGeometricTransformer(node_count_limit=self.max_num_graph_nodes,
|
@@ -1666,17 +1653,6 @@ def gnn_forward(self, graph: dgl.DGLGraph):
|
1666 | 1653 | graph.ndata['f'] = layer(graph, graph.ndata['f'], edge_weight=graph.edata['f'][:, 1]).squeeze()
|
1667 | 1654 | # Retain the original batch number of nodes and edges
|
1668 | 1655 | graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges)
|
1669 |
| - elif self.using_gat: |
1670 |
| - # Forward propagate with each GNN layer |
1671 |
| - for layer in self.gnn_module: |
1672 |
| - # Cache the original batch number of nodes and edges |
1673 |
| - batch_num_nodes, batch_num_edges = graph.batch_num_nodes(), graph.batch_num_edges() |
1674 |
| - if self.num_gnn_attention_heads > 1: |
1675 |
| - graph.ndata['f'] = torch.sum(layer(graph, graph.ndata['f']), dim=1) # Sum the attention heads |
1676 |
| - else: |
1677 |
| - graph.ndata['f'] = layer(graph, graph.ndata['f']).squeeze() |
1678 |
| - # Retain the original batch number of nodes and edges |
1679 |
| - graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges) |
1680 | 1656 | else: # The GeometricTransformer updates simply by returning a graph containing the updated node/edge feats
|
1681 | 1657 | for layer in self.gnn_module:
|
1682 | 1658 | graph = layer(graph) # Geometric Transformers can handle their own depth
|
@@ -2209,14 +2185,16 @@ def add_model_specific_args(parent_parser):
|
2209 | 2185 | # -----------------
|
2210 | 2186 | parser.add_argument('--gnn_layer_type', type=str, default='geotran',
|
2211 | 2187 | help='Which type of GNN layer to use'
|
2212 |
| - ' (i.e. gat for DGLGATConv or geotran for DGLGeometricTransformer)') |
| 2188 | + ' (i.e., gcn for GraphConv,' |
| 2189 | + ' geotran w/ --disable_geometric_mode for DGLGraphTransformer,' |
| 2190 | + ' or geotran for DGLGeometricTransformer)') |
2213 | 2191 | parser.add_argument('--num_gnn_hidden_channels', type=int, default=128,
|
2214 | 2192 | help='Dimensionality of GNN filters (for nodes and edges alike after embedding)')
|
2215 | 2193 | parser.add_argument('--num_gnn_attention_heads', type=int, default=4,
|
2216 | 2194 | help='How many multi-head GNN attention blocks to run in parallel')
|
2217 | 2195 | parser.add_argument('--interact_module_type', type=str, default='dil_resnet',
|
2218 | 2196 | help='Which type of dense prediction interaction module to use'
|
2219 |
| - ' (i.e. dil_resnet for Dilated ResNet, or deeplab for DeepLabV3Plus)') |
| 2197 | + ' (i.e. dil_resnet for ResNet2DInputWithOptAttention, or deeplab for DeepLabV3Plus)') |
2220 | 2198 | parser.add_argument('--num_interact_hidden_channels', type=int, default=128,
|
2221 | 2199 | help='Dimensionality of interaction module filters')
|
2222 | 2200 | parser.add_argument('--use_interact_attention', action='store_true', dest='use_interact_attention',
|
|
0 commit comments