diff --git a/.gitignore b/.gitignore index 1dd7612f..cab75c64 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,6 @@ cscope.* # config file /config local_scripts/ + +**/amr_graph_construction/mawps/* + diff --git a/examples/pytorch/rgcn/rgcn.py b/examples/pytorch/rgcn/rgcn.py index 25291e33..7cf0d9d0 100644 --- a/examples/pytorch/rgcn/rgcn.py +++ b/examples/pytorch/rgcn/rgcn.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from dgl.nn.pytorch import RelGraphConv -from .base import GNNBase, GNNLayerBase +from graph4nlp.pytorch.modules.graph_embedding_learning.base import GNNBase, GNNLayerBase class RGCN(GNNBase): @@ -18,19 +18,18 @@ class RGCN(GNNBase): Number of RGCN layers. input_size : int, or pair of ints Input feature size. - hidden_size: int list of int + hidden_size: int Hidden layer size. - If a scalar is given, the sizes of all the hidden layers are the same. - If a list of scalar is given, each element in the list is the size of each hidden layer. - Example: [100,50] output_size : int Output feature size. num_rels : int Number of relations. num_bases : int, optional - Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + Number of bases. Needed when ``regularizer`` is specified. Default: ``-1`` [all]. use_self_loop : bool, optional - True to include self loop message. Default: ``True``. + True to include self loop message. Default: ``False``. + gpu : int, optional + True to use gpu. Default: ``-1`` [cpu]. dropout : float, optional Dropout rate. Default: ``0.0`` """ @@ -42,12 +41,14 @@ def __init__( hidden_size, output_size, num_rels, - num_bases=None, + num_bases=-1, use_self_loop=True, dropout=0.0, ): super(RGCN, self).__init__() self.num_layers = num_layers + if num_bases == -1: + num_bases = num_rels self.num_rels = num_rels self.num_bases = num_bases self.use_self_loop = use_self_loop @@ -75,8 +76,7 @@ def __init__( ) ) # hidden layers - for l in range(1, self.num_layers - 1): - # due to multi-head, the input_size = hidden_size * num_heads + for l in range(1, self.num_layers-1): self.RGCN_layers.append( RGCNLayer( hidden_size[l - 1], @@ -93,7 +93,7 @@ def __init__( # output projection self.RGCN_layers.append( RGCNLayer( - hidden_size[-1] if self.num_layers > 1 else input_size, + hidden_size, output_size, num_rels=self.num_rels, regularizer="basis", @@ -105,6 +105,9 @@ def __init__( ) ) + if self.gpu != -1: + self.to(device=self.gpu) + def forward(self, graph): r"""Compute RGCN layer. @@ -122,18 +125,19 @@ def forward(self, graph): named as "node_emb". """ - h = graph.node_features["node_feat"] - # get the node feature tensor from graph - g = graph.to_dgl() # transfer the current NLPgraph to DGL graph - edge_type = g.edata[dgl.ETYPE].long() - # output projection - if self.num_layers > 1: - for l in range(0, self.num_layers - 1): - h = self.RGCN_layers[l](g, h, edge_type) - + # transfer the current NLPgraph to DGL graph + g = graph.to_dgl() + h = graph.node_features['node_feat'] + edge_type = graph.edge_features['token_id'].squeeze(1) + for l in range(self.num_layers): + h = self.RGCN_layers[l](g, h, edge_type) + h = self.dropout(F.relu(h)) logits = self.RGCN_layers[-1](g, h, edge_type) + + # put the results into the NLPGraph + # graph.node_features['node_feat'] = h + graph.node_features["node_emb"] = logits - graph.node_features["node_emb"] = logits # put the results into the NLPGraph return graph @@ -176,7 +180,7 @@ def __init__( output_size, num_rels, regularizer=None, - num_bases=None, + num_bases=-1, bias=True, activation=None, self_loop=False, diff --git a/graph4nlp/pytorch/data/data.py b/graph4nlp/pytorch/data/data.py index c503c58f..3036eda8 100644 --- a/graph4nlp/pytorch/data/data.py +++ b/graph4nlp/pytorch/data/data.py @@ -8,7 +8,7 @@ """ import os import warnings -from collections import namedtuple +from collections import namedtuple, Counter from typing import Any, Callable, Dict, List, Tuple, Union import dgl import scipy.sparse @@ -99,6 +99,7 @@ def __init__(self, src=None, device: str = None, is_hetero: bool = False): self.batch_size = None # Batch size self._batch_num_nodes = None # Subgraph node number list with the length of batch size self._batch_num_edges = None # Subgraph edge number list with the length of batch size + self.batch_graph_attributes = [] # Subgraph attribute list with the length of batch size if src is not None: if isinstance(src, GraphData): @@ -176,7 +177,7 @@ def add_nodes(self, node_num: int, ntypes: List[str] = None): ) if not self.is_hetero: - if ntypes is None: + if ntypes is not None: raise ValueError( "The graph is homogeneous, ntypes should be None. Got {}".format(ntypes) ) @@ -878,7 +879,9 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): # Add nodes self.add_nodes(dgl_g.number_of_nodes()) for k, v in dgl_g.ndata.items(): - self.node_features[k] = v + self.node_features['node_'+k] = v + + # node_features['node_embed'] -> tensor.size((num_of_node, emb_dim)) # Add edges src_tensor, tgt_tensor = dgl_g.edges() @@ -886,7 +889,9 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): tgt_list = list(tgt_tensor.detach().cpu().numpy()) self.add_edges(src_list, tgt_list) for k, v in dgl_g.edata.items(): - self.edge_features[k] = v + self.edge_features['edge_'+k] = v + # edge_features['edge_emb'] -> tensor.size((number_of_edge, emb_dim)) + # edge_features['type'] -> tensor.size((number_of_edge,)) else: self.is_hetero = True # For heterogeneous DGL graphs, we perform the same routines for nodes and edges. @@ -904,19 +909,22 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): # for feature_name, feature_value in node_data.items(): # self.node_features[feature_name] = feature_value node_data = dgl_g.ndata - ntypes = [] + # ntypes = [] + ntypes = [None for _ in range(dgl_g.number_of_nodes())] processed_node_types = False node_feat_dict = {} for feature_name, data_dict in node_data.items(): if not processed_node_types: for node_type, node_feature in data_dict.items(): - ntypes += [node_type] * len(node_feature) + for nidx in node_feature: + ntypes[nidx] = node_type + # ntypes += [node_type] * len(node_feature) processed_node_types = True # for node_type, node_feature in data_dict.items(): node_feat_dict[feature_name] = torch.cat(list(data_dict.values()), dim=0) self.add_nodes(len(ntypes), ntypes=ntypes) for feature_name, feature_value in node_feat_dict.items(): - self.node_features[feature_name] = feature_value + self.node_features['node_'+feature_name] = feature_value # do the same thing for edges dgl_g_etypes = dgl_g.canonical_etypes # Add edges first @@ -924,13 +932,15 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): for etype in dgl_g_etypes: num_edges = dgl_g.num_edges(etype) src_type, r_type, dst_type = etype - srcs, dsts = dgl_g.find_edges( - torch.tensor(list(range(num_edges)), dtype=torch.long), etype - ) + # srcs, dsts = dgl_g.find_edges( + # torch.tensor(list(range(num_edges)), dtype=torch.long), etype + # ) + srcs, dsts = dgl_g.edges(etype=etype) srcs, dsts = ( srcs.detach().cpu().numpy().tolist(), dsts.detach().cpu().numpy().tolist(), ) + self.add_edges(srcs, dsts, etypes=[etype] * num_edges) if len(dgl_g_etypes) > 1: for feature_name, feature_dict in dgl_g.edata.items(): @@ -945,7 +955,7 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): edge_feature_dict[feature_name] = feature_value # Add edge features then for feat_name, feat_value in edge_feature_dict.items(): - self.edge_features[feat_name] = feat_value + self.edge_features['edge_'+feat_name] = feat_value # edge_data = dgl_g.edata # etypes = [] # processed_edge_types = False @@ -1330,7 +1340,7 @@ def split_features(self, input_tensor: torch.Tensor, type: str = "node") -> torc return output -def from_dgl(g: dgl.DGLGraph) -> GraphData: +def from_dgl(g: dgl.DGLGraph, is_hetero=False) -> GraphData: """ Convert a dgl.DGLGraph to a GraphData object. @@ -1338,14 +1348,15 @@ def from_dgl(g: dgl.DGLGraph) -> GraphData: ---------- g : dgl.DGLGraph The source graph in DGLGraph format. - + is_hetero: bool, default=False + Whether the graph should be heterogeneous Returns ------- GraphData The converted graph in GraphData format. """ - graph = GraphData(is_hetero=not g.is_homogeneous) - graph.from_dgl(g, is_hetero=not g.is_homogeneous) + graph = GraphData(is_hetero=is_hetero) + graph.from_dgl(g, is_hetero=is_hetero) return graph @@ -1456,7 +1467,11 @@ def stack_edge_indices(gs): big_graph._batch_num_nodes = [g.get_node_num() for g in graphs] big_graph._batch_num_edges = [g.get_edge_num() for g in graphs] - # Step 8: merge node and edge types if the batch is heterograph + # Step 8: Insert graph attributes + for g in graphs: + big_graph.batch_graph_attributes.append(g.graph_attributes) + + # Step 9: merge node and edge types if the batch is heterograph if is_heterograph: node_types = [] edge_types = [] @@ -1501,6 +1516,7 @@ def from_batch(batch: GraphData) -> List[GraphData]: cum_n_edges += num_edges[i] cum_n_nodes += num_nodes[i] ret.append(g) + g.graph_attributes = batch.batch_graph_attributes[i] # Add node and edge features for k, v in batch._node_features.items(): diff --git a/graph4nlp/pytorch/data/dataset.py b/graph4nlp/pytorch/data/dataset.py index 9fe88e16..4698f00c 100644 --- a/graph4nlp/pytorch/data/dataset.py +++ b/graph4nlp/pytorch/data/dataset.py @@ -36,7 +36,7 @@ from ..modules.utils.tree_utils import Tree from ..modules.utils.tree_utils import Vocab as VocabForTree from ..modules.utils.tree_utils import VocabForAll -from ..modules.utils.vocab_utils import VocabModel +from ..modules.utils.vocab_utils import Vocab, VocabModel class DataItem(object): @@ -146,6 +146,16 @@ def extract(self): output_tokens = self.tokenizer(self.output_text) return input_tokens, output_tokens + + def extract_edge_tokens(self): + g: GraphData = self.graph + edge_tokens = [] + for i in range(g.get_edge_num()): + if "token" in g.edge_attributes[i]: + edge_tokens.append(g.edge_attributes[i]["token"]) + else: + edge_tokens.append("") + return edge_tokens class Text2LabelDataItem(DataItem): @@ -311,6 +321,8 @@ def __init__( for_inference=False, reused_vocab_model=None, nlp_processor_args=None, + init_edge_vocab=False, + is_hetero=False, **kwargs, ): """ @@ -357,6 +369,10 @@ def __init__( vocabulary data is located. nlp_processor_args: dict, default=None It contains the parameter for nlp processor such as ``stanza``. + init_edge_vocab: bool, default=False + Whether to initialize the edge vocabulary. + is_hetero: bool, default=False + Whether the graph is heterogeneous. kwargs """ super(Dataset, self).__init__() @@ -385,6 +401,8 @@ def __init__( self.topology_builder = topology_builder self.topology_subdir = topology_subdir self.use_val_for_vocab = use_val_for_vocab + self.init_edge_vocab = init_edge_vocab + self.is_hetero = is_hetero for k, v in kwargs.items(): setattr(self, k, v) self.__indices__ = None @@ -659,6 +677,7 @@ def build_vocab(self): target_pretrained_word_emb_name=self.target_pretrained_word_emb_name, target_pretrained_word_emb_url=self.target_pretrained_word_emb_url, word_emb_size=self.word_emb_size, + init_edge_vocab=self.init_edge_vocab, ) self.vocab_model = vocab_model @@ -1077,6 +1096,11 @@ def build_vocab(self): pretrained_word_emb_cache_dir=self.pretrained_word_emb_cache_dir, embedding_dims=self.dec_emb_size, ) + if self.init_edge_vocab: + all_edge_words = VocabModel.collect_edge_vocabs(data_for_vocab, self.tokenizer, lower_case=self.lower_case) + edge_vocab = Vocab(lower_case=self.lower_case, tokenizer=self.tokenizer) + edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1) + edge_vocab.randomize_embeddings(self.word_emb_size) if self.share_vocab: all_words = Counter() @@ -1119,6 +1143,7 @@ def build_vocab(self): in_word_vocab=src_vocab_model, out_word_vocab=tgt_vocab_model, share_vocab=src_vocab_model if self.share_vocab else None, + edge_vocab=edge_vocab if self.init_edge_vocab else None, ) return self.vocab_model @@ -1136,6 +1161,18 @@ def vectorization(self, data_items): token_matrix = torch.tensor(token_matrix, dtype=torch.long) graph.node_features["token_id"] = token_matrix + if self.is_hetero: + for edge_idx in range(graph.get_edge_num()): + if "token" in graph.edge_attributes[edge_idx]: + edge_token = graph.edge_attributes[edge_idx]["token"] + else: + edge_token = "" + edge_token_id = self.edge_vocab[edge_token] + graph.edge_attributes[edge_idx]["token_id"] = edge_token_id + token_matrix.append([edge_token_id]) + token_matrix = torch.tensor(token_matrix, dtype=torch.long) + graph.edge_features["token_id"] = token_matrix + tgt = item.output_text tgt_list = self.vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split()) output_tree = Tree.convert_to_tree( @@ -1144,7 +1181,7 @@ def vectorization(self, data_items): item.output_tree = output_tree @classmethod - def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False): + def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False, is_hetero=False): item = deepcopy(data_item) graph: GraphData = item.graph token_matrix = [] @@ -1156,6 +1193,21 @@ def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False): token_matrix = torch.tensor(token_matrix, dtype=torch.long) graph.node_features["token_id"] = token_matrix + if is_hetero: + if not hasattr(vocab_model, "edge_vocab"): + raise ValueError("Vocab model must have edge vocab attribute") + token_matrix = [] + for edge_idx in range(graph.get_edge_num()): + if "token" in graph.edge_attributes[edge_idx]: + edge_token = graph.edge_attributes[edge_idx]["token"] + else: + edge_token = "" + edge_token_id = vocab_model.edge_vocab[edge_token] + graph.edge_attributes[edge_idx]["token_id"] = edge_token_id + token_matrix.append([edge_token_id]) + token_matrix = torch.tensor(token_matrix, dtype=torch.long) + graph.edge_features["token_id"] = token_matrix + if isinstance(item.output_text, str): tgt = item.output_text tgt_list = vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split()) diff --git a/graph4nlp/pytorch/datasets/mawps.py b/graph4nlp/pytorch/datasets/mawps.py index 5c931907..5f770749 100644 --- a/graph4nlp/pytorch/datasets/mawps.py +++ b/graph4nlp/pytorch/datasets/mawps.py @@ -50,6 +50,8 @@ def __init__( max_word_vocab_size=100000, for_inference=False, reused_vocab_model=None, + init_edge_vocab=False, + is_hetero=False, ): """ Parameters @@ -97,7 +99,7 @@ def __init__( # then do the preprocessing and save them. super(MawpsDatasetForTree, self).__init__( root_dir=root_dir, - # topology_builder=topology_builder, + topology_builder=topology_builder, topology_subdir=topology_subdir, graph_construction_name=graph_construction_name, static_or_dynamic=static_or_dynamic, @@ -118,4 +120,6 @@ def __init__( max_word_vocab_size=max_word_vocab_size, for_inference=for_inference, reused_vocab_model=reused_vocab_model, + init_edge_vocab=init_edge_vocab, + is_hetero=is_hetero, ) diff --git a/graph4nlp/pytorch/modules/utils/tree_utils.py b/graph4nlp/pytorch/modules/utils/tree_utils.py index 5e9155bb..346dca32 100644 --- a/graph4nlp/pytorch/modules/utils/tree_utils.py +++ b/graph4nlp/pytorch/modules/utils/tree_utils.py @@ -132,10 +132,11 @@ def convert_to_tree(r_list, i_left, i_right, tgt_vocab): class VocabForAll: - def __init__(self, in_word_vocab, out_word_vocab, share_vocab): + def __init__(self, in_word_vocab, out_word_vocab, share_vocab, edge_vocab=None): self.in_word_vocab = in_word_vocab self.out_word_vocab = out_word_vocab self.share_vocab = share_vocab + self.edge_vocab = edge_vocab def get_vocab_size(self): if hasattr(self, "share_vocab"): diff --git a/graph4nlp/pytorch/modules/utils/vocab_utils.py b/graph4nlp/pytorch/modules/utils/vocab_utils.py index 82c7a742..36bf0f0e 100644 --- a/graph4nlp/pytorch/modules/utils/vocab_utils.py +++ b/graph4nlp/pytorch/modules/utils/vocab_utils.py @@ -47,6 +47,8 @@ class VocabModel(object): Word embedding size, default: ``None``. share_vocab : boolean Specify whether to share vocab between input and output text, default: ``True``. + init_edge_vocab: boolean + Specify whether to initialize edge vocab, default: ``False``. Examples ------- @@ -82,6 +84,7 @@ def __init__( # pretrained_word_emb_file=None, word_emb_size=None, share_vocab=True, + init_edge_vocab=False, ): super(VocabModel, self).__init__() self.tokenizer = tokenizer @@ -150,6 +153,12 @@ def __init__( self.out_word_vocab.randomize_embeddings(word_emb_size) else: self.out_word_vocab = self.in_word_vocab + + if init_edge_vocab: + all_edge_words = VocabModel.collect_edge_vocabs(data_set, self.tokenizer, lower_case=lower_case) + self.edge_vocab = Vocab(lower_case=lower_case, tokenizer=self.tokenizer) + self.edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1) + self.edge_vocab.randomize_embeddings(word_emb_size) if share_vocab: print("[ Initialized word embeddings: {} ]".format(self.in_word_vocab.embeddings.shape)) @@ -265,6 +274,14 @@ def collect_vocabs(all_instances, tokenizer, lower_case=True, share_vocab=True): all_words[1].update(extracted_tokens[1]) return all_words + @staticmethod + def collect_edge_vocabs(all_instances, tokenizer, lower_case=True): + """Count vocabulary tokens for edge.""" + all_edges = Counter() + for instance in all_instances: + extracted_edge_tokens = instance.extract_edge_tokens() + all_edges.update(extracted_edge_tokens) + return all_edges class WordEmbModel(Vectors): diff --git a/graph4nlp/pytorch/test/data_structure/test_graphdata.py b/graph4nlp/pytorch/test/data_structure/test_graphdata.py index 7d438715..e73ea8ad 100644 --- a/graph4nlp/pytorch/test/data_structure/test_graphdata.py +++ b/graph4nlp/pytorch/test/data_structure/test_graphdata.py @@ -324,17 +324,17 @@ def test_conversion_dgl(): def test_conversion_dgl_hetero(): g = GraphData(is_hetero=True) - g.add_nodes(10, ntypes=["A"] * 5 + ["B"] * 5) + g.add_nodes(11, ntypes=["A"] * 5 + ["B"] * 6) # g.add_nodes for i in range(5): - g.add_edge(src=i, tgt=(i + 5) % 10, etype=("A", "R_ab", "B")) + g.add_edge(src=i, tgt=(i + 5) % 11, etype=("A", "R_ab", "B")) for i in range(5): - g.add_edge(src=(i + 5) % 10, tgt=i, etype=("B", "R_ba", "A")) + g.add_edge(src=(i + 6) % 11, tgt=i, etype=("B", "R_ba", "A")) for i in range(5): g.add_edge(src=i, tgt=(i + 1) % 5, etype=("A", "R_aa", "A")) - g.node_features["node_feat"] = torch.randn((10, 10)) - g.node_features["zero"] = torch.zeros(10) - g.node_features["idx"] = torch.tensor(list(range(10)), dtype=torch.long) + g.node_features["node_feat"] = torch.randn((11, 10)) + g.node_features["zero"] = torch.zeros(11) + g.node_features["idx"] = torch.tensor(list(range(11)), dtype=torch.long) g.edge_features["edge_feat"] = torch.randn((15, 10)) g.edge_features["idx"] = torch.tensor(list(range(15)), dtype=torch.long) # Test to_dgl @@ -582,3 +582,6 @@ def test_remove_edges(): mem_report() g.remove_all_edges() mem_report() + +if __name__ == "__main__": + test_conversion_dgl_hetero() \ No newline at end of file diff --git a/graph4nlp/pytorch/test/graph_construction/test_embedding_construction.py b/graph4nlp/pytorch/test/graph_construction/test_embedding_construction.py index f0c1a6a4..d57f1d9e 100644 --- a/graph4nlp/pytorch/test/graph_construction/test_embedding_construction.py +++ b/graph4nlp/pytorch/test/graph_construction/test_embedding_construction.py @@ -1,23 +1,35 @@ import torch -from ...modules.graph_construction.embedding_construction import EmbeddingConstruction -from ...modules.utils.padding_utils import pad_2d_vals_no_size -from ...modules.utils.vocab_utils import VocabModel +from graph4nlp.pytorch.modules.graph_embedding_initialization.embedding_construction import EmbeddingConstruction +from graph4nlp.pytorch.modules.utils.padding_utils import pad_2d_vals_no_size +from graph4nlp.pytorch.modules.utils.vocab_utils import VocabModel +from graph4nlp.pytorch.data.dataset import Text2LabelDataItem +from graph4nlp.pytorch.data.data import GraphData, to_batch +from examples.pytorch.amr_graph_construction.amr_graph_construction import AMRGraphConstruction +from graph4nlp.pytorch.data.dataset import Text2LabelDataset +from graph4nlp.pytorch.modules.graph_construction.dependency_graph_construction import DependencyBasedGraphConstruction +from stanfordcorenlp import StanfordCoreNLP if __name__ == "__main__": raw_text_data = [["I like nlp.", "Same here!"], ["I like graph.", "Same here!"]] - vocab_model = VocabModel( - raw_text_data, max_word_vocab_size=None, min_word_vocab_freq=1, word_emb_size=300 + # src_text_seq = list(zip(*raw_text_data))[0] + # src_idx_seq = [vocab_model.word_vocab.to_index_sequence(each) for each in src_text_seq] + # src_len = torch.LongTensor([len(each) for each in src_idx_seq]) + # num_seq = torch.LongTensor([len(src_len)]) + # input_tensor = torch.LongTensor(pad_2d_vals_no_size(src_idx_seq)) + # print("input_tensor: {}".format(input_tensor.shape)) + raw_data = ( + "We need to borrow 55% of the hammer price until we can get planning permission for restoration which will allow us to get a mortgage . I saw a nice dog and noticed he was eating a bone ." ) - src_text_seq = list(zip(*raw_text_data))[0] - src_idx_seq = [vocab_model.word_vocab.to_index_sequence(each) for each in src_text_seq] - src_len = torch.LongTensor([len(each) for each in src_idx_seq]) - num_seq = torch.LongTensor([len(src_len)]) - input_tensor = torch.LongTensor(pad_2d_vals_no_size(src_idx_seq)) - print("input_tensor: {}".format(input_tensor.shape)) - - emb_constructor = EmbeddingConstruction(vocab_model.word_vocab, "w2v", "bilstm", "bilstm", 128) - emb = emb_constructor(input_tensor, src_len, num_seq) - print("emb: {}".format(emb.shape)) + graph = AMRGraphConstruction.static_topology(raw_data) + data_set = Text2LabelDataItem('I like nlp.') + data_set.graph = graph + vocab_model = VocabModel( + [data_set], max_word_vocab_size=None, min_word_vocab_freq=1, word_emb_size=300 + ) + emb_constructor = EmbeddingConstruction(vocab_model.in_word_vocab, False, emb_strategy="bert_bilstm_amr",hidden_size=300) + g = Text2LabelDataset._vectorize_one_dataitem(data_set, vocab_model) + emb = emb_constructor(to_batch([g.graph, g.graph])) + print("emb: {}".format(emb.node_features)) \ No newline at end of file