diff --git a/Tutorial.md b/Tutorial.md index 8e98e9a..a0a064e 100644 --- a/Tutorial.md +++ b/Tutorial.md @@ -303,6 +303,7 @@ Question answer matching is a crucial subtask of the question answering problem, [ARC-I](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7508 [ARC-II](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7612 [MatchPyramid](https://arxiv.org/abs/1602.06359) (NeuronBlocks) | 0.763 + [MV-LSTM](https://arxiv.org/abs/1511.08277) (NeuronBlocks) | 0.774 BiLSTM+Match Attention (NeuronBlocks) | 0.786 diff --git a/Tutorial_zh_CN.md b/Tutorial_zh_CN.md index 7d93db9..e243958 100644 --- a/Tutorial_zh_CN.md +++ b/Tutorial_zh_CN.md @@ -292,6 +292,7 @@ Question answer matching is a crucial subtask of the question answering problem, [ARC-I](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7508 [ARC-II](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7612 [MatchPyramid](https://arxiv.org/abs/1602.06359) (NeuronBlocks) | 0.763 + [MV-LSTM](https://arxiv.org/abs/1511.08277) (NeuronBlocks) | 0.774 BiLSTM+Match Attention (NeuronBlocks) | 0.786 diff --git a/block_zoo/PoolingKmax2D.py b/block_zoo/PoolingKmax2D.py new file mode 100644 index 0000000..402a4dd --- /dev/null +++ b/block_zoo/PoolingKmax2D.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from block_zoo.BaseLayer import BaseLayer, BaseConf +from utils.DocInherit import DocInherit + + +class PoolingKmax2DConf(BaseConf): + """ + Args: + pool_type (str): 'max', default is 'max'. + k (int): how many element to reserve. + """ + def __init__(self, **kwargs): + super(PoolingKmax2DConf, self).__init__(**kwargs) + + @DocInherit + def default(self): + self.pool_type = 'max' # Supported: ['max'] + self.k = 50 + + @DocInherit + def declare(self): + self.num_of_inputs = 1 + self.input_ranks = [4] + + + @DocInherit + def inference(self): + self.output_dim = [self.input_dims[0][0], self.input_dims[0][3] * self.k] + self.output_rank = len(self.output_dim) + + @DocInherit + def verify(self): + super(PoolingKmax2DConf, self).verify() + necessary_attrs_for_user = ['pool_type'] + for attr in necessary_attrs_for_user: + self.add_attr_exist_assertion_for_user(attr) + self.add_attr_value_assertion('pool_type', ['max']) + + assert all([input_rank == 4 for input_rank in self.input_ranks]), "Cannot apply a pooling layer on a tensor of which the rank is not 4. Usually, a tensor whose rank is 4, e.g. [batch size, length, width, feature]" + assert self.output_dim[-1] != -1, "The shape of input is %s , and the input channel number of pooling should not be -1." % (str(self.input_dims[0])) + +class PoolingKmax2D(BaseLayer): + """ Pooling layer + Args: + layer_conf (PoolingKmax2DConf): configuration of a layer + """ + def __init__(self, layer_conf): + super(PoolingKmax2D, self).__init__(layer_conf) + self.k = layer_conf.k + + def forward(self, string, string_len=None): + """ process inputs + Args: + string (Tensor): tensor with shape: [batch_size, length, width, feature_dim] + string_len (Tensor): [batch_size], default is None. + Returns: + Tensor: Pooling result of string + """ + string = string.permute(0, 3, 1, 2) + string = string.view(string.size()[0], string.size()[1], -1) + index = string.topk(self.k, dim=-1)[1].sort(dim=-1)[0] + string = string.gather(-1, index) + string = string.view(string.size()[0], -1) + + return string, string_len diff --git a/block_zoo/__init__.py b/block_zoo/__init__.py index 7351a69..837e706 100644 --- a/block_zoo/__init__.py +++ b/block_zoo/__init__.py @@ -12,6 +12,7 @@ from .Conv import Conv, ConvConf from .Pooling import Pooling, PoolingConf from .ConvPooling import ConvPooling, ConvPoolingConf +from .PoolingKmax2D import PoolingKmax2D, PoolingKmax2DConf from .Dropout import Dropout, DropoutConf diff --git a/block_zoo/op/Combination2D.py b/block_zoo/op/Combination2D.py new file mode 100644 index 0000000..99560b4 --- /dev/null +++ b/block_zoo/op/Combination2D.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +import numpy as np +import logging + +from block_zoo.BaseLayer import BaseConf +from utils.DocInherit import DocInherit +from utils.exceptions import ConfigurationError +import copy + +class Combination2DConf(BaseConf): + """ Configuration for combination layer + Args: + operations (list): a subset of ["dot", "bilinear", "add"]. + """ + def __init__(self, **kwargs): + super(Combination2DConf, self).__init__(**kwargs) + + @DocInherit + def default(self): + self.operations = ["dot", "bilinear", "add"] + + @DocInherit + def declare(self): + self.num_of_inputs = -1 + self.input_ranks = [-1] + + @DocInherit + def inference(self): + self.output_dim = [self.input_dims[0][0], self.input_dims[0][1], self.input_dims[1][1], len(self.operations)] + if "add" in self.operations: + self.output_dim[-1] = self.output_dim[-1] + self.input_dims[0][-1] - 1 + + super(Combination2DConf, self).inference() + + @DocInherit + def verify(self): + super(Combination2DConf, self).verify() + + # to check if the ranks of all the inputs are equal + rank_equal_flag = True + for i in range(len(self.input_ranks)): + if self.input_ranks[i] != self.input_ranks[0]: + rank_equal_flag = False + break + if rank_equal_flag == False: + raise ConfigurationError("For layer Combination, the ranks of each inputs should be consistent!") + + +class Combination2D(nn.Module): + """ Combination2D layer to merge the representation of two sequence + Args: + layer_conf (Combination2DConf): configuration of a layer + """ + def __init__(self, layer_conf): + super(Combination2D, self).__init__() + self.layer_conf = layer_conf + + self.weight_bilinear = torch.nn.Linear(self.layer_conf.input_dims[0][-1], self.layer_conf.input_dims[0][-1]) + + + logging.warning("The length Combination layer returns is the length of first input") + + def forward(self, *args): + """ process inputs + Args: + args (list): [string, string_len, string2, string2_len, ...] + e.g. string (Variable): [batch_size, dim], string_len (ndarray): [batch_size] + Returns: + Variable: [batch_size, width, height, dim], None + """ + + result = [] + if "dot" in self.layer_conf.operations: + string1 = args[0] + string2 = args[2] + result_multiply = torch.matmul(string1, string2.transpose(1,2)) + + result.append(torch.unsqueeze(result_multiply, 3)) + + + if "bilinear" in self.layer_conf.operations: + string1 = args[0] + string2 = args[2] + string1 = self.weight_bilinear(string1) + result_multiply = torch.matmul(string1, string2.transpose(1,2)) + + result.append(torch.unsqueeze(result_multiply, 3)) + + if "add" in self.layer_conf.operations: + string1 = args[0] + string2 = args[2] + x_new = torch.stack([string1]*string2.size()[1], 2) # [batch_size, x_max_len, y_max_len, dim] + y_new = torch.stack([string2]*string1.size()[1], 1) # [batch_size, x_max_len, y_max_len, dim] + result.append((x_new + y_new)) + + return torch.cat(result, 3), args[1] diff --git a/block_zoo/op/Expand_plus.py b/block_zoo/op/Expand_plus.py deleted file mode 100644 index 17ebb47..0000000 --- a/block_zoo/op/Expand_plus.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -# Come from http://www.hangli-hl.com/uploads/3/1/6/8/3168008/hu-etal-nips2014.pdf [ARC-II] - -import torch -import torch.nn as nn -import copy - -from block_zoo.BaseLayer import BaseLayer, BaseConf -from utils.DocInherit import DocInherit -from utils.exceptions import ConfigurationError - -class Expand_plusConf(BaseConf): - """Configuration for Expand_plus layer - - """ - def __init__(self, **kwargs): - super(Expand_plusConf, self).__init__(**kwargs) - - @DocInherit - def default(self): - self.operation = 'Plus' - - @DocInherit - def declare(self): - self.num_of_inputs = 2 - self.input_ranks = [3, 3] - - @DocInherit - def inference(self): - self.output_dim = copy.deepcopy(self.input_dims[0]) - if self.input_dims[0][1] == -1 or self.input_dims[1][1] == -1: - raise ConfigurationError("For Expand_plus layer, the sequence length should be fixed") - self.output_dim.insert(2, self.input_dims[1][1]) # y_len - super(Expand_plusConf, self).inference() # PUT THIS LINE AT THE END OF inference() - - @DocInherit - def verify(self): - super(Expand_plusConf, self).verify() - - -class Expand_plus(BaseLayer): - """ Expand_plus layer - Given sequences X and Y, put X and Y expand_dim, and then add. - - Args: - layer_conf (Expand_plusConf): configuration of a layer - - """ - def __init__(self, layer_conf): - - super(Expand_plus, self).__init__(layer_conf) - assert layer_conf.input_dims[0][-1] == layer_conf.input_dims[1][-1] - - - def forward(self, x, x_len, y, y_len): - """ - - Args: - x: [batch_size, x_max_len, dim]. - x_len: [batch_size], default is None. - y: [batch_size, y_max_len, dim]. - y_len: [batch_size], default is None. - - Returns: - output: batch_size, x_max_len, y_max_len, dim]. - - """ - - x_new = torch.stack([x]*y.size()[1], 2) # [batch_size, x_max_len, y_max_len, dim] - y_new = torch.stack([y]*x.size()[1], 1) # [batch_size, x_max_len, y_max_len, dim] - - return x_new + y_new, None - - diff --git a/block_zoo/op/__init__.py b/block_zoo/op/__init__.py index 896cef6..eb67ce5 100644 --- a/block_zoo/op/__init__.py +++ b/block_zoo/op/__init__.py @@ -3,7 +3,7 @@ from .Concat2D import Concat2D, Concat2DConf from .Concat3D import Concat3D, Concat3DConf from .Combination import Combination, CombinationConf +from .Combination2D import Combination2D, Combination2DConf from .Match import Match, MatchConf from .Flatten import Flatten, FlattenConf -from .Expand_plus import Expand_plus, Expand_plusConf -from .CalculateDistance import CalculateDistance, CalculateDistanceConf \ No newline at end of file +from .CalculateDistance import CalculateDistance, CalculateDistanceConf diff --git a/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json index e9bf0d9..c8d81f4 100644 --- a/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json +++ b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json @@ -114,8 +114,9 @@ }, { "layer_id": "match", - "layer": "Expand_plus", + "layer": "Combination2D", "conf": { + "operations": ["add"] }, "inputs": ["s1_conv_1", "s2_conv_1"] }, @@ -209,4 +210,4 @@ ] }, "metrics": ["auc", "accuracy"] -} \ No newline at end of file +} diff --git a/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_mvlstm.json b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_mvlstm.json new file mode 100644 index 0000000..8d99868 --- /dev/null +++ b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_mvlstm.json @@ -0,0 +1,147 @@ +{ + "license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.", + "tool_version": "1.1.0", + "model_description": "This model is used for question answer matching task, and it achieved auc: 0.7736 in WikiQACorpus test set", + "inputs": { + "use_cache": true, + "dataset_type": "classification", + "data_paths": { + "train_data_path": "./dataset/WikiQACorpus/WikiQA-train.tsv", + "valid_data_path": "./dataset/WikiQACorpus/WikiQA-dev.tsv", + "test_data_path": "./dataset/WikiQACorpus/WikiQA-test.tsv", + "pre_trained_emb": "./dataset/GloVe/glove.840B.300d.txt" + }, + "file_with_col_header": true, + "add_start_end_for_seq": true, + "file_header": { + "question_id": 0, + "question_text": 1, + "document_id": 2, + "document_title": 3, + "passage_id": 4, + "passage_text": 5, + "label": 6 + }, + "model_inputs": { + "question": ["question_text"], + "passage": ["passage_text"] + }, + "target": ["label"] + }, + "outputs":{ + "save_base_dir": "./models/wikiqa_bilstm/", + "model_name": "model.nb", + "train_log_name": "train.log", + "test_log_name": "test.log", + "predict_log_name": "predict.log", + "predict_fields": ["prediction"], + "predict_output_name": "predict.tsv", + "cache_dir": ".cache.wikiqa/" + }, + "training_params": { + "vocabulary": { + "min_word_frequency": 1 + }, + "optimizer": { + "name": "Adam", + "params": { + } + }, + "lr_decay": 0.9, + "minimum_lr": 0.00005, + "epoch_start_lr_decay": 3, + "use_gpu": true, + "batch_size": 64, + "batch_num_to_show_results": 100, + "max_epoch": 10, + "valid_times_per_epoch": 5, + "fixed_lengths":{ + "question": 200, + "passage": 200 + } + }, + "architecture":[ + { + "layer": "Embedding", + "conf": { + "word": { + "cols": ["question_text", "passage_text"], + "dim": 300, + "fix_weight": true + } + } + }, + { + "layer_id": "question_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["question"] + }, + { + "layer_id": "passage_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["passage"] + }, + { + "layer_id": "question_1", + "layer": "BiLSTM", + "conf": { + "hidden_dim": 256, + "dropout": 0.2, + "num_layers": 2 + }, + "inputs": ["question_dropout"] + }, + { + "layer_id": "passage_1", + "layer": "question_1", + "inputs": ["passage_dropout"] + }, + { + "layer_id": "comb_qp", + "layer": "Combination2D", + "conf": { + "operations": ["dot", "bilinear"] + }, + "inputs": ["question_1", "passage_1"] + }, + { + "layer_id": "pooltest", + "layer": "PoolingKmax2D", + "conf": { + "pool_type": "max", + "k": 50 + }, + "inputs": ["comb_qp"] + }, + { + "output_layer_flag": true, + "layer_id": "output", + "layer": "Linear", + "conf": { + "hidden_dim": [128,2], + "activation": "PReLU", + "last_hidden_activation": false + }, + "inputs": ["pooltest"] + } + ], + "loss": { + "losses": [ + { + "type": "CrossEntropyLoss", + "conf": { + "weight": [0.1,0.9], + "size_average": true + }, + "inputs": ["output","label"] + } + ] + }, + "metrics": ["auc","accuracy"] +}