diff --git a/deep_qa/data/dataset.py b/deep_qa/data/dataset.py index 524aa7828..71a1ee822 100644 --- a/deep_qa/data/dataset.py +++ b/deep_qa/data/dataset.py @@ -81,12 +81,13 @@ def read_from_lines(lines: List[str], instance_class): instances = [instance_class.read_from_line(x) for x in lines] labels = [(x.label, x) for x in instances] labels.sort(key=lambda x: str(x[0])) - label_counts = [(label, len([x for x in group])) - for label, group in itertools.groupby(labels, lambda x: x[0])] - label_count_str = str(label_counts) - if len(label_count_str) > 100: - label_count_str = label_count_str[:100] + '...' - logger.info("Finished reading dataset; label counts: %s", label_count_str) + if len(labels) < 1: + label_counts = [(label, len([x for x in group])) + for label, group in itertools.groupby(labels, lambda x: x[0])] + label_count_str = str(label_counts) + if len(label_count_str) > 100: + label_count_str = label_count_str[:100] + '...' + logger.info("Finished reading dataset; label counts: %s", label_count_str) return TextDataset(instances) diff --git a/deep_qa/data/instances/text_classification/__init__.py b/deep_qa/data/instances/text_classification/__init__.py index dfd787a5b..a5c710bb3 100644 --- a/deep_qa/data/instances/text_classification/__init__.py +++ b/deep_qa/data/instances/text_classification/__init__.py @@ -1,3 +1,10 @@ +from deep_qa.data.instances.text_classification.frame_embedded_label_instance import FrameEmbeddedLabelInstance +from deep_qa.data.instances.text_classification.frame_instance import FrameInstance from .logical_form_instance import LogicalFormInstance, IndexedLogicalFormInstance from .text_classification_instance import TextClassificationInstance, IndexedTextClassificationInstance from .tuple_instance import TupleInstance, IndexedTupleInstance + +concrete_instances = { + 'FrameInstance': FrameInstance, + 'FrameEmbeddedLabelInstance': FrameEmbeddedLabelInstance + } diff --git a/deep_qa/data/instances/text_classification/frame_embedded_label_instance.py b/deep_qa/data/instances/text_classification/frame_embedded_label_instance.py new file mode 100644 index 000000000..5a2129fe1 --- /dev/null +++ b/deep_qa/data/instances/text_classification/frame_embedded_label_instance.py @@ -0,0 +1,205 @@ +from typing import Dict, List + +import numpy +from overrides import overrides + +from ..instance import TextInstance, IndexedInstance +from ...data_indexer import DataIndexer + +# the slotnames can vary according to different end applications, e.g., a HowTo tuple, OpenIE tuple ... +SLOTNAMES_ORDERED = ["agent", "beneficiary", "causer", "context", "definition", "event", + "finalloc", "headverb", "initloc", "input", "output", "manner", + "patient", "resultant", "timebegin", "timeend", "temporal", "hierarchical", + "similar", "contemporary", "enables", "mechanism", "condition", "purpose", + "cause", "openrel", "participant"] +UNKNOWN_SLOTVAL = "missingval" # making an open world assumption, we do not observe all the values +QUES_SLOTVAL = "ques" # this slot in the frame must be queried/completed. + + +class FrameEmbeddedLabelInstance(TextInstance): + + """ + A FrameEmbeddedLabelInstance is a kind of TextInstance that has text in multiple slots. + """ + def __init__(self, + dense_frame: List[str], + phrase_dims_in_queried_slot: numpy.array): # output label: vector representation of label phrase + super(FrameEmbeddedLabelInstance, self).__init__(phrase_dims_in_queried_slot) + self.text = dense_frame # "event:plant absorb water###participant:water###agent:plant" TAB "agent:plant" + + def __str__(self): + return 'FrameEmbeddedLabelInstance( [' + ',\n'.join(self.text) + '] , ' + str(self.label) + ')' + + @overrides + def words(self) -> Dict[str, List[str]]: + # Accumulate words from each slot's phrase. + # Label is a vector representation of the phrase + words = [] + for phrase in self.text: # phrases + phrase_words = self._words_from_text(phrase) + words.extend(phrase_words['words']) + return {'words': words, 'slot_names': SLOTNAMES_ORDERED} + + @staticmethod + def query_slot_from(slot_as_dims: str, + kv_separator: str=":"): + """ + :param slot_as_dims: "participant:water" + :param sparse_given_frame: If the expected slot name is given in the query + but its value is not, then pick the value from the sparse_given_frame + :param kv_separator: typically colon separated + :return: name=participant, val=0.98877,098762,-0.876,... embedding + """ + slot_name_val = slot_as_dims.split(kv_separator) + csv_of_floats = slot_name_val[1] + val_arr = numpy.array(list(csv_of_floats + .replace('\n', ',') + .replace(' ', '') + .replace(',,', ',') + .split(',')), + dtype='float64') + # val_arr = numpy.genfromtxt(StringIO(csv_of_floats), delimiter=",", dtype="float64", autostrip=True) + # The shape and type is automatically inferred by numpy based on csv of reals. + return {'name': slot_name_val[0], 'val': val_arr} + + @staticmethod + def unpack_input(frame_as_string: str, + kv_separator: str="\t"): + """ + :param frame_as_string: "event:plant absorb water###participant:water" TAB "participant:water" + :param kv_separator: typically TAB separated partial frame and query + :return: event:plant absorb water###participant:water, and query: participant:water + Both event and query will be lowercased + """ + # No information loss in lower-casing, and simplifies matching. + partialframe_query = frame_as_string.lower().split(kv_separator) + if len(partialframe_query) != 2: + raise RuntimeError("Unexpected number (not 2) of fields in frame: " + frame_as_string) + return {'content': partialframe_query[0], 'query': partialframe_query[1]} + + @staticmethod + def given_slots_from(slots_csv: str, + values_separator: str="###", + kv_separator: str=":"): + """ + :param slots_csv: event:plant absorb water###participant:water + :param values_separator: typically "###" + :param kv_separator: typically ":" + :return: map of slotnames -> slot phrase [event -> plant absorb water , participant -> water] + """ + # ValueError: dictionary update sequence element # 3 has length 1; 2 is required + return dict(map(lambda x: x.split(kv_separator), slots_csv.split(values_separator))) + + @staticmethod + def dense_frame_from(sparse_frame: Dict[str, str], + query_slotname: str): + """ + Performs two types of padding: + i) unobserved slots are filled with self.unknown_slotval + ii) query slot is masked with self.unknown_queryval + The order of slots strictly follows from SLOTNAMES_ORDERED. + :param sparse_frame: + slotnames -> slot phrase [event -> plant absorb water , participant -> water] + :param query_slotname: + participant + :return: [plant absorb water, ques, missingval, missingval, ...] + """ + slots = [] + for slot_name in SLOTNAMES_ORDERED: + if slot_name == query_slotname: # query hence masked + slots.append(QUES_SLOTVAL) + elif slot_name in sparse_frame: # observed hence as-is + slots.append(sparse_frame[slot_name]) + else: # unobserved hence inserted + slots.append(UNKNOWN_SLOTVAL) + return slots + + @classmethod + @overrides + def read_from_line(cls, line: str): + """ + Reads a FrameEmbeddedLabelInstance from a line. The format is: + frame represented as list of TAB