diff --git a/.appveyor.yml b/.appveyor.yml index a33a9acae..43b836a8f 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -20,7 +20,7 @@ cache: install: - "%PYTHON%\\python.exe -m pip install wheel" - - "%PYTHON%\\python.exe -m pip install -e .[test] --verbose" + - "%PYTHON%\\python.exe -m pip install -e .[test]" build: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b13a6296..1abff5e61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,22 @@ # Changelog All notable changes to this project will be documented in this file. +## [0.20.0] +### Added +- Add new intent parser: `LookupIntentParser` [#759](https://github.com/snipsco/snips-nlu/pull/759) + +### Changed +- Replace `DeterministicIntentParser` by `LookupIntentParser` in default configs [#829](https://github.com/snipsco/snips-nlu/pull/829) +- Bumped `snips-nlu-parsers` to `0.3.x` introducing new builtin entities: + - `snips/time` + - `snips/timePeriod` + - `snips/date` + - `snips/datePeriod` + - `snips/city` + - `snips/country` + - `snips/region` + + ## [0.19.8] ### Added - Add filter for entity match feature [#814](https://github.com/snipsco/snips-nlu/pull/814) @@ -296,6 +312,7 @@ several commands. - Fix compiling issue with `bindgen` dependency when installing from source - Fix issue in `CRFSlotFiller` when handling builtin entities +[0.20.0]: https://github.com/snipsco/snips-nlu/compare/0.19.8...0.20.0 [0.19.8]: https://github.com/snipsco/snips-nlu/compare/0.19.7...0.19.8 [0.19.7]: https://github.com/snipsco/snips-nlu/compare/0.19.6...0.19.7 [0.19.6]: https://github.com/snipsco/snips-nlu/compare/0.19.5...0.19.6 diff --git a/README.rst b/README.rst index 4a3acc6e1..361addbfe 100644 --- a/README.rst +++ b/README.rst @@ -256,6 +256,15 @@ Licence This library is provided by `Snips `_ as Open Source software. See `LICENSE `_ for more information. + +Geonames Licence +---------------- + +The `snips/city`, `snips/country` and `snips/region` builtin entities rely on +software from Geonames, which is made available under a Creative Commons Attribution 4.0 +license international. For the license and warranties for Geonames please refer to: https://creativecommons.org/licenses/by/4.0/legalcode. + + .. _external language resources: https://github.com/snipsco/snips-nlu-language-resources .. _forum: https://forum.snips.ai/ .. _blog post: https://medium.com/snips-ai/an-introduction-to-snips-nlu-the-open-source-library-behind-snips-embedded-voice-platform-b12b1a60a41a diff --git a/docs/source/api.rst b/docs/source/api.rst index 55396e8eb..b86db553c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -27,6 +27,9 @@ Intent Parser .. autoclass:: DeterministicIntentParser :members: +.. autoclass:: LookupIntentParser + :members: + .. autoclass:: ProbabilisticIntentParser :members: @@ -89,6 +92,9 @@ Configurations .. autoclass:: DeterministicIntentParserConfig :members: +.. autoclass:: LookupIntentParserConfig + :members: + .. autoclass:: ProbabilisticIntentParserConfig :members: diff --git a/docs/source/custom_processing_units.rst b/docs/source/custom_processing_units.rst index 9d7edbc80..2831a246e 100644 --- a/docs/source/custom_processing_units.rst +++ b/docs/source/custom_processing_units.rst @@ -4,7 +4,7 @@ Custom Processing Units ======================= The Snips NLU library provides a default NLU pipeline containing built-in -processing units such as the :class:`.DeterministicIntentParser` or the +processing units such as the :class:`.LookupIntentParser` or the :class:`.ProbabilisticIntentParser`. However, it is possible to define custom processing units and use them in a @@ -14,7 +14,7 @@ The main processing unit of the Snips NLU processing pipeline is the :class:`.SnipsNLUEngine`. This engine relies on a list of :class:`.IntentParser` that are called successively until one of them manages to extract an intent. By default, two parsers are used by the engine: a -:class:`.DeterministicIntentParser` and a :class:`.ProbabilisticIntentParser`. +:class:`.LookupParser` and a :class:`.ProbabilisticIntentParser`. Let's focus on the probabilistic intent parser. This parser parses text using two steps: first it classifies the intent using an @@ -82,12 +82,12 @@ naive keyword matching logic: "slots_keywords": self.slots_keywords, "config": self.config.to_dict() } - with path.open(mode="w") as f: + with path.open(mode="w", encoding="utf8") as f: f.write(json_string(model)) @classmethod def from_path(cls, path, **shared): - with path.open() as f: + with path.open(encoding="utf8") as f: model = json.load(f) slot_filler = cls() slot_filler.language = model["language"] @@ -188,12 +188,12 @@ this: "slots_keywords": self.slots_keywords, "config": self.config.to_dict() } - with path.open(mode="w") as f: + with path.open(mode="w", encoding="utf8") as f: f.write(json_string(model)) @classmethod def from_path(cls, path, **shared): - with path.open() as f: + with path.open(encoding="utf8") as f: model = json.load(f) slot_filler = cls() slot_filler.language = model["language"] diff --git a/setup.py b/setup.py index 27b749fc5..514fe2deb 100644 --- a/setup.py +++ b/setup.py @@ -30,9 +30,8 @@ "scikit-learn>=0.21.1,<0.22; python_version>='3.5'", "scipy>=1.0,<2.0", "sklearn-crfsuite>=0.3.6,<0.4", - "snips-nlu-parsers>=0.2,<0.3", - "snips-nlu-utils>=0.8,<0.9", - "deprecation>=2,<3", + "snips-nlu-parsers>=0.3,<0.4", + "snips_nlu_utils>=0.9,<0.10", ] extras_require = { diff --git a/snips_nlu/__about__.py b/snips_nlu/__about__.py index 050edcb4c..6158ceef6 100644 --- a/snips_nlu/__about__.py +++ b/snips_nlu/__about__.py @@ -13,13 +13,13 @@ __email__ = "clement.doumouro@snips.ai, adrien.ball@snips.ai" __license__ = "Apache License, Version 2.0" -__version__ = "0.19.8" -__model_version__ = "0.19.0" +__version__ = "0.20.0" +__model_version__ = "0.20.0" __download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download" __compatibility__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/compatibility.json" __shortcuts__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/shortcuts.json" -__entities_download_url__ = "https://resources.snips.ai/nlu-lm/gazetteer-entities" +__entities_download_url__ = "https://resources.snips.ai/nlu/gazetteer-entities" # pylint:enable=line-too-long diff --git a/snips_nlu/common/utils.py b/snips_nlu/common/utils.py index 5e8bc7837..a3550bc22 100644 --- a/snips_nlu/common/utils.py +++ b/snips_nlu/common/utils.py @@ -163,6 +163,10 @@ def get_package_path(name): def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn): + """Deduplicates the items by looping over the items, sorted using + sort_key_fn, and checking overlaps with previously seen items using + overlap_fn + """ sorted_items = sorted(items, key=sort_key_fn) deduplicated_items = [] for item in sorted_items: @@ -173,6 +177,9 @@ def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn): def replace_entities_with_placeholders(text, entities, placeholder_fn): + """Processes the text in order to replace entity values with placeholders + as defined by the placeholder function + """ if not entities: return dict(), text @@ -207,6 +214,8 @@ def replace_entities_with_placeholders(text, entities, placeholder_fn): def deduplicate_overlapping_entities(entities): + """Deduplicates entities based on overlapping ranges""" + def overlap(lhs_entity, rhs_entity): return ranges_overlap(lhs_entity[RES_MATCH_RANGE], rhs_entity[RES_MATCH_RANGE]) diff --git a/snips_nlu/dataset/utils.py b/snips_nlu/dataset/utils.py index 816f906f2..a142c4f03 100644 --- a/snips_nlu/dataset/utils.py +++ b/snips_nlu/dataset/utils.py @@ -52,3 +52,16 @@ def get_dataset_gazetteer_entities(dataset, intent=None): if intent is not None: return extract_intent_entities(dataset, is_gazetteer_entity)[intent] return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)} + + +def get_stop_words_whitelist(dataset, stop_words): + """Extracts stop words whitelists per intent consisting of entity values + that appear in the stop_words list""" + entity_values_per_intent = extract_entity_values( + dataset, apply_normalization=True) + stop_words_whitelist = dict() + for intent, entity_values in iteritems(entity_values_per_intent): + whitelist = stop_words.intersection(entity_values) + if whitelist: + stop_words_whitelist[intent] = whitelist + return stop_words_whitelist diff --git a/snips_nlu/default_configs/config_de.py b/snips_nlu/default_configs/config_de.py index f19037965..200fc3064 100644 --- a/snips_nlu/default_configs/config_de.py +++ b/snips_nlu/default_configs/config_de.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_en.py b/snips_nlu/default_configs/config_en.py index 8ea17a19f..12f7ae12e 100644 --- a/snips_nlu/default_configs/config_en.py +++ b/snips_nlu/default_configs/config_en.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_es.py b/snips_nlu/default_configs/config_es.py index 3ae5a6d93..28969ced2 100644 --- a/snips_nlu/default_configs/config_es.py +++ b/snips_nlu/default_configs/config_es.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_fr.py b/snips_nlu/default_configs/config_fr.py index cdd1c35fe..a2da590a6 100644 --- a/snips_nlu/default_configs/config_fr.py +++ b/snips_nlu/default_configs/config_fr.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_it.py b/snips_nlu/default_configs/config_it.py index cdd1c35fe..a2da590a6 100644 --- a/snips_nlu/default_configs/config_it.py +++ b/snips_nlu/default_configs/config_it.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_ja.py b/snips_nlu/default_configs/config_ja.py index 49652df1f..b28791fea 100644 --- a/snips_nlu/default_configs/config_ja.py +++ b/snips_nlu/default_configs/config_ja.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": False }, { diff --git a/snips_nlu/default_configs/config_ko.py b/snips_nlu/default_configs/config_ko.py index 7381bd7dc..1630796bc 100644 --- a/snips_nlu/default_configs/config_ko.py +++ b/snips_nlu/default_configs/config_ko.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": False }, { diff --git a/snips_nlu/default_configs/config_pt_br.py b/snips_nlu/default_configs/config_pt_br.py index cf1a72823..450f0dbba 100644 --- a/snips_nlu/default_configs/config_pt_br.py +++ b/snips_nlu/default_configs/config_pt_br.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/default_configs/config_pt_pt.py b/snips_nlu/default_configs/config_pt_pt.py index cf1a72823..450f0dbba 100644 --- a/snips_nlu/default_configs/config_pt_pt.py +++ b/snips_nlu/default_configs/config_pt_pt.py @@ -4,9 +4,7 @@ "unit_name": "nlu_engine", "intent_parsers_configs": [ { - "unit_name": "deterministic_intent_parser", - "max_queries": 500, - "max_pattern_length": 1000, + "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, { diff --git a/snips_nlu/intent_classifier/featurizer.py b/snips_nlu/intent_classifier/featurizer.py index ffd8ab045..a1a44876f 100644 --- a/snips_nlu/intent_classifier/featurizer.py +++ b/snips_nlu/intent_classifier/featurizer.py @@ -758,7 +758,7 @@ def persist(self, path): } vectorizer_json = json_string(self_as_dict) vectorizer_path = path / "vectorizer.json" - with vectorizer_path.open(mode="w") as f: + with vectorizer_path.open(mode="w", encoding="utf8") as f: f.write(vectorizer_json) self.persist_metadata(path) diff --git a/snips_nlu/intent_classifier/log_reg_classifier.py b/snips_nlu/intent_classifier/log_reg_classifier.py index e70bbcd7c..137ccae49 100644 --- a/snips_nlu/intent_classifier/log_reg_classifier.py +++ b/snips_nlu/intent_classifier/log_reg_classifier.py @@ -222,7 +222,8 @@ def persist(self, path): } classifier_json = json_string(self_as_dict) - with (path / "intent_classifier.json").open(mode="w") as f: + with (path / "intent_classifier.json").open(mode="w", + encoding="utf8") as f: f.write(classifier_json) self.persist_metadata(path) diff --git a/snips_nlu/intent_parser/__init__.py b/snips_nlu/intent_parser/__init__.py index d5b4b0f10..1b0d446de 100644 --- a/snips_nlu/intent_parser/__init__.py +++ b/snips_nlu/intent_parser/__init__.py @@ -1,3 +1,4 @@ from .deterministic_intent_parser import DeterministicIntentParser from .intent_parser import IntentParser +from .lookup_intent_parser import LookupIntentParser from .probabilistic_intent_parser import ProbabilisticIntentParser diff --git a/snips_nlu/intent_parser/deterministic_intent_parser.py b/snips_nlu/intent_parser/deterministic_intent_parser.py index 7612095be..88e19c6eb 100644 --- a/snips_nlu/intent_parser/deterministic_intent_parser.py +++ b/snips_nlu/intent_parser/deterministic_intent_parser.py @@ -21,7 +21,7 @@ RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES, RES_PROBA) from snips_nlu.dataset import validate_and_format_dataset -from snips_nlu.dataset.utils import extract_entity_values +from snips_nlu.dataset.utils import get_stop_words_whitelist from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity from snips_nlu.exceptions import IntentNotFoundError, LoadingError from snips_nlu.intent_parser.intent_parser import IntentParser @@ -143,7 +143,7 @@ def fit(self, dataset, force_retrain=True): self.slot_names_to_entities = get_slot_name_mappings(dataset) self.group_names_to_slot_names = _get_group_names_to_slot_names( self.slot_names_to_entities) - self._stop_words_whitelist = _get_stop_words_whitelist( + self._stop_words_whitelist = get_stop_words_whitelist( dataset, self._stop_words) # Do not use ambiguous patterns that appear in more than one intent @@ -239,11 +239,12 @@ def placeholder_fn(entity_name): cleaned_processed_text = self._preprocess_text(processed_text, intent) for regex in self.regexes_per_intent[intent]: - res = self._get_matching_result(text, cleaned_processed_text, - regex, intent, mapping) + res = self._get_matching_result(text, cleaned_text, regex, + intent) if res is None and cleaned_text != cleaned_processed_text: - res = self._get_matching_result(text, cleaned_text, regex, - intent) + res = self._get_matching_result( + text, cleaned_processed_text, regex, intent, mapping) + if res is not None: results.append(res) break @@ -300,6 +301,7 @@ def get_slots(self, text, intent): if intent not in self.regexes_per_intent: raise IntentNotFoundError(intent) + slots = self.parse(text, intents=[intent])[RES_SLOTS] if slots is None: slots = [] @@ -408,7 +410,7 @@ def persist(self, path): parser_json = json_string(self.to_dict()) parser_path = path / "intent_parser.json" - with parser_path.open(mode="w") as f: + with parser_path.open(mode="w", encoding="utf8") as f: f.write(parser_json) self.persist_metadata(path) @@ -514,14 +516,3 @@ def sort_key_fn(slot): def _get_entity_name_placeholder(entity_label, language): return "%%%s%%" % "".join( tokenize_light(entity_label, language)).upper() - - -def _get_stop_words_whitelist(dataset, stop_words): - entity_values_per_intent = extract_entity_values( - dataset, apply_normalization=True) - stop_words_whitelist = dict() - for intent, entity_values in iteritems(entity_values_per_intent): - whitelist = stop_words.intersection(entity_values) - if whitelist: - stop_words_whitelist[intent] = whitelist - return stop_words_whitelist diff --git a/snips_nlu/intent_parser/lookup_intent_parser.py b/snips_nlu/intent_parser/lookup_intent_parser.py new file mode 100644 index 000000000..4f4bf082d --- /dev/null +++ b/snips_nlu/intent_parser/lookup_intent_parser.py @@ -0,0 +1,509 @@ +from __future__ import unicode_literals + +import json +import logging +from builtins import str +from collections import defaultdict +from itertools import combinations +from pathlib import Path + +from future.utils import iteritems, itervalues +from snips_nlu_utils import normalize, hash_str + +from snips_nlu.common.log_utils import log_elapsed_time, log_result +from snips_nlu.common.utils import ( + check_persisted_path, deduplicate_overlapping_entities, fitted_required, + json_string) +from snips_nlu.constants import ( + DATA, END, ENTITIES, ENTITY, ENTITY_KIND, INTENTS, LANGUAGE, RES_INTENT, + RES_INTENT_NAME, RES_MATCH_RANGE, RES_SLOTS, SLOT_NAME, START, TEXT, + UTTERANCES, RES_PROBA) +from snips_nlu.dataset import ( + validate_and_format_dataset, extract_intent_entities) +from snips_nlu.dataset.utils import get_stop_words_whitelist +from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_nlu.exceptions import IntentNotFoundError, LoadingError +from snips_nlu.intent_parser.intent_parser import IntentParser +from snips_nlu.pipeline.configs import LookupIntentParserConfig +from snips_nlu.preprocessing import tokenize_light +from snips_nlu.resources import get_stop_words +from snips_nlu.result import ( + empty_result, intent_classification_result, parsing_result, + unresolved_slot, extraction_result) + +logger = logging.getLogger(__name__) + + +@IntentParser.register("lookup_intent_parser") +class LookupIntentParser(IntentParser): + """A deterministic Intent parser implementation based on a dictionary + + This intent parser is very strict by nature, and tends to have a very good + precision but a low recall. For this reason, it is interesting to use it + first before potentially falling back to another parser. + """ + + config_type = LookupIntentParserConfig + + def __init__(self, config=None, **shared): + """The lookup intent parser can be configured by passing a + :class:`.LookupIntentParserConfig`""" + super(LookupIntentParser, self).__init__(config, **shared) + self._language = None + self._stop_words = None + self._stop_words_whitelist = None + self._map = None + self._intents_names = [] + self._slots_names = [] + self._intents_mapping = dict() + self._slots_mapping = dict() + self._entity_scopes = None + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + self._language = value + if value is None: + self._stop_words = None + else: + if self.config.ignore_stop_words: + self._stop_words = get_stop_words(self.resources) + else: + self._stop_words = set() + + @property + def fitted(self): + """Whether or not the intent parser has already been trained""" + return self._map is not None + + @log_elapsed_time( + logger, logging.INFO, "Fitted lookup intent parser in {elapsed_time}") + def fit(self, dataset, force_retrain=True): + """Fits the intent parser with a valid Snips dataset""" + logger.info("Fitting lookup intent parser...") + dataset = validate_and_format_dataset(dataset) + self.load_resources_if_needed(dataset[LANGUAGE]) + self.fit_builtin_entity_parser_if_needed(dataset) + self.fit_custom_entity_parser_if_needed(dataset) + self.language = dataset[LANGUAGE] + self._entity_scopes = _get_entity_scopes(dataset) + self._map = dict() + self._stop_words_whitelist = get_stop_words_whitelist( + dataset, self._stop_words) + entity_placeholders = _get_entity_placeholders(dataset, self.language) + + ambiguous_keys = set() + for (key, val) in self._generate_io_mapping(dataset[INTENTS], + entity_placeholders): + key = hash_str(key) + # handle key collisions -*- flag ambiguous entries -*- + if key in self._map and self._map[key] != val: + ambiguous_keys.add(key) + else: + self._map[key] = val + + # delete ambiguous keys + for key in ambiguous_keys: + self._map.pop(key) + + return self + + @log_result(logger, logging.DEBUG, "LookupIntentParser result -> {result}") + @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.") + @fitted_required + def parse(self, text, intents=None, top_n=None): + """Performs intent parsing on the provided *text* + + Intent and slots are extracted simultaneously through pattern matching + + Args: + text (str): input + intents (str or list of str): if provided, reduces the scope of + intent parsing to the provided list of intents + top_n (int, optional): when provided, this method will return a + list of at most top_n most likely intents, instead of a single + parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + + Raises: + NotTrained: when the intent parser is not fitted + """ + if top_n is None: + top_intents = self._parse_top_intents(text, top_n=1, + intents=intents) + if top_intents: + intent = top_intents[0][RES_INTENT] + slots = top_intents[0][RES_SLOTS] + if intent[RES_PROBA] <= 0.5: + # return None in case of ambiguity + return empty_result(text, probability=1.0) + return parsing_result(text, intent, slots) + return empty_result(text, probability=1.0) + return self._parse_top_intents(text, top_n=top_n, intents=intents) + + def _parse_top_intents(self, text, top_n, intents=None): + if isinstance(intents, str): + intents = {intents} + elif isinstance(intents, list): + intents = set(intents) + + if top_n < 1: + raise ValueError( + "top_n argument must be greater or equal to 1, but got: %s" + % top_n) + + results_per_intent = defaultdict(list) + for text_candidate, entities in self._get_candidates(text, intents): + val = self._map.get(hash_str(text_candidate)) + if val is not None: + result = self._parse_map_output(text, val, entities, intents) + if result: + intent_name = result[RES_INTENT][RES_INTENT_NAME] + results_per_intent[intent_name].append(result) + + results = [] + for intent_results in itervalues(results_per_intent): + sorted_results = sorted(intent_results, + key=lambda res: len(res[RES_SLOTS])) + results.append(sorted_results[0]) + + # In some rare cases there can be multiple ambiguous intents + # In such cases, priority is given to results containing fewer slots + weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results] + total_weight = sum(weights) + + for res, weight in zip(results, weights): + res[RES_INTENT][RES_PROBA] = weight / total_weight + + results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA]) + return results[:top_n] + + def _get_candidates(self, text, intents): + candidates = defaultdict(list) + for grouped_entity_scope in self._entity_scopes: + entity_scope = grouped_entity_scope["entity_scope"] + intent_group = grouped_entity_scope["intent_group"] + intent_group = [intent_ for intent_ in intent_group + if intents is None or intent_ in intents] + if not intent_group: + continue + + builtin_entities = self.builtin_entity_parser.parse( + text, scope=entity_scope["builtin"], use_cache=True) + custom_entities = self.custom_entity_parser.parse( + text, scope=entity_scope["custom"], use_cache=True) + all_entities = builtin_entities + custom_entities + all_entities = deduplicate_overlapping_entities(all_entities) + + # We generate all subsets of entities to match utterances + # containing ambivalent words which can be both entity values or + # random words + for entities in _get_entities_combinations(all_entities): + processed_text = self._replace_entities_with_placeholders( + text, entities) + for intent in intent_group: + cleaned_text = self._preprocess_text(text, intent) + cleaned_processed_text = self._preprocess_text( + processed_text, intent) + + raw_candidate = cleaned_text, [] + placeholder_candidate = cleaned_processed_text, entities + intent_candidates = [raw_candidate, placeholder_candidate] + for text_input, text_entities in intent_candidates: + if text_input not in candidates \ + or text_entities not in candidates[text_input]: + candidates[text_input].append(text_entities) + yield text_input, text_entities + + def _parse_map_output(self, text, output, entities, intents): + """Parse the map output to the parser's result format""" + intent_id, slot_ids = output + intent_name = self._intents_names[intent_id] + if intents is not None and intent_name not in intents: + return None + + parsed_intent = intent_classification_result( + intent_name=intent_name, probability=1.0) + slots = [] + # assert invariant + assert len(slot_ids) == len(entities) + for slot_id, entity in zip(slot_ids, entities): + slot_name = self._slots_names[slot_id] + rng_start = entity[RES_MATCH_RANGE][START] + rng_end = entity[RES_MATCH_RANGE][END] + slot_value = text[rng_start:rng_end] + entity_name = entity[ENTITY_KIND] + slot = unresolved_slot( + [rng_start, rng_end], slot_value, entity_name, slot_name) + slots.append(slot) + + return extraction_result(parsed_intent, slots) + + @fitted_required + def get_intents(self, text): + """Returns the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + """ + nb_intents = len(self._intents_names) + top_intents = [intent_result[RES_INTENT] for intent_result in + self._parse_top_intents(text, top_n=nb_intents)] + matched_intents = {res[RES_INTENT_NAME] for res in top_intents} + for intent in self._intents_names: + if intent not in matched_intents: + top_intents.append(intent_classification_result(intent, 0.0)) + + # The None intent is not included in the lookup table and is thus + # never matched by the lookup parser + top_intents.append(intent_classification_result(None, 0.0)) + return top_intents + + @fitted_required + def get_slots(self, text, intent): + """Extracts slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: When the intent was not part of the training + data + """ + if intent is None: + return [] + + if intent not in self._intents_names: + raise IntentNotFoundError(intent) + + slots = self.parse(text, intents=[intent])[RES_SLOTS] + if slots is None: + slots = [] + return slots + + def _get_intent_stop_words(self, intent): + whitelist = self._stop_words_whitelist.get(intent, set()) + return self._stop_words.difference(whitelist) + + def _get_intent_id(self, intent_name): + """generate a numeric id for an intent + + Args: + intent_name (str): intent name + + Returns: + int: numeric id + + """ + intent_id = self._intents_mapping.get(intent_name) + if intent_id is None: + intent_id = len(self._intents_names) + self._intents_names.append(intent_name) + self._intents_mapping[intent_name] = intent_id + + return intent_id + + def _get_slot_id(self, slot_name): + """generate a numeric id for a slot + + Args: + slot_name (str): intent name + + Returns: + int: numeric id + + """ + slot_id = self._slots_mapping.get(slot_name) + if slot_id is None: + slot_id = len(self._slots_names) + self._slots_names.append(slot_name) + self._slots_mapping[slot_name] = slot_id + + return slot_id + + def _preprocess_text(self, txt, intent): + """Replaces stop words and characters that are tokenized out by + whitespaces""" + stop_words = self._get_intent_stop_words(intent) + tokens = tokenize_light(txt, self.language) + cleaned_string = " ".join( + [tkn for tkn in tokens if normalize(tkn) not in stop_words]) + return cleaned_string.lower() + + def _generate_io_mapping(self, intents, entity_placeholders): + """Generate input-output pairs""" + for intent_name, intent in sorted(iteritems(intents)): + intent_id = self._get_intent_id(intent_name) + for entry in intent[UTTERANCES]: + yield self._build_io_mapping( + intent_id, entry, entity_placeholders) + + def _build_io_mapping(self, intent_id, utterance, entity_placeholders): + input_ = [] + output = [intent_id] + slots = [] + for chunk in utterance[DATA]: + if SLOT_NAME in chunk: + slot_name = chunk[SLOT_NAME] + slot_id = self._get_slot_id(slot_name) + entity_name = chunk[ENTITY] + placeholder = entity_placeholders[entity_name] + input_.append(placeholder) + slots.append(slot_id) + else: + input_.append(chunk[TEXT]) + output.append(slots) + + intent = self._intents_names[intent_id] + key = self._preprocess_text(" ".join(input_), intent) + + return key, output + + def _replace_entities_with_placeholders(self, text, entities): + if not entities: + return text + entities = sorted(entities, key=lambda e: e[RES_MATCH_RANGE][START]) + processed_text = "" + current_idx = 0 + for ent in entities: + start = ent[RES_MATCH_RANGE][START] + end = ent[RES_MATCH_RANGE][END] + processed_text += text[current_idx:start] + place_holder = _get_entity_name_placeholder( + ent[ENTITY_KIND], self.language) + processed_text += place_holder + current_idx = end + processed_text += text[current_idx:] + + return processed_text + + @check_persisted_path + def persist(self, path): + """Persists the object at the given path""" + path.mkdir() + parser_json = json_string(self.to_dict()) + parser_path = path / "intent_parser.json" + + with parser_path.open(mode="w", encoding="utf8") as pfile: + pfile.write(parser_json) + self.persist_metadata(path) + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`LookupIntentParser` instance from a path + + The data at the given path must have been generated using + :func:`~LookupIntentParser.persist` + """ + path = Path(path) + model_path = path / "intent_parser.json" + if not model_path.exists(): + raise LoadingError( + "Missing lookup intent parser metadata file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as pfile: + metadata = json.load(pfile) + return cls.from_dict(metadata, **shared) + + def to_dict(self): + """Returns a json-serializable dict""" + stop_words_whitelist = None + if self._stop_words_whitelist is not None: + stop_words_whitelist = { + intent: sorted(values) + for intent, values in iteritems(self._stop_words_whitelist)} + return { + "config": self.config.to_dict(), + "language_code": self.language, + "map": self._map, + "slots_names": self._slots_names, + "intents_names": self._intents_names, + "entity_scopes": self._entity_scopes, + "stop_words_whitelist": stop_words_whitelist, + } + + @classmethod + def from_dict(cls, unit_dict, **shared): + """Creates a :class:`LookupIntentParser` instance from a dict + + The dict must have been generated with + :func:`~LookupIntentParser.to_dict` + """ + config = cls.config_type.from_dict(unit_dict["config"]) + parser = cls(config=config, **shared) + parser.language = unit_dict["language_code"] + # pylint:disable=protected-access + parser._map = _convert_dict_keys_to_int(unit_dict["map"]) + parser._slots_names = unit_dict["slots_names"] + parser._intents_names = unit_dict["intents_names"] + parser._entity_scopes = unit_dict["entity_scopes"] + if parser.fitted: + whitelist = unit_dict["stop_words_whitelist"] + parser._stop_words_whitelist = { + intent: set(values) for intent, values in iteritems(whitelist)} + # pylint:enable=protected-access + return parser + + +def _get_entity_scopes(dataset): + intent_entities = extract_intent_entities(dataset) + intent_groups = [] + entity_scopes = [] + for intent, entities in sorted(iteritems(intent_entities)): + scope = { + "builtin": list( + {ent for ent in entities if is_builtin_entity(ent)}), + "custom": list( + {ent for ent in entities if not is_builtin_entity(ent)}) + } + if scope in entity_scopes: + group_idx = entity_scopes.index(scope) + intent_groups[group_idx].append(intent) + else: + entity_scopes.append(scope) + intent_groups.append([intent]) + return [ + { + "intent_group": intent_group, + "entity_scope": entity_scope + } for intent_group, entity_scope in zip(intent_groups, entity_scopes) + ] + + +def _get_entity_placeholders(dataset, language): + return { + e: _get_entity_name_placeholder(e, language) for e in dataset[ENTITIES] + } + + +def _get_entity_name_placeholder(entity_label, language): + return "%%%s%%" % "".join(tokenize_light(entity_label, language)).upper() + + +def _convert_dict_keys_to_int(dct): + if isinstance(dct, dict): + return {int(k): v for k, v in iteritems(dct)} + return dct + + +def _get_entities_combinations(entities): + yield () + for nb_entities in reversed(range(1, len(entities) + 1)): + for combination in combinations(entities, nb_entities): + yield combination diff --git a/snips_nlu/nlu_engine/nlu_engine.py b/snips_nlu/nlu_engine/nlu_engine.py index 303609858..1af2bb410 100644 --- a/snips_nlu/nlu_engine/nlu_engine.py +++ b/snips_nlu/nlu_engine/nlu_engine.py @@ -313,7 +313,7 @@ def persist(self, path): model_json = json_string(model) model_path = path / "nlu_engine.json" - with model_path.open(mode="w") as f: + with model_path.open(mode="w", encoding="utf8") as f: f.write(model_json) if self.fitted: diff --git a/snips_nlu/pipeline/configs/__init__.py b/snips_nlu/pipeline/configs/__init__.py index f8bb79361..027f286c2 100644 --- a/snips_nlu/pipeline/configs/__init__.py +++ b/snips_nlu/pipeline/configs/__init__.py @@ -1,9 +1,10 @@ from .config import Config, ProcessingUnitConfig from .features import default_features_factories -from .intent_classifier import ( - LogRegIntentClassifierConfig, IntentClassifierDataAugmentationConfig, - FeaturizerConfig, CooccurrenceVectorizerConfig) +from .intent_classifier import (CooccurrenceVectorizerConfig, FeaturizerConfig, + IntentClassifierDataAugmentationConfig, + LogRegIntentClassifierConfig) from .intent_parser import (DeterministicIntentParserConfig, + LookupIntentParserConfig, ProbabilisticIntentParserConfig) from .nlu_engine import NLUEngineConfig from .slot_filler import CRFSlotFillerConfig, SlotFillerDataAugmentationConfig diff --git a/snips_nlu/pipeline/configs/intent_parser.py b/snips_nlu/pipeline/configs/intent_parser.py index bdc56e083..b7cc32abc 100644 --- a/snips_nlu/pipeline/configs/intent_parser.py +++ b/snips_nlu/pipeline/configs/intent_parser.py @@ -95,3 +95,33 @@ def to_dict(self): "max_pattern_length": self.max_pattern_length, "ignore_stop_words": self.ignore_stop_words } + + +class LookupIntentParserConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.LookupIntentParser` + + Args: + ignore_stop_words (bool, optional): If True, stop words will be + removed before building patterns. + """ + + def __init__(self, ignore_stop_words=False): + self.ignore_stop_words = ignore_stop_words + + @property + def unit_name(self): + from snips_nlu.intent_parser.lookup_intent_parser import \ + LookupIntentParser + return LookupIntentParser.unit_name + + def get_required_resources(self): + return { + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS, + STOP_WORDS: self.ignore_stop_words + } + + def to_dict(self): + return { + "unit_name": self.unit_name, + "ignore_stop_words": self.ignore_stop_words + } diff --git a/snips_nlu/pipeline/processing_unit.py b/snips_nlu/pipeline/processing_unit.py index 72abffd63..5f47a1861 100644 --- a/snips_nlu/pipeline/processing_unit.py +++ b/snips_nlu/pipeline/processing_unit.py @@ -165,7 +165,7 @@ def persist_metadata(self, path, **kwargs): metadata = {"unit_name": self.unit_name} metadata.update(kwargs) metadata_json = json_string(metadata) - with (path / "metadata.json").open(mode="w") as f: + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: f.write(metadata_json) @abstractmethod diff --git a/snips_nlu/slot_filler/crf_slot_filler.py b/snips_nlu/slot_filler/crf_slot_filler.py index 8a68a49d6..60b198268 100644 --- a/snips_nlu/slot_filler/crf_slot_filler.py +++ b/snips_nlu/slot_filler/crf_slot_filler.py @@ -363,7 +363,7 @@ def persist(self, path): } model_json = json_string(model) model_path = path / "slot_filler.json" - with model_path.open(mode="w") as f: + with model_path.open(mode="w", encoding="utf8") as f: f.write(model_json) self.persist_metadata(path) diff --git a/snips_nlu/slot_filler/keyword_slot_filler.py b/snips_nlu/slot_filler/keyword_slot_filler.py index 1830bc14e..2df876c6e 100644 --- a/snips_nlu/slot_filler/keyword_slot_filler.py +++ b/snips_nlu/slot_filler/keyword_slot_filler.py @@ -56,7 +56,7 @@ def persist(self, path): "slots_keywords": self.slots_keywords, "config": self.config.to_dict() } - with path.open(mode="w") as f: + with path.open(mode="w", encoding="utf8") as f: f.write(json_string(model)) @classmethod diff --git a/snips_nlu/tests/test_builtin_entity_parser.py b/snips_nlu/tests/test_builtin_entity_parser.py index da7178744..6cface15f 100644 --- a/snips_nlu/tests/test_builtin_entity_parser.py +++ b/snips_nlu/tests/test_builtin_entity_parser.py @@ -41,7 +41,7 @@ def test_should_parse_grammar_entities(self): def test_should_parse_gazetteer_entities(self): # Given - text = "je veux ecouter les daft punk s'il vous plait" + text = "je veux ecouter daft punk s'il vous plait" parser = BuiltinEntityParser.build( language="fr", gazetteer_entity_scope=["snips/musicArtist"]) @@ -56,8 +56,8 @@ def test_should_parse_gazetteer_entities(self): }, "entity_kind": "snips/musicArtist", "range": { - "end": 29, - "start": 20 + "end": 25, + "start": 16 }, "value": "daft punk" } @@ -107,7 +107,7 @@ def test_should_not_disambiguate_grammar_and_gazetteer_entities(self): }, "resolved_value": { "kind": "MusicTrack", - "value": "3 nuits par semaine" + "value": "Trois nuits par semaine" }, "entity_kind": "snips/musicTrack" } diff --git a/snips_nlu/tests/test_cli.py b/snips_nlu/tests/test_cli.py index 0a2260ff7..883f915fc 100644 --- a/snips_nlu/tests/test_cli.py +++ b/snips_nlu/tests/test_cli.py @@ -53,7 +53,7 @@ def setUp(self): self.beverage_dataset_path = self.fixture_dir / "beverage_dataset.json" if self.beverage_dataset_path.exists(): self.beverage_dataset_path.unlink() - with self.beverage_dataset_path.open(mode="w") as f: + with self.beverage_dataset_path.open(mode="w", encoding="utf8") as f: f.write(json_string(beverage_dataset)) self.tmp_file_path = self.fixture_dir / next( @@ -147,7 +147,7 @@ def test_generate_dataset(self): values: - [new york, big apple]""" self.tmp_file_path = self.tmp_file_path.with_suffix(".yaml") - with self.tmp_file_path.open(mode="w") as f: + with self.tmp_file_path.open(mode="w", encoding="utf8") as f: f.write(unicode_string(yaml_string)) # When diff --git a/snips_nlu/tests/test_deterministic_intent_parser.py b/snips_nlu/tests/test_deterministic_intent_parser.py index 2c21a218c..ff7990bcb 100644 --- a/snips_nlu/tests/test_deterministic_intent_parser.py +++ b/snips_nlu/tests/test_deterministic_intent_parser.py @@ -3,6 +3,7 @@ import io from builtins import range +from copy import deepcopy from checksumdir import dirhash from mock import patch @@ -42,7 +43,7 @@ def setUp(self): This is a [dummy_slot_name](dummy_1) query with another [dummy_slot_name2](dummy_2) [startTime](at 10p.m.) or [startTime](tomorrow) - - "This is a [dummy_slot_name](dummy_1) " + - "This is a [dummy_slot_name](dummy_1) " - "[startTime](tomorrow evening) there is a [dummy_slot_name](dummy_1)" --- @@ -452,7 +453,7 @@ def test_should_parse_stop_words_slots(self): - [this thing, that] """) - resources = self.get_resources("en") + resources = deepcopy(self.get_resources("en")) resources[STOP_WORDS] = {"a", "this", "that"} dataset = Dataset.from_yaml_files("en", [dataset_stream]).json parser_config = DeterministicIntentParserConfig(ignore_stop_words=True) @@ -482,24 +483,26 @@ def test_should_parse_stop_words_slots(self): def test_should_get_intents(self): # Given - dataset_stream = io.StringIO(""" + dataset_stream = io.StringIO( + """ --- type: intent name: greeting1 utterances: - - Hello [name](John) + - Hello John --- type: intent name: greeting2 utterances: - - How are you [name](Thomas) - + - Hello [name](John) + --- type: intent name: greeting3 utterances: - - Hi [name](Robert)""") + - "[greeting](Hello) [name](John)" + """) dataset = Dataset.from_yaml_files("en", [dataset_stream]).json parser = DeterministicIntentParser().fit(dataset) @@ -509,10 +512,22 @@ def test_should_get_intents(self): # Then expected_intents = [ - {RES_INTENT_NAME: "greeting1", RES_PROBA: 1.0}, - {RES_INTENT_NAME: "greeting2", RES_PROBA: 0.0}, - {RES_INTENT_NAME: "greeting3", RES_PROBA: 0.0}, - {RES_INTENT_NAME: None, RES_PROBA: 0.0} + { + RES_INTENT_NAME: "greeting1", + RES_PROBA: 1. / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: "greeting2", + RES_PROBA: (1. / 2.) / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: "greeting3", + RES_PROBA: (1. / 3.) / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: None, + RES_PROBA: 0.0 + }, ] def sorting_key(intent_res): diff --git a/snips_nlu/tests/test_lookup_intent_parser.py b/snips_nlu/tests/test_lookup_intent_parser.py new file mode 100644 index 000000000..a0750f337 --- /dev/null +++ b/snips_nlu/tests/test_lookup_intent_parser.py @@ -0,0 +1,1156 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import io +from copy import deepcopy + +from mock import patch +from snips_nlu_utils import hash_str + +from snips_nlu.constants import ( + DATA, ENTITY, RES_ENTITY, RES_INTENT, RES_INTENT_NAME, + RES_PROBA, RES_SLOTS, RES_VALUE, SLOT_NAME, TEXT, STOP_WORDS) +from snips_nlu.dataset import Dataset +from snips_nlu.entity_parser import BuiltinEntityParser +from snips_nlu.exceptions import IntentNotFoundError, NotTrained +from snips_nlu.intent_parser import LookupIntentParser +from snips_nlu.intent_parser.lookup_intent_parser import _get_entity_scopes +from snips_nlu.pipeline.configs import LookupIntentParserConfig +from snips_nlu.result import ( + empty_result, extraction_result, intent_classification_result, + unresolved_slot, parsing_result) +from snips_nlu.tests.utils import FixtureTest, TEST_PATH, EntityParserMock + + +class TestLookupIntentParser(FixtureTest): + def setUp(self): + super(TestLookupIntentParser, self).setUp() + slots_dataset_stream = io.StringIO( + """ +--- +type: intent +name: dummy_intent_1 +slots: + - name: dummy_slot_name + entity: dummy_entity_1 + - name: dummy_slot_name2 + entity: dummy_entity_2 + - name: startTime + entity: snips/datetime +utterances: + - > + This is a [dummy_slot_name](dummy_1) query with another + [dummy_slot_name2](dummy_2) [startTime](at 10p.m.) or + [startTime](tomorrow) + - "This is a [dummy_slot_name](dummy_1) " + - "[startTime](tomorrow evening) there is a [dummy_slot_name](dummy_1)" + +--- +type: entity +name: dummy_entity_1 +automatically_extensible: no +values: +- [dummy_a, dummy 2a, dummy a, 2 dummy a] +- [dummy_b, dummy b, dummy_bb, dummy_b] +- dummy d + +--- +type: entity +name: dummy_entity_2 +automatically_extensible: no +values: +- [dummy_c, 3p.m., dummy_cc, dummy c]""") + self.slots_dataset = Dataset.from_yaml_files( + "en", [slots_dataset_stream]).json + + def test_should_parse_intent(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: intent1 +utterances: + - foo bar baz + +--- +type: intent +name: intent2 +utterances: + - foo bar ban""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "foo bar ban" + + # When + parsing = parser.parse(text) + + # Then + probability = 1.0 + expected_intent = intent_classification_result( + intent_name="intent2", probability=probability) + + self.assertEqual(expected_intent, parsing[RES_INTENT]) + + def test_should_parse_intent_with_filter(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: intent1 +utterances: + - foo bar baz + +--- +type: intent +name: intent2 +utterances: + - foo bar ban""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "foo bar ban" + + # When + parsing = parser.parse(text, intents=["intent1"]) + + # Then + self.assertEqual(empty_result(text, 1.0), parsing) + + def test_should_parse_top_intents(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: intent1 +utterances: + - meeting [time:snips/datetime](today) + +--- +type: intent +name: intent2 +utterances: + - meeting tomorrow + +--- +type: intent +name: intent3 +utterances: + - "[event_type](call) [time:snips/datetime](at 9pm)" + +--- +type: entity +name: event_type +values: + - meeting + - feedback session""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "meeting tomorrow" + + # When + results = parser.parse(text, top_n=3) + + # Then + time_slot = { + "entity": "snips/datetime", + "range": {"end": 16, "start": 8}, + "slotName": "time", + "value": "tomorrow" + } + event_slot = { + "entity": "event_type", + "range": {"end": 7, "start": 0}, + "slotName": "event_type", + "value": "meeting" + } + weight_intent_1 = 1. / 2. + weight_intent_2 = 1. + weight_intent_3 = 1. / 3. + total_weight = weight_intent_1 + weight_intent_2 + weight_intent_3 + proba_intent2 = weight_intent_2 / total_weight + proba_intent1 = weight_intent_1 / total_weight + proba_intent3 = weight_intent_3 / total_weight + expected_results = [ + extraction_result( + intent_classification_result( + intent_name="intent2", probability=proba_intent2), + slots=[]), + extraction_result( + intent_classification_result( + intent_name="intent1", probability=proba_intent1), + slots=[time_slot]), + extraction_result( + intent_classification_result( + intent_name="intent3", probability=proba_intent3), + slots=[event_slot, time_slot]) + ] + self.assertEqual(expected_results, results) + + @patch("snips_nlu.intent_parser.lookup_intent_parser" ".get_stop_words") + def test_should_parse_intent_with_stop_words(self, mock_get_stop_words): + # Given + mock_get_stop_words.return_value = {"a", "hey"} + dataset = self.slots_dataset + config = LookupIntentParserConfig(ignore_stop_words=True) + parser = LookupIntentParser(config).fit(dataset) + text = "Hey this is dummy_a query with another dummy_c at 10p.m. " \ + "or at 12p.m." + + # When + parsing = parser.parse(text) + + # Then + probability = 1.0 + expected_intent = intent_classification_result( + intent_name="dummy_intent_1", probability=probability) + + self.assertEqual(expected_intent, parsing[RES_INTENT]) + + def test_should_parse_intent_with_duplicated_slot_names(self): + # Given + slots_dataset_stream = io.StringIO(""" +--- +type: intent +name: math_operation +slots: + - name: number + entity: snips/number +utterances: + - what is [number](one) plus [number](one)""") + dataset = Dataset.from_yaml_files("en", [slots_dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "what is one plus one" + + # When + parsing = parser.parse(text) + + # Then + probability = 1.0 + expected_intent = intent_classification_result( + intent_name="math_operation", probability=probability) + expected_slots = [ + { + "entity": "snips/number", + "range": {"end": 11, "start": 8}, + "slotName": "number", + "value": "one" + }, + { + "entity": "snips/number", + "range": {"end": 20, "start": 17}, + "slotName": "number", + "value": "one" + } + ] + + self.assertDictEqual(expected_intent, parsing[RES_INTENT]) + self.assertListEqual(expected_slots, parsing[RES_SLOTS]) + + def test_should_parse_intent_with_ambivalent_words(self): + # Given + slots_dataset_stream = io.StringIO(""" +--- +type: intent +name: give_flower +utterances: + - give a rose to [name](emily) + - give a daisy to [name](tom) + - give a tulip to [name](daisy) + """) + dataset = Dataset.from_yaml_files("en", + [slots_dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "give a daisy to emily" + + # When + parsing = parser.parse(text) + + # Then + expected_intent = intent_classification_result( + intent_name="give_flower", probability=1.0) + expected_slots = [ + { + "entity": "name", + "range": {"end": 21, "start": 16}, + "slotName": "name", + "value": "emily" + } + ] + + self.assertDictEqual(expected_intent, parsing[RES_INTENT]) + self.assertListEqual(expected_slots, parsing[RES_SLOTS]) + + def test_should_ignore_completely_ambiguous_utterances(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: dummy_intent_1 +utterances: + - Hello world + +--- +type: intent +name: dummy_intent_2 +utterances: + - Hello world""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "Hello world" + + # When + res = parser.parse(text) + + # Then + self.assertEqual(empty_result(text, 1.0), res) + + def test_should_ignore_very_ambiguous_utterances(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: intent_1 +utterances: + - "[event_type](meeting) tomorrow" + +--- +type: intent +name: intent_2 +utterances: + - call [time:snips/datetime](today) + +--- +type: entity +name: event_type +values: + - call + - diner""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "call tomorrow" + + # When + res = parser.parse(text) + + # Then + self.assertEqual(empty_result(text, 1.0), res) + + def test_should_parse_slightly_ambiguous_utterances(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: intent_1 +utterances: + - call tomorrow + +--- +type: intent +name: intent_2 +utterances: + - call [time:snips/datetime](today)""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + text = "call tomorrow" + + # When + res = parser.parse(text) + + # Then + expected_intent = intent_classification_result( + intent_name="intent_1", probability=2. / 3.) + expected_result = parsing_result(text, expected_intent, []) + self.assertEqual(expected_result, res) + + def test_should_not_parse_when_not_fitted(self): + # Given + parser = LookupIntentParser() + + # When / Then + self.assertFalse(parser.fitted) + with self.assertRaises(NotTrained): + parser.parse("foobar") + + def test_should_parse_intent_after_deserialization(self): + # Given + dataset = self.slots_dataset + shared = self.get_shared_data(dataset) + parser = LookupIntentParser(**shared).fit(dataset) + parser.persist(self.tmp_file_path) + deserialized_parser = LookupIntentParser.from_path( + self.tmp_file_path, **shared) + text = "this is a dummy_a query with another dummy_c at 10p.m. or " \ + "at 12p.m." + + # When + parsing = deserialized_parser.parse(text) + + # Then + probability = 1.0 + expected_intent = intent_classification_result( + intent_name="dummy_intent_1", probability=probability) + self.assertEqual(expected_intent, parsing[RES_INTENT]) + + def test_should_parse_slots(self): + # Given + dataset = self.slots_dataset + parser = LookupIntentParser().fit(dataset) + texts = [ + ( + "this is a dummy a query with another dummy_c at 10p.m. or at" + " 12p.m.", + [ + unresolved_slot( + match_range=(10, 17), + value="dummy a", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ), + unresolved_slot( + match_range=(37, 44), + value="dummy_c", + entity="dummy_entity_2", + slot_name="dummy_slot_name2", + ), + unresolved_slot( + match_range=(45, 54), + value="at 10p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + unresolved_slot( + match_range=(58, 67), + value="at 12p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + ], + ), + ( + "this, is,, a, dummy a query with another dummy_c at 10pm or " + "at 12p.m.", + [ + unresolved_slot( + match_range=(14, 21), + value="dummy a", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ), + unresolved_slot( + match_range=(41, 48), + value="dummy_c", + entity="dummy_entity_2", + slot_name="dummy_slot_name2", + ), + unresolved_slot( + match_range=(49, 56), + value="at 10pm", + entity="snips/datetime", + slot_name="startTime", + ), + unresolved_slot( + match_range=(60, 69), + value="at 12p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + ], + ), + ( + "this is a dummy b", + [ + unresolved_slot( + match_range=(10, 17), + value="dummy b", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ) + ], + ), + ( + " this is a dummy b ", + [ + unresolved_slot( + match_range=(11, 18), + value="dummy b", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ) + ], + ), + ( + " at 8am ’ there is a dummy a", + [ + unresolved_slot( + match_range=(1, 7), + value="at 8am", + entity="snips/datetime", + slot_name="startTime", + ), + unresolved_slot( + match_range=(21, 29), + value="dummy a", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ), + ], + ), + ] + + for text, expected_slots in texts: + # When + parsing = parser.parse(text) + + # Then + self.assertListEqual(expected_slots, parsing[RES_SLOTS]) + + def test_should_parse_stop_words_slots(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: search +utterances: + - search + - search [search_object](this) + - search [search_object](a cat) + +--- +type: entity +name: search_object +values: + - [this thing, that] + """) + + resources = deepcopy(self.get_resources("en")) + resources[STOP_WORDS] = {"a", "this", "that"} + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser_config = LookupIntentParserConfig(ignore_stop_words=True) + parser = LookupIntentParser(config=parser_config, resources=resources) + parser.fit(dataset) + + # When + res_1 = parser.parse("search this") + res_2 = parser.parse("search that") + + # Then + expected_intent = intent_classification_result( + intent_name="search", probability=1.0) + expected_slots_1 = [ + unresolved_slot(match_range=(7, 11), value="this", + entity="search_object", + slot_name="search_object") + ] + expected_slots_2 = [ + unresolved_slot(match_range=(7, 11), value="that", + entity="search_object", + slot_name="search_object") + ] + self.assertEqual(expected_intent, res_1[RES_INTENT]) + self.assertEqual(expected_intent, res_2[RES_INTENT]) + self.assertListEqual(expected_slots_1, res_1[RES_SLOTS]) + self.assertListEqual(expected_slots_2, res_2[RES_SLOTS]) + + def test_should_get_intents(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: greeting1 +utterances: + - Hello John + +--- +type: intent +name: greeting2 +utterances: + - Hello [name](John) + +--- +type: intent +name: greeting3 +utterances: + - "[greeting](Hello) [name](John)" + """) + + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + + # When + top_intents = parser.get_intents("Hello John") + + # Then + expected_intents = [ + { + RES_INTENT_NAME: "greeting1", + RES_PROBA: 1. / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: "greeting2", + RES_PROBA: (1. / 2.) / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: "greeting3", + RES_PROBA: (1. / 3.) / (1. + 1. / 2. + 1. / 3.) + }, + { + RES_INTENT_NAME: None, + RES_PROBA: 0.0 + }, + ] + + self.assertListEqual(expected_intents, top_intents) + + def test_should_get_slots(self): + # Given + slots_dataset_stream = io.StringIO( + """ +--- +type: intent +name: greeting1 +utterances: + - Hello [name1](John) + +--- +type: intent +name: greeting2 +utterances: + - Hello [name2](Thomas) + +--- +type: intent +name: goodbye +utterances: + - Goodbye [name](Eric)""") + dataset = Dataset.from_yaml_files("en", [slots_dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + + # When + slots_greeting1 = parser.get_slots("Hello John", "greeting1") + slots_greeting2 = parser.get_slots("Hello Thomas", "greeting2") + slots_goodbye = parser.get_slots("Goodbye Eric", "greeting1") + + # Then + self.assertEqual(1, len(slots_greeting1)) + self.assertEqual(1, len(slots_greeting2)) + self.assertEqual(0, len(slots_goodbye)) + + self.assertEqual("John", slots_greeting1[0][RES_VALUE]) + self.assertEqual("name1", slots_greeting1[0][RES_ENTITY]) + self.assertEqual("Thomas", slots_greeting2[0][RES_VALUE]) + self.assertEqual("name2", slots_greeting2[0][RES_ENTITY]) + + def test_should_get_no_slots_with_none_intent(self): + # Given + slots_dataset_stream = io.StringIO( + """ +--- +type: intent +name: greeting +utterances: + - Hello [name](John)""") + dataset = Dataset.from_yaml_files("en", [slots_dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + + # When + slots = parser.get_slots("Hello John", None) + + # Then + self.assertListEqual([], slots) + + def test_get_slots_should_raise_with_unknown_intent(self): + # Given + slots_dataset_stream = io.StringIO( + """ +--- +type: intent +name: greeting1 +utterances: + - Hello [name1](John) + +--- +type: intent +name: goodbye +utterances: + - Goodbye [name](Eric)""") + dataset = Dataset.from_yaml_files("en", [slots_dataset_stream]).json + parser = LookupIntentParser().fit(dataset) + + # When / Then + with self.assertRaises(IntentNotFoundError): + parser.get_slots("Hello John", "greeting3") + + def test_should_parse_slots_after_deserialization(self): + # Given + dataset = self.slots_dataset + shared = self.get_shared_data(dataset) + parser = LookupIntentParser(**shared).fit(dataset) + parser.persist(self.tmp_file_path) + deserialized_parser = LookupIntentParser.from_path( + self.tmp_file_path, **shared) + + texts = [ + ( + "this is a dummy a query with another dummy_c at 10p.m. or at" + " 12p.m.", + [ + unresolved_slot( + match_range=(10, 17), + value="dummy a", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ), + unresolved_slot( + match_range=(37, 44), + value="dummy_c", + entity="dummy_entity_2", + slot_name="dummy_slot_name2", + ), + unresolved_slot( + match_range=(45, 54), + value="at 10p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + unresolved_slot( + match_range=(58, 67), + value="at 12p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + ], + ), + ( + "this, is,, a, dummy a query with another dummy_c at 10pm or " + "at 12p.m.", + [ + unresolved_slot( + match_range=(14, 21), + value="dummy a", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ), + unresolved_slot( + match_range=(41, 48), + value="dummy_c", + entity="dummy_entity_2", + slot_name="dummy_slot_name2", + ), + unresolved_slot( + match_range=(49, 56), + value="at 10pm", + entity="snips/datetime", + slot_name="startTime", + ), + unresolved_slot( + match_range=(60, 69), + value="at 12p.m.", + entity="snips/datetime", + slot_name="startTime", + ), + ], + ), + ( + "this is a dummy b", + [ + unresolved_slot( + match_range=(10, 17), + value="dummy b", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ) + ], + ), + ( + " this is a dummy b ", + [ + unresolved_slot( + match_range=(11, 18), + value="dummy b", + entity="dummy_entity_1", + slot_name="dummy_slot_name", + ) + ], + ), + ] + + for text, expected_slots in texts: + # When + parsing = deserialized_parser.parse(text) + + # Then + self.assertListEqual(expected_slots, parsing[RES_SLOTS]) + + def test_should_be_serializable_into_bytearray(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: MakeTea +utterances: +- make me [number_of_cups:snips/number](one) cup of tea +- i want [number_of_cups] cups of tea please +- can you prepare [number_of_cups] cup of tea ? + +--- +type: intent +name: MakeCoffee +utterances: +- make me [number_of_cups:snips/number](two) cups of coffee +- brew [number_of_cups] cups of coffee +- can you prepare [number_of_cups] cup of coffee""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + shared = self.get_shared_data(dataset) + intent_parser = LookupIntentParser(**shared).fit(dataset) + + # When + intent_parser_bytes = intent_parser.to_byte_array() + loaded_intent_parser = LookupIntentParser.from_byte_array( + intent_parser_bytes, **shared) + result = loaded_intent_parser.parse("make me two cups of coffee") + + # Then + self.assertEqual("MakeCoffee", result[RES_INTENT][RES_INTENT_NAME]) + + def test_should_parse_naughty_strings(self): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: my_intent +utterances: +- this is [slot1:entity1](my first entity) +- this is [slot2:entity2](second_entity)""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + naughty_strings_path = TEST_PATH / "resources" / "naughty_strings.txt" + with naughty_strings_path.open(encoding="utf8") as f: + naughty_strings = [line.strip("\n") for line in f.readlines()] + + # When + parser = LookupIntentParser().fit(dataset) + + # Then + for s in naughty_strings: + with self.fail_if_exception("Exception raised"): + parser.parse(s) + + def test_should_fit_with_naughty_strings_no_tags(self): + # Given + naughty_strings_path = TEST_PATH / "resources" / "naughty_strings.txt" + with naughty_strings_path.open(encoding="utf8") as f: + naughty_strings = [line.strip("\n") for line in f.readlines()] + + utterances = [ + {DATA: [{TEXT: naughty_string}]} + for naughty_string in naughty_strings + ] + + # When + naughty_dataset = { + "intents": {"naughty_intent": {"utterances": utterances}}, + "entities": dict(), + "language": "en", + } + + # Then + with self.fail_if_exception("Exception raised"): + LookupIntentParser().fit(naughty_dataset) + + def test_should_fit_and_parse_with_non_ascii_tags(self): + # Given + inputs = ["string%s" % i for i in range(10)] + utterances = [ + { + DATA: [ + { + TEXT: string, + ENTITY: "non_ascìi_entïty", + SLOT_NAME: "non_ascìi_slöt", + } + ] + } + for string in inputs + ] + + # When + naughty_dataset = { + "intents": {"naughty_intent": {"utterances": utterances}}, + "entities": { + "non_ascìi_entïty": { + "use_synonyms": False, + "automatically_extensible": True, + "matching_strictness": 1.0, + "data": [], + } + }, + "language": "en", + } + + # Then + with self.fail_if_exception("Exception raised"): + parser = LookupIntentParser().fit(naughty_dataset) + parsing = parser.parse("string0") + + expected_slot = { + "entity": "non_ascìi_entïty", + "range": {"start": 0, "end": 7}, + "slotName": "non_ascìi_slöt", + "value": "string0", + } + intent_name = parsing[RES_INTENT][RES_INTENT_NAME] + self.assertEqual("naughty_intent", intent_name) + self.assertListEqual([expected_slot], parsing[RES_SLOTS]) + + def test_should_be_serializable_before_fitting(self): + # Given + config = LookupIntentParserConfig(ignore_stop_words=True) + parser = LookupIntentParser(config=config) + + # When + parser.persist(self.tmp_file_path) + + # Then + expected_dict = { + "config": { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True, + }, + "language_code": None, + "intents_names": [], + "map": None, + "slots_names": [], + "entity_scopes": None, + "stop_words_whitelist": None + } + + metadata = {"unit_name": "lookup_intent_parser"} + self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata) + self.assertJsonContent( + self.tmp_file_path / "intent_parser.json", expected_dict) + + @patch("snips_nlu.intent_parser.lookup_intent_parser.get_stop_words") + def test_should_be_serializable(self, mock_get_stop_words): + # Given + dataset_stream = io.StringIO( + """ +--- +type: intent +name: searchFlight +slots: + - name: origin + entity: city + - name: destination + entity: city +utterances: + - find me a flight from [origin](Paris) to [destination](New York) + - I need a flight to [destination](Berlin) + +--- +type: entity +name: city +values: + - london + - [new york, big apple] + - [paris, city of lights]""") + + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + + mock_get_stop_words.return_value = {"a", "me"} + config = LookupIntentParserConfig(ignore_stop_words=True) + parser = LookupIntentParser(config=config).fit(dataset) + + # When + parser.persist(self.tmp_file_path) + + # Then + expected_dict = { + "config": { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True, + }, + "intents_names": ["searchFlight"], + "language_code": "en", + "map": { + "-2020846245": [0, [0, 1]], + "-1558674456": [0, [1]], + }, + "slots_names": ["origin", "destination"], + "entity_scopes": [ + { + "entity_scope": {"builtin": [], "custom": ["city"]}, + "intent_group": ["searchFlight"] + } + ], + "stop_words_whitelist": dict() + } + metadata = {"unit_name": "lookup_intent_parser"} + self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata) + self.assertJsonContent( + self.tmp_file_path / "intent_parser.json", expected_dict) + + def test_should_be_deserializable(self): + # Given + parser_dict = { + "config": { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + "language_code": "en", + "map": { + hash_str("make coffee"): [0, []], + hash_str("prepare % snipsnumber % coffees"): [0, [0]], + hash_str("% snipsnumber % teas at % snipstemperature %"): + [1, [0, 1]], + }, + "slots_names": ["nb_cups", "tea_temperature"], + "intents_names": ["MakeCoffee", "MakeTea"], + "entity_scopes": [ + { + "entity_scope": { + "builtin": ["snips/number"], + "custom": [], + }, + "intent_group": ["MakeCoffee"] + }, + { + "entity_scope": { + "builtin": ["snips/number", "snips/temperature"], + "custom": [], + }, + "intent_group": ["MakeTea"] + }, + ], + "stop_words_whitelist": dict() + } + self.tmp_file_path.mkdir() + metadata = {"unit_name": "lookup_intent_parser"} + self.writeJsonContent( + self.tmp_file_path / "intent_parser.json", parser_dict) + self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata) + resources = self.get_resources("en") + builtin_entity_parser = BuiltinEntityParser.build(language="en") + custom_entity_parser = EntityParserMock() + + # When + parser = LookupIntentParser.from_path( + self.tmp_file_path, custom_entity_parser=custom_entity_parser, + builtin_entity_parser=builtin_entity_parser, + resources=resources) + res_make_coffee = parser.parse("make me a coffee") + res_make_tea = parser.parse("two teas at 90°C please") + + # Then + expected_result_coffee = parsing_result( + input="make me a coffee", + intent=intent_classification_result("MakeCoffee", 1.0), + slots=[]) + expected_result_tea = parsing_result( + input="two teas at 90°C please", + intent=intent_classification_result("MakeTea", 1.0), + slots=[ + { + "entity": "snips/number", + "range": {"end": 3, "start": 0}, + "slotName": "nb_cups", + "value": "two" + }, + { + "entity": "snips/temperature", + "range": {"end": 16, "start": 12}, + "slotName": "tea_temperature", + "value": "90°C" + } + ]) + self.assertEqual(expected_result_coffee, res_make_coffee) + self.assertEqual(expected_result_tea, res_make_tea) + + def test_should_be_deserializable_before_fitting(self): + # Given + parser_dict = { + "config": {}, + "language_code": None, + "map": None, + "slots_names": [], + "intents_names": [], + "entity_scopes": None + } + self.tmp_file_path.mkdir() + metadata = {"unit_name": "dict_deterministic_intent_parser"} + self.writeJsonContent( + self.tmp_file_path / "intent_parser.json", parser_dict) + self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata) + + # When + parser = LookupIntentParser.from_path(self.tmp_file_path) + + # Then + config = LookupIntentParserConfig() + expected_parser = LookupIntentParser(config=config) + self.assertEqual(parser.to_dict(), expected_parser.to_dict()) + + def test_get_entity_scopes(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: intent1 +utterances: + - meeting [schedule_time:snips/datetime](today) + +--- +type: intent +name: intent2 +utterances: + - hello world + +--- +type: intent +name: intent3 +utterances: + - what will be the weather [weather_time:snips/datetime](tomorrow) + +--- +type: intent +name: intent4 +utterances: + - find a flight for [city](Paris) [flight_time:snips/datetime](tomorrow)""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + + # When + entity_scopes = _get_entity_scopes(dataset) + + # Then + expected_scopes = [ + { + "entity_scope": { + "builtin": ["snips/datetime"], + "custom": [] + }, + "intent_group": ["intent1", "intent3"] + }, + { + "entity_scope": { + "builtin": [], + "custom": [] + }, + "intent_group": ["intent2"] + }, + { + "entity_scope": { + "builtin": ["snips/datetime"], + "custom": ["city"] + }, + "intent_group": ["intent4"] + } + ] + + def sort_key(group_scope): + return " ".join(group_scope["intent_group"]) + + self.assertListEqual(sorted(expected_scopes, key=sort_key), + sorted(entity_scopes, key=sort_key)) diff --git a/snips_nlu/tests/test_preprocessing.py b/snips_nlu/tests/test_preprocessing.py index e8d4bf3c5..02226c42c 100644 --- a/snips_nlu/tests/test_preprocessing.py +++ b/snips_nlu/tests/test_preprocessing.py @@ -57,9 +57,11 @@ def test_should_tokenize_symbols(self): # Then expected_tokens = [ - Token(value='$$', start=0, end=2), + Token(value='$', start=0, end=1), + Token(value='$', start=1, end=2), Token(value='%', start=3, end=4), - Token(value='!!', start=5, end=7) + Token(value='!', start=5, end=6), + Token(value='!', start=6, end=7) ] self.assertListEqual(tokens, expected_tokens) diff --git a/snips_nlu/tests/utils.py b/snips_nlu/tests/utils.py index 8f1e6f04c..e1792f302 100644 --- a/snips_nlu/tests/utils.py +++ b/snips_nlu/tests/utils.py @@ -76,12 +76,12 @@ def assertFileContent(self, path, expected_content): @staticmethod def writeJsonContent(path, json_dict): json_content = json_string(json_dict) - with path.open(mode="w") as f: + with path.open(mode="w", encoding="utf8") as f: f.write(json_content) @staticmethod def writeFileContent(path, content): - with path.open(mode="w") as f: + with path.open(mode="w", encoding="utf8") as f: f.write(unicode_string(content)) @@ -153,7 +153,7 @@ def fitted(self, value): def persist(self, path): path = Path(path) path.mkdir() - with (path / "metadata.json").open(mode="w") as f: + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: unit_dict = {"unit_name": self.unit_name, "fitted": self.fitted} f.write(json_string(unit_dict)) @@ -206,8 +206,10 @@ def fit(self, dataset, intent): class EntityParserMock(EntityParser): - def __init__(self, entities): + def __init__(self, entities=None): super(EntityParserMock, self).__init__() + if entities is None: + entities = dict() self.entities = entities def persist(self, path):