-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathmodel.py
155 lines (141 loc) · 6.79 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
import tensorflow as tf
try:
from tensorflow.python import ipu
except ImportError:
pass
from absl import flags
import layers
from data_utils.data_generators import EDGE_FEATURE_DIMS, NODE_FEATURE_DIMS
flags.DEFINE_enum("dtype", "float16", ("float16", "float32"), "data dtype")
flags.DEFINE_integer("n_latent", 300, "number of latent units in the network")
flags.DEFINE_integer("n_hidden", 600, "dimensionality for the hidden MLP layers")
flags.DEFINE_integer("n_mlp_layers", 2, "total number of layers in the MLPs (including output)")
flags.DEFINE_integer("n_embedding_channels", 100, "how many channels to use for the input embeddings")
flags.DEFINE_integer("n_graph_layers", 5, "how many message-passing steps in the model")
flags.DEFINE_integer("replicas", 1, "The number of replicas to scale the model over.")
flags.DEFINE_enum("opt", "adam", ("SGD", "adam"), "which optimizer to use")
flags.DEFINE_float("nodes_dropout", 0.0, "dropout for nodes")
flags.DEFINE_float("edges_dropout", 0.0, "dropout for edges")
flags.DEFINE_float("globals_dropout", 0.0, "dropout for globals")
flags.DEFINE_boolean("use_edges", True, "use edges in GIN")
flags.DEFINE_enum(
"model", "graph_isomorphism", ("graph_network", "interaction_network", "graph_isomorphism"), help="model to use"
)
FLAGS = flags.FLAGS
def get_default_mlp(activate_final, name=None):
return layers.MLP(
n_layers=FLAGS.n_mlp_layers,
n_hidden=FLAGS.n_hidden,
n_out=FLAGS.n_latent,
activate_final=activate_final,
name=name,
)
def create_model():
"""
creates the GNN model
params:
bs -- batch size
dtype -- data type
eval_mode: used for checking correctness of the model function using
pretrained weights
"""
# nodes, edges, receivers, senders, node_graph_idx, edge_graph_idx = graph
inputs_list = [
# inputs are categorical features
tf.keras.Input((FLAGS.n_nodes_per_pack, NODE_FEATURE_DIMS), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
tf.keras.Input((FLAGS.n_edges_per_pack, EDGE_FEATURE_DIMS), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
]
if FLAGS.gather_scatter == "dense":
inputs_list.extend(
[
# receivers
tf.keras.Input(
(FLAGS.n_edges_per_pack, FLAGS.n_nodes_per_pack), dtype=tf.int32, batch_size=FLAGS.micro_batch_size
),
# senders
tf.keras.Input(
(FLAGS.n_edges_per_pack, FLAGS.n_nodes_per_pack), dtype=tf.int32, batch_size=FLAGS.micro_batch_size
),
# node graph idx
tf.keras.Input(
(FLAGS.n_nodes_per_pack, FLAGS.n_graphs_per_pack), dtype=tf.int32, batch_size=FLAGS.micro_batch_size
),
]
)
else:
inputs_list.extend(
[
# receivers
tf.keras.Input((FLAGS.n_edges_per_pack,), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
# senders
tf.keras.Input((FLAGS.n_edges_per_pack,), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
# node graph idx
tf.keras.Input((FLAGS.n_nodes_per_pack,), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
]
)
inputs_list.extend(
[
# edge graph idx
tf.keras.Input((FLAGS.n_edges_per_pack,), dtype=tf.int32, batch_size=FLAGS.micro_batch_size),
]
)
if FLAGS.model == "graph_network":
x = layers.EncoderLayer(
edge_model_fn=lambda: get_default_mlp(activate_final=False, name="edge_encoder"),
node_model_fn=lambda: get_default_mlp(activate_final=False, name="node_encoder"),
)(inputs_list)
for i in range(FLAGS.n_graph_layers):
x = layers.GraphNetworkLayer(
edge_model_fn=lambda: get_default_mlp(activate_final=False, name="edge"),
node_model_fn=lambda: get_default_mlp(activate_final=False, name="node"),
global_model_fn=lambda: get_default_mlp(activate_final=False, name="global"),
# eval mode -- load all the weights, including the redundant last global layer
nodes_dropout=FLAGS.nodes_dropout,
edges_dropout=FLAGS.edges_dropout,
globals_dropout=FLAGS.globals_dropout,
)(x)
output_logits = layers.DecoderLayer(
global_model_fn=lambda: layers.MLP(
n_layers=3, n_hidden=FLAGS.n_hidden, n_out=1, activate_final=False, name="output_logits"
)
)(x)
# dummy dim needed -- see
# https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
output_prob = tf.nn.sigmoid(output_logits)
return tf.keras.Model(inputs_list, output_prob)
elif FLAGS.model == "interaction_network":
x = layers.EncoderLayer(
edge_model_fn=lambda: get_default_mlp(activate_final=False, name="edge_encoder"),
node_model_fn=lambda: get_default_mlp(activate_final=False, name="node_encoder"),
)(inputs_list)
for i in range(FLAGS.n_graph_layers):
x = layers.InteractionNetworkLayer(
edge_model_fn=lambda: get_default_mlp(activate_final=False, name="edge"),
node_model_fn=lambda: get_default_mlp(activate_final=False, name="node"),
nodes_dropout=FLAGS.nodes_dropout,
edges_dropout=FLAGS.edges_dropout,
)(x)
output_logits = layers.DecoderLayer(
global_model_fn=lambda: layers.MLP(
n_layers=3, n_hidden=FLAGS.n_hidden, n_out=1, activate_final=False, name="output_logits"
)
)(x)
# dummy dim needed -- see
# https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
output_prob = tf.nn.sigmoid(output_logits)
return tf.keras.Model(inputs_list, output_prob)
elif FLAGS.model == "graph_isomorphism":
graph_tuple = layers.GinEncoderLayer()(inputs_list)
for i in range(FLAGS.n_graph_layers):
graph_tuple = layers.GraphIsomorphismLayer(
# final layer before output decoder is NOT activated
get_mlp=lambda: get_default_mlp(name="GIN_mlp", activate_final=i < FLAGS.n_graph_layers),
use_edges=FLAGS.use_edges,
# edge embedding dimensionality must match the input to the layer
edge_dim=FLAGS.n_latent if i > 0 else FLAGS.n_embedding_channels,
dropout=FLAGS.nodes_dropout,
)(graph_tuple)
output_prob = layers.GinDecoderLayer()(graph_tuple)
output_prob = tf.nn.sigmoid(output_prob)
return tf.keras.Model(inputs_list, output_prob)