Skip to content

Commit

Permalink
fix the inference
Browse files Browse the repository at this point in the history
  • Loading branch information
cminus01 committed Sep 9, 2022
1 parent 8933948 commit 52d81ee
Show file tree
Hide file tree
Showing 23 changed files with 867 additions and 282 deletions.
4 changes: 2 additions & 2 deletions examples/pytorch/amr_graph_construction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .amr_graph_construction import AmrGraphConstruction
from .amr_graph_construction import AMRGraphConstruction

__all__ = [
"AmrGraphConstruction",
"AMRGraphConstruction",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graph4nlp.pytorch.modules.graph_construction.base import StaticGraphConstructionBase


class AmrGraphConstruction(StaticGraphConstructionBase):
class AMRGraphConstruction(StaticGraphConstructionBase):
"""
Dependency-parsing-tree based graph construction class
Expand All @@ -22,7 +22,7 @@ def __init__(
self,
vocab,
):
super(AmrGraphConstruction, self).__init__()
super(AMRGraphConstruction, self).__init__()
self.vocab = vocab
self.verbose = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model_args:
graph_construction_args:
graph_construction_share:
root_dir: "examples/pytorch/amr_graph_construction/amr_semantic_parsing/graph2seq/jobs"
topology_subdir: 'AmrGraph'
topology_subdir: 'AMRGraph'
share_vocab: true
thread_number: 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

from amr_graph_construction import (
AmrGraphConstruction,
AMRGraphConstruction,
)
from graph4nlp.pytorch.modules.graph_construction.ie_graph_construction import IEBasedGraphConstruction
from graph4nlp.pytorch.modules.graph_construction.node_embedding_based_graph_construction import (
Expand All @@ -40,9 +40,9 @@
from amr_semantic_parsing.graph2seq.utils import get_log, wordid2str


class AmrDataItem(DataItem):
class AMRDataItem(DataItem):
def __init__(self, input_text, output_text, tokenizer, share_vocab=True):
super(AmrDataItem, self).__init__(input_text, tokenizer)
super(AMRDataItem, self).__init__(input_text, tokenizer)
self.output_text = output_text
self.share_vocab = share_vocab

Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_logger(self, log_file):
self.logger = get_log(log_file)

def _build_dataloader(self):
topology_builder = AmrGraphConstruction
topology_builder = AMRGraphConstruction
dataset = JobsDataset(
root_dir=self.opt["model_args"]["graph_construction_args"]["graph_construction_share"][
"root_dir"
Expand Down Expand Up @@ -155,7 +155,7 @@ def _build_dataloader(self):
"graph_construction_share"
]["nlp_processor_args"],
topology_builder=topology_builder,
dataitem=AmrDataItem,
dataitem=AMRDataItem,
)

self.train_dataloader = DataLoader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ model_args:
graph_construction_args:
graph_construction_share:
root_dir: 'examples/pytorch/semantic_parsing/graph2tree/jobs/jobs_data'
topology_subdir: 'AmrGraph'
topology_subdir: 'AMRGraph'
thread_number: 15
share_vocab: True
port: 9000
Expand All @@ -38,7 +38,8 @@ model_args:
normalizeParentheses: False
normalizeOtherBrackets: False
tokenize.whitespace: True
ssplit.isOneSentence: False
ssplit.isOneSentence: True
#ssplit.isOneSentence: False

graph_construction_private:
edge_strategy: 'heterogeneous'
Expand Down
108 changes: 108 additions & 0 deletions examples/pytorch/amr_graph_construction/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
The inference code.
In this file, we will run the inference by using the prediction API \
in the GeneratorInferenceWrapper.
The GeneratorInferenceWrapper takes the raw inputs and produce the outputs.
"""
import argparse
import random
import warnings
import numpy as np
import torch
from examples.pytorch.amr_graph_construction.runner import AMRDataItem

from graph4nlp.pytorch.datasets.mawps import MawpsDatasetForTree, tokenize_mawps
from graph4nlp.pytorch.inference_wrapper.generator_inference_wrapper_for_tree import (
GeneratorInferenceWrapper,
)
from graph4nlp.pytorch.modules.utils.config_utils import load_json_config
from amr_graph_construction import AMRGraphConstruction
from utils import AMRDataItem, RGCNGraph2Tree, InferenceText2TreeDataset
warnings.filterwarnings("ignore")


class Mawps:
def __init__(self, opt=None):
super(Mawps, self).__init__()
self.opt = opt

seed = self.opt["env_args"]["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if self.opt["env_args"]["gpuid"] == -1:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda:{}".format(self.opt["env_args"]["gpuid"]))

self._build_model()

def _build_model(self):
self.model = RGCNGraph2Tree.load_checkpoint(
self.opt["checkpoint_args"]["out_dir"], self.opt["checkpoint_args"]["checkpoint_name"]
).to(self.device)

self.inference_tool = GeneratorInferenceWrapper(
cfg=self.opt,
model=self.model,
beam_size=2,
lower_case=True,
tokenizer=tokenize_mawps,
dataset=InferenceText2TreeDataset,
data_item=AMRDataItem,
init_edge_vocab=True,
edge_vocab=self.model.edge_vocab,
topology_builder=(AMRGraphConstruction if self.model.graph_construction_name == "amr" else None)
)

@torch.no_grad()
def translate(self):
self.model.eval()
ret = self.inference_tool.predict(
raw_contents=[
"2 dogs are barking . 1 more dogs start to bark . how many dogs are barking"
],
batch_size=1,
)
print(ret)


################################################################################
# ArgParse and Helper Functions #
################################################################################
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-json_config",
"--json_config",
required=True,
type=str,
help="path to the json config file",
)
args = vars(parser.parse_args())

return args


def print_config(config):
import pprint

