diff --git a/chatterbot/tagging.py b/chatterbot/tagging.py index 9645536b5..c32e62512 100644 --- a/chatterbot/tagging.py +++ b/chatterbot/tagging.py @@ -23,17 +23,37 @@ def __init__(self, language=None): def get_text_index_string(self, text: Union[str, List[str]]): if isinstance(text, list): - documents = self.nlp.pipe(text) + documents = self.nlp.pipe(text, batch_size=1000, n_process=1) return [document._.search_index for document in documents] else: document = self.nlp(text) return document._.search_index - def as_nlp_pipeline(self, texts: Union[List[str], Tuple[str, dict]]): + def as_nlp_pipeline( + self, + texts: Union[List[str], Tuple[str, dict]], + batch_size: int = 1000, + n_process: int = 1 + ): + """ + Process texts through the spaCy NLP pipeline with optimized batching. + + :param texts: Text strings or tuples of (text, context_dict) + :param batch_size: Number of texts per batch (default 1000) + :param n_process: Number of worker processes for spaCy's pipe (set >1 to use multiprocessing) + Usage: + documents = tagger.as_nlp_pipeline(texts) + documents = tagger.as_nlp_pipeline(texts, batch_size=2000, n_process=4) + """ process_as_tuples = texts and isinstance(texts[0], tuple) - documents = self.nlp.pipe(texts, as_tuples=process_as_tuples) + documents = self.nlp.pipe( + texts, + as_tuples=process_as_tuples, + batch_size=batch_size, + n_process=n_process + ) return documents @@ -58,20 +78,37 @@ def get_text_index_string(self, text: Union[str, List[str]]) -> str: Return a string of text containing part-of-speech, lemma pairs. """ if isinstance(text, list): - documents = self.nlp.pipe(text) + documents = self.nlp.pipe(text, batch_size=1000, n_process=1) return [document._.search_index for document in documents] else: document = self.nlp(text) return document._.search_index - def as_nlp_pipeline(self, texts: Union[List[str], Tuple[str, dict]]): + def as_nlp_pipeline( + self, + texts: Union[List[str], Tuple[str, dict]], + batch_size: int = 1000, + n_process: int = 1 + ) -> spacy.tokens.Doc: """ Accepts a single string or a list of strings, or a list of tuples where the first element is the text and the second element is a dictionary of context to return alongside the generated document. - """ + :param texts: Text strings or tuples of (text, context_dict) + :param batch_size: Number of texts per batch (default 1000) + :param n_process: Number of worker processes for spaCy's pipe (set >1 to use multiprocessing) + + Usage: + documents = tagger.as_nlp_pipeline(texts) + documents = tagger.as_nlp_pipeline(texts, batch_size=2000, n_process=4) + """ process_as_tuples = texts and isinstance(texts[0], tuple) - documents = self.nlp.pipe(texts, as_tuples=process_as_tuples) + documents = self.nlp.pipe( + texts, + as_tuples=process_as_tuples, + batch_size=batch_size, + n_process=n_process + ) return documents diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 76a600592..e5c260264 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -90,28 +90,38 @@ def train(self, conversation: List[str]): """ previous_statement_text = None previous_statement_search_text = '' - statements_to_create = [] - # Run the pipeline in bulk to improve performance - documents = self.chatbot.tagger.as_nlp_pipeline(conversation) - + # Preprocess all text before NLP analysis + preprocessed_texts = conversation + for preprocessor in self.chatbot.preprocessors: + preprocessed_texts = [ + preprocessor(Statement(text=text)).text + for text in preprocessed_texts + ] + + # Batch process with NLP + documents = list(self.chatbot.tagger.as_nlp_pipeline( + preprocessed_texts, + batch_size=2000, + # NOTE: Not all spaCy models support multi-processing + n_process=1 + )) + + # Create statements from processed documents for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress): statement_search_text = document._.search_index - statement = self.get_preprocessed_statement( - Statement( - text=document.text, - search_text=statement_search_text, - in_response_to=previous_statement_text, - search_in_response_to=previous_statement_search_text, - conversation='training' - ) + statement = Statement( + text=document.text, + search_text=statement_search_text, + in_response_to=previous_statement_text, + search_in_response_to=previous_statement_search_text, + conversation='training' ) previous_statement_text = statement.text previous_statement_search_text = statement_search_text - statements_to_create.append(statement) self.chatbot.storage.create_many(statements_to_create) @@ -134,21 +144,45 @@ def train(self, *corpus_paths: Union[str, List[str]]): for corpus, categories, _file_path in tqdm( load_corpus(*data_file_paths), - desc='ChatterBot Corpus Trainer', + desc='Training corpus', disable=self.disable_progress ): statements_to_create = [] - # Train the chat bot with each statement and response pair - for conversation in corpus: - - # Run the pipeline in bulk to improve performance - documents = self.chatbot.tagger.as_nlp_pipeline(conversation) + # Collect all texts from all conversations for batch processing + all_texts = [] + conversation_lengths = [] + for conversation in corpus: + conversation_lengths.append(len(conversation)) + all_texts.extend(conversation) + + # Preprocess all texts + preprocessed_texts = all_texts + for preprocessor in self.chatbot.preprocessors: + preprocessed_texts = [ + preprocessor(Statement(text=text)).text + for text in preprocessed_texts + ] + + # Batch process all texts with NLP + documents = list(self.chatbot.tagger.as_nlp_pipeline( + preprocessed_texts, + batch_size=2000, + # NOTE: Not all spaCy models support multi-processing + n_process=1 + )) + + # Reconstruct conversations from batch-processed documents + doc_index = 0 + for conversation_length in conversation_lengths: previous_statement_text = None previous_statement_search_text = '' - for document in documents: + for _ in range(conversation_length): + document = documents[doc_index] + doc_index += 1 + statement_search_text = document._.search_index statement = Statement( @@ -161,11 +195,8 @@ def train(self, *corpus_paths: Union[str, List[str]]): statement.add_tags(*categories) - statement = self.get_preprocessed_statement(statement) - previous_statement_text = statement.text previous_statement_search_text = statement_search_text - statements_to_create.append(statement) if statements_to_create: @@ -283,18 +314,28 @@ def train(self, data_path: str, limit=None): text_row = self.field_map['text'] + # Collect all rows first to avoid re-reading file + rows_list = [row for row in data if len(row) > 0] + + # Extract text and metadata for each row + text_values = [] + contexts = [] + try: - documents = self.chatbot.tagger.as_nlp_pipeline([ - ( - row[text_row], - { - # Include any defined metadata columns - key: row[value] - for key, value in self.field_map.items() - if key != text_row - } - ) for row in data if len(row) > 0 - ]) + for row in rows_list: + context = { + key: row[value] + for key, value in self.field_map.items() + if key != text_row + } + contexts.append(context) + + # Preprocess text + text = row[text_row] + for preprocessor in self.chatbot.preprocessors: + text = preprocessor(Statement(text=text)).text + + text_values.append((text, context)) except KeyError as e: raise KeyError( f'{e}. Please check the field_map parameter used to initialize ' @@ -302,21 +343,50 @@ def train(self, data_path: str, limit=None): f'Current mapping: {self.field_map}' ) + # Batch process with NLP + documents = self.chatbot.tagger.as_nlp_pipeline( + text_values, + batch_size=2000, + # NOTE: Not all spaCy models support multi-processing + n_process=1 + ) + + # Convert to list for processing + documents_list = list(documents) + response_to_search_index_mapping = {} if 'in_response_to' in self.field_map.keys(): - # Generate the search_in_response_to value for the in_response_to fields - response_documents = self.chatbot.tagger.as_nlp_pipeline([ - ( - row[self.field_map['in_response_to']] - ) for row in data if len(row) > 0 and row[self.field_map['in_response_to']] is not None - ]) - - # (Process the response values the same way as the text values) - for document in response_documents: - response_to_search_index_mapping[document.text] = document._.search_index - - for document, context in documents: + # Process response references for search indexing + in_response_to_field = self.field_map['in_response_to'] + response_texts = [ + row[in_response_to_field] + for row in rows_list + if row[in_response_to_field] is not None + ] + + if response_texts: + # Preprocess response texts + preprocessed_response_texts = response_texts + for preprocessor in self.chatbot.preprocessors: + preprocessed_response_texts = [ + preprocessor(Statement(text=text)).text + for text in preprocessed_response_texts + ] + + # Batch process response texts + response_documents = self.chatbot.tagger.as_nlp_pipeline( + preprocessed_response_texts, + batch_size=2000, + # NOTE: Not all spaCy models support multi-processing + n_process=1 + ) + + for document in response_documents: + response_to_search_index_mapping[document.text] = document._.search_index + + # Create statements from processed documents + for document, context in tqdm(documents_list, desc='Creating statements', disable=self.disable_progress, leave=False): statement = Statement( text=document.text, conversation=context.get('conversation', 'training'), @@ -342,9 +412,6 @@ def train(self, data_path: str, limit=None): statement.in_response_to = previous_statement_text statement.search_in_response_to = previous_statement_search_text - for preprocessor in self.chatbot.preprocessors: - statement = preprocessor(statement) - previous_statement_text = statement.text previous_statement_search_text = statement.search_text