3
3
4
4
from spektral .data .utils import (
5
5
batch_generator ,
6
+ collate_labels_batch ,
6
7
collate_labels_disjoint ,
7
8
get_spec ,
8
9
prepend_none ,
@@ -78,10 +79,10 @@ def train_step(inputs, target):
78
79
**Arguments**
79
80
80
81
- `dataset`: a `spektral.data.Dataset` object;
81
- - `batch_size`: size of the mini-batches;
82
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
82
+ - `batch_size`: int, size of the mini-batches;
83
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
83
84
iterates indefinitely;
84
- - `shuffle`: whether to shuffle the dataset at the start of each epoch.
85
+ - `shuffle`: bool, whether to shuffle the dataset at the start of each epoch.
85
86
"""
86
87
87
88
def __init__ (self , dataset , batch_size = 1 , epochs = None , shuffle = True ):
@@ -178,11 +179,10 @@ class SingleLoader(Loader):
178
179
**Arguments**
179
180
180
181
- `dataset`: a `spektral.data.Dataset` object with only one graph;
181
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
182
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
182
183
iterates indefinitely;
183
- - `shuffle`: whether to shuffle the data at the start of each epoch;
184
- - `sample_weights`: if given, these will be appended to the output
185
- automatically.
184
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
185
+ - `sample_weights`: Numpy array, will be appended to the output automatically.
186
186
187
187
**Output**
188
188
@@ -197,9 +197,8 @@ class SingleLoader(Loader):
197
197
- `e`: same as `dataset[0].e`;
198
198
199
199
`labels` is the same as `dataset[0].y`.
200
- `sample_weights` is the same object passed to the constructor.
201
-
202
200
201
+ `sample_weights` is the same array passed when creating the loader.
203
202
"""
204
203
205
204
def __init__ (self , dataset , epochs = None , sample_weights = None ):
@@ -262,6 +261,8 @@ class DisjointLoader(Loader):
262
261
**Arguments**
263
262
264
263
- `dataset`: a graph Dataset;
264
+ - `node_level`: bool, if `True` stack the labels vertically for node-level
265
+ prediction;
265
266
- `batch_size`: size of the mini-batches;
266
267
- `epochs`: number of epochs to iterate over the dataset. By default (`None`)
267
268
iterates indefinitely;
@@ -321,7 +322,7 @@ def tf_signature(self):
321
322
Adjacency matrix has shape [n_nodes, n_nodes]
322
323
Node features have shape [n_nodes, n_node_features]
323
324
Edge features have shape [n_edges, n_edge_features]
324
- Targets have shape [... , n_labels]
325
+ Targets have shape [* , n_labels]
325
326
"""
326
327
signature = self .dataset .signature
327
328
if "y" in signature :
@@ -347,33 +348,40 @@ class BatchLoader(Loader):
347
348
If `n_max` is the number of nodes of the biggest graph in the batch, then
348
349
the padding consist of adding zeros to the node features, adjacency matrix,
349
350
and edge attributes of each graph so that they have shapes
350
- `( n_max, n_node_features) `, `( n_max, n_max) `, and
351
- `( n_max, n_max, n_edge_features) ` respectively.
351
+ `[ n_max, n_node_features] `, `[ n_max, n_max] `, and
352
+ `[ n_max, n_max, n_edge_features] ` respectively.
352
353
353
354
The zero-padding is done batch-wise, which saves up memory at the cost of
354
355
more computation. If latency is an issue but memory isn't, or if the
355
356
dataset has graphs with a similar number of nodes, you can use
356
- the `PackedBatchLoader` that first zero-pads all the dataset and then
357
+ the `PackedBatchLoader` that zero-pads all the dataset once and then
357
358
iterates over it.
358
359
359
360
Note that the adjacency matrix and edge attributes are returned as dense
360
- arrays (mostly due to the lack of support for sparse tensor operations for
361
- rank >2).
361
+ arrays.
362
362
363
- Only graph-level labels are supported with this loader (i.e., labels are not
364
- zero-padded because they are assumed to have no "node" dimensions).
363
+ if `mask=True`, node attributes will be extended with a binary mask that indicates
364
+ valid nodes (the last feature of each node will be 1 if the node was originally in
365
+ the graph and 0 if it is a fake node added by zero-padding).
366
+
367
+ Use this flag in conjunction with layers.base.GraphMasking to start the propagation
368
+ of masks in a model (necessary for node-level prediction and models that use a
369
+ dense pooling layer like DiffPool or MinCutPool).
370
+
371
+ If `node_level=False`, the labels are interpreted as graph-level labels and
372
+ are returned as an array of shape `[batch, n_labels]`.
373
+ If `node_level=True`, then the labels are padded along the node dimension and are
374
+ returned as an array of shape `[batch, n_max, n_labels]`.
365
375
366
376
**Arguments**
367
377
368
378
- `dataset`: a graph Dataset;
369
- - `mask`: if True, node attributes will be extended with a binary mask that
370
- indicates valid nodes (the last feature of each node will be 1 if the node is valid
371
- and 0 otherwise). Use this flag in conjunction with layers.base.GraphMasking to
372
- start the propagation of masks in a model.
373
- - `batch_size`: size of the mini-batches;
374
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
379
+ - `mask`: bool, whether to add a mask to the node features;
380
+ - `batch_size`: int, size of the mini-batches;
381
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
375
382
iterates indefinitely;
376
- - `shuffle`: whether to shuffle the data at the start of each epoch.
383
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
384
+ - `node_level`: bool, if `True` pad the labels along the node dimension;
377
385
378
386
**Output**
379
387
@@ -385,19 +393,30 @@ class BatchLoader(Loader):
385
393
- `a`: adjacency matrices of shape `[batch, n_max, n_max]`;
386
394
- `e`: edge attributes of shape `[batch, n_max, n_max, n_edge_features]`.
387
395
388
- `labels` have shape `[batch, n_labels]`.
396
+ `labels` have shape `[batch, n_labels]` if `node_level=False` or
397
+ `[batch, n_max, n_labels]` otherwise.
389
398
"""
390
399
391
- def __init__ (self , dataset , mask = False , batch_size = 1 , epochs = None , shuffle = True ):
400
+ def __init__ (
401
+ self ,
402
+ dataset ,
403
+ mask = False ,
404
+ batch_size = 1 ,
405
+ epochs = None ,
406
+ shuffle = True ,
407
+ node_level = False ,
408
+ ):
392
409
self .mask = mask
410
+ self .node_level = node_level
411
+ self .signature = dataset .signature
393
412
super ().__init__ (dataset , batch_size = batch_size , epochs = epochs , shuffle = shuffle )
394
413
395
414
def collate (self , batch ):
396
415
packed = self .pack (batch )
397
416
398
417
y = packed .pop ("y_list" , None )
399
418
if y is not None :
400
- y = np . array ( y )
419
+ y = collate_labels_batch ( y , node_level = self . node_level )
401
420
402
421
output = to_batch (** packed , mask = self .mask )
403
422
output = sp_matrices_to_sp_tensors (output )
@@ -415,12 +434,13 @@ def tf_signature(self):
415
434
Adjacency matrix has shape [batch, n_nodes, n_nodes]
416
435
Node features have shape [batch, n_nodes, n_node_features]
417
436
Edge features have shape [batch, n_nodes, n_nodes, n_edge_features]
418
- Targets have shape [batch, ... , n_labels]
437
+ Labels have shape [batch, n_labels]
419
438
"""
420
- signature = self .dataset . signature
439
+ signature = self .signature
421
440
for k in signature :
422
441
signature [k ]["shape" ] = prepend_none (signature [k ]["shape" ])
423
- if "x" in signature :
442
+ if "x" in signature and self .mask :
443
+ # In case we have a mask, the mask is concatenated to the features
424
444
signature ["x" ]["shape" ] = signature ["x" ]["shape" ][:- 1 ] + (
425
445
signature ["x" ]["shape" ][- 1 ] + 1 ,
426
446
)
@@ -430,6 +450,9 @@ def tf_signature(self):
430
450
if "e" in signature :
431
451
# Edge attributes have an extra None dimension in batch mode
432
452
signature ["e" ]["shape" ] = prepend_none (signature ["e" ]["shape" ])
453
+ if "y" in signature and self .node_level :
454
+ # Node labels have an extra None dimension
455
+ signature ["y" ]["shape" ] = prepend_none (signature ["y" ]["shape" ])
433
456
434
457
return to_tf_signature (signature )
435
458
@@ -454,10 +477,12 @@ class PackedBatchLoader(BatchLoader):
454
477
**Arguments**
455
478
456
479
- `dataset`: a graph Dataset;
457
- - `batch_size`: size of the mini-batches;
458
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
480
+ - `mask`: bool, whether to add a mask to the node features;
481
+ - `batch_size`: int, size of the mini-batches;
482
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
459
483
iterates indefinitely;
460
- - `shuffle`: whether to shuffle the data at the start of each epoch.
484
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch;
485
+ - `node_level`: bool, if `True` pad the labels along the node dimension;
461
486
462
487
**Output**
463
488
@@ -469,22 +494,35 @@ class PackedBatchLoader(BatchLoader):
469
494
- `a`: adjacency matrices of shape `[batch, n_max, n_max]`;
470
495
- `e`: edge attributes of shape `[batch, n_max, n_max, n_edge_features]`.
471
496
472
- `labels` have shape `[batch, ..., n_labels]`.
497
+ `labels` have shape `[batch, n_labels]` if `node_level=False` or
498
+ `[batch, n_max, n_labels]` otherwise.
473
499
"""
474
500
475
- def __init__ (self , dataset , mask = False , batch_size = 1 , epochs = None , shuffle = True ):
501
+ def __init__ (
502
+ self ,
503
+ dataset ,
504
+ mask = False ,
505
+ batch_size = 1 ,
506
+ epochs = None ,
507
+ shuffle = True ,
508
+ node_level = False ,
509
+ ):
476
510
super ().__init__ (
477
- dataset , mask = mask , batch_size = batch_size , epochs = epochs , shuffle = shuffle
511
+ dataset ,
512
+ mask = mask ,
513
+ batch_size = batch_size ,
514
+ epochs = epochs ,
515
+ shuffle = shuffle ,
516
+ node_level = node_level ,
478
517
)
479
518
480
519
# Drop the Dataset container and work on packed tensors directly
481
520
packed = self .pack (self .dataset )
482
521
483
522
y = packed .pop ("y_list" , None )
484
523
if y is not None :
485
- y = np . array ( y )
524
+ y = collate_labels_batch ( y , node_level = self . node_level )
486
525
487
- self .signature = dataset .signature
488
526
self .dataset = to_batch (** packed , mask = mask )
489
527
if y is not None :
490
528
self .dataset += (y ,)
@@ -501,29 +539,6 @@ def collate(self, batch):
501
539
else :
502
540
return batch [:- 1 ], batch [- 1 ]
503
541
504
- def tf_signature (self ):
505
- """
506
- Adjacency matrix has shape [batch, n_nodes, n_nodes]
507
- Node features have shape [batch, n_nodes, n_node_features]
508
- Edge features have shape [batch, n_nodes, n_nodes, n_edge_features]
509
- Targets have shape [batch, ..., n_labels]
510
- """
511
- signature = self .signature
512
- for k in signature :
513
- signature [k ]["shape" ] = prepend_none (signature [k ]["shape" ])
514
- if "x" in signature :
515
- signature ["x" ]["shape" ] = signature ["x" ]["shape" ][:- 1 ] + (
516
- signature ["x" ]["shape" ][- 1 ] + 1 ,
517
- )
518
- if "a" in signature :
519
- # Adjacency matrix in batch mode is dense
520
- signature ["a" ]["spec" ] = tf .TensorSpec
521
- if "e" in signature :
522
- # Edge attributes have an extra None dimension in batch mode
523
- signature ["e" ]["shape" ] = prepend_none (signature ["e" ]["shape" ])
524
-
525
- return to_tf_signature (signature )
526
-
527
542
@property
528
543
def steps_per_epoch (self ):
529
544
if len (self .dataset ) > 0 :
@@ -544,10 +559,10 @@ class MixedLoader(Loader):
544
559
**Arguments**
545
560
546
561
- `dataset`: a graph Dataset;
547
- - `batch_size`: size of the mini-batches;
548
- - `epochs`: number of epochs to iterate over the dataset. By default (`None`)
562
+ - `batch_size`: int, size of the mini-batches;
563
+ - `epochs`: int, number of epochs to iterate over the dataset. By default (`None`)
549
564
iterates indefinitely;
550
- - `shuffle`: whether to shuffle the data at the start of each epoch.
565
+ - `shuffle`: bool, whether to shuffle the data at the start of each epoch.
551
566
552
567
**Output**
553
568
0 commit comments