print("**************** MODEL CONFIGURATION ****************")
pprint.pprint(config)
print("**************** MODEL CONFIGURATION ****************")


if __name__ == "__main__":
import platform
import multiprocessing

#if platform.system() == "Darwin":
multiprocessing.set_start_method("spawn")

cfg = get_args()
config = load_json_config(cfg["json_config"])
# print_config(config)

runner = Mawps(opt=config)
runner.translate()
82 changes: 12 additions & 70 deletions examples/pytorch/amr_graph_construction/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,29 @@
import copy
import random
import time
from typing import Union
import warnings
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from copy import deepcopy
from torch.utils.data import DataLoader
from tqdm import tqdm
from graph4nlp.pytorch.data.data import GraphData
from graph4nlp.pytorch.data.dataset import DataItem, Text2TreeDataItem
from graph4nlp.pytorch.data.dataset import DataItem, Text2TreeDataItem, Text2TreeDataset

from graph4nlp.pytorch.datasets.mawps import MawpsDatasetForTree
from graph4nlp.pytorch.models.graph2tree import Graph2Tree
from graph4nlp.pytorch.modules.utils.config_utils import load_json_config
from graph4nlp.pytorch.modules.utils.tree_utils import Tree

from examples.pytorch.rgcn.rgcn import RGCN
from amr_graph_construction import AmrGraphConstruction
from amr_graph_construction import AMRGraphConstruction
from utils import AMRDataItem, EdgeText2TreeDataset, RGCNGraph2Tree

warnings.filterwarnings("ignore")

class AmrDataItem(DataItem):
def __init__(self, input_text, output_text, output_tree, tokenizer, share_vocab=True):
super(AmrDataItem, self).__init__(input_text, tokenizer)
self.output_text = output_text
self.share_vocab = share_vocab
self.output_tree = output_tree

def extract(self):
"""
Returns
-------
Input tokens and output tokens
"""
g: GraphData = self.graph

input_tokens = []
for i in range(g.get_node_num()):
tokenized_token = self.tokenizer(g.node_attributes[i]["token"])
input_tokens.extend(tokenized_token)

for s in g.graph_attributes["sentence"]:
input_tokens.extend(s.strip().split(" "))

output_tokens = self.tokenizer(self.output_text)

return input_tokens, output_tokens

class RGCNGraph2Tree(Graph2Tree):
def _build_gnn_encoder(
self,
gnn,
num_layers,
input_size,
hidden_size,
output_size,
direction_option,
feats_dropout,
gnn_heads=None,
gnn_residual=True,
gnn_attn_dropout=0.0,
gnn_activation=F.relu, # gat
gnn_bias=True,
gnn_allow_zero_in_degree=True,
gnn_norm="both",
gnn_weight=True,
gnn_use_edge_weight=False,
gnn_gcn_norm="both", # gcn
gnn_n_etypes=1, # ggnn
gnn_aggregator_type="lstm", # graphsage
**kwargs
):
if gnn == "rgcn":
self.gnn_encoder = RGCN(
num_layers,
input_size,
hidden_size,
output_size,
num_rels=77,
gpu=0,
)
else:
raise NotImplementedError()

class Mawps:
def __init__(self, opt=None):
super(Mawps, self).__init__()
Expand Down Expand Up @@ -113,6 +53,7 @@ def __init__(self, opt=None):
self._build_optimizer()

def _build_dataloader(self):
graph_type = self.opt["model_args"]["graph_construction_name"]
para_dic = {
"root_dir": self.data_dir,
"word_emb_size": self.opt["model_args"]["graph_initialization_args"]["input_size"],
Expand All @@ -136,12 +77,12 @@ def _build_dataloader(self):
"nlp_processor_args": self.opt["model_args"]["graph_construction_args"][
"graph_construction_share"
]["nlp_processor_args"],
"dataitem": Text2TreeDataItem,
#"dataitem": AmrDataItem,
"topology_builder": AmrGraphConstruction,
"dataitem": Text2TreeDataItem if graph_type != "amr" else AMRDataItem,
#"dataitem": AMRDataItem,
"topology_builder": AMRGraphConstruction if graph_type == "amr" else None,
}

dataset = MawpsDatasetForTree(**para_dic)
dataset = EdgeText2TreeDataset(**para_dic)

self.train_data_loader = DataLoader(
dataset.train,
Expand All @@ -157,6 +98,7 @@ def _build_dataloader(self):
dataset.val, batch_size=1, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
)
self.vocab_model = dataset.vocab_model
self.edge_vocab = dataset.edge_vocab
self.src_vocab = self.vocab_model.in_word_vocab
self.tgt_vocab = self.vocab_model.out_word_vocab
#self.num_rel = len(dataset.edge_vocab)
Expand All @@ -167,7 +109,7 @@ def _build_model(self):
"""For encoder-decoder"""
print(self.opt["model_args"]["graph_embedding_name"])
if self.opt["model_args"]["graph_embedding_name"] == "rgcn":
self.model = RGCNGraph2Tree.from_args(self.opt, vocab_model=self.vocab_model)
self.model = RGCNGraph2Tree.from_args(opt=self.opt, vocab_model=self.vocab_model, edge_vocab=self.edge_vocab)
else:
self.model = Graph2Tree.from_args(self.opt, vocab_model=self.vocab_model)
self.model.init(self.opt["training_args"]["init_weight"])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from amr_graph_construction import (
AmrGraphConstruction,
AMRGraphConstruction,
)


Expand All @@ -8,7 +8,7 @@ def test_amr():
"find all languageid0 job in locid0"
)

AmrGraphConstruction.static_topology(
AMRGraphConstruction.static_topology(
raw_data,
verbose=1,
)
Expand Down
Loading

0 comments on commit 52d81ee

Please sign in to comment.