diff --git a/src/extractors/text_to_text_extractor/methods/InputWithoutSpaces.py b/src/extractors/text_to_text_extractor/methods/InputWithoutSpaces.py index 83f58d0..93b80f8 100644 --- a/src/extractors/text_to_text_extractor/methods/InputWithoutSpaces.py +++ b/src/extractors/text_to_text_extractor/methods/InputWithoutSpaces.py @@ -6,7 +6,7 @@ class InputWithoutSpaces(TextToTextMethod): def train(self, extraction_data: ExtractionData): - pass + self.save_json("best_method.json", True) @staticmethod def trim_text(tag_texts: list[str]) -> str: diff --git a/src/extractors/text_to_text_extractor/test/test_text_to_text_extractor.py b/src/extractors/text_to_text_extractor/test/test_text_to_text_extractor.py index b408407..6336fc6 100644 --- a/src/extractors/text_to_text_extractor/test/test_text_to_text_extractor.py +++ b/src/extractors/text_to_text_extractor/test/test_text_to_text_extractor.py @@ -56,3 +56,24 @@ def test_predictions_two_samples(self): self.assertEqual(extraction_id, suggestions[0].id) self.assertEqual("entity_name", suggestions[0].entity_name) self.assertEqual("one", suggestions[0].text) + + def test_predictions_input_without_spaces(self): + sample = [ + TrainingSample( + labeled_data=LabeledData(label_text="onetwothree", language_iso="en"), tags_texts=["one two", "three"] + ) + ] + extraction_data = ExtractionData(samples=sample * 3, extraction_identifier=extraction_identifier) + + text_to_text_extractor = TextToTextExtractor(extraction_identifier=extraction_identifier) + text_to_text_extractor.create_model(extraction_data) + + suggestions = text_to_text_extractor.get_suggestions( + [PredictionSample.from_text("one two three four", "entity_name")] + ) + + self.assertEqual(1, len(suggestions)) + self.assertEqual(tenant, suggestions[0].tenant) + self.assertEqual(extraction_id, suggestions[0].id) + self.assertEqual("entity_name", suggestions[0].entity_name) + self.assertEqual("onetwothreefour", suggestions[0].text)