Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 62 additions & 36 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union
from chatterbot.storage import StorageAdapter
from chatterbot.logic import LogicAdapter
from chatterbot.search import TextSearch, IndexedTextSearch
from chatterbot.search import TextSearch, IndexedTextSearch, SemanticVectorSearch
from chatterbot.tagging import PosLemmaTagger
from chatterbot.conversation import Statement
from chatterbot import languages
Expand Down Expand Up @@ -74,41 +74,60 @@ def __init__(self, name, stream=False, **kwargs):

tagger_language = kwargs.get('tagger_language', languages.ENG)

try:
Tagger = kwargs.get('tagger', PosLemmaTagger)

# Allow instances to be provided for performance optimization
# (Example: a pre-loaded model in a tagger when unit testing)
if not isinstance(Tagger, type):
self.tagger = Tagger
else:
self.tagger = Tagger(language=tagger_language)
except IOError as io_error:
# Return a more helpful error message if possible
if "Can't find model" in str(io_error):
model_name = utils.get_model_for_language(tagger_language)
if hasattr(tagger_language, 'ENGLISH_NAME'):
language_name = tagger_language.ENGLISH_NAME
# Check if storage adapter has a preferred tagger
PreferredTagger = self.storage.get_preferred_tagger()

if PreferredTagger is not None:
# Storage adapter specifies its own tagger
self.tagger = PreferredTagger(language=tagger_language)
else:
# Use default or user-specified tagger
try:
Tagger = kwargs.get('tagger', PosLemmaTagger)

# Allow instances to be provided for performance optimization
# (Example: a pre-loaded model in a tagger when unit testing)
if not isinstance(Tagger, type):
self.tagger = Tagger
else:
language_name = tagger_language
raise self.ChatBotException(
'Setup error:\n'
f'The Spacy model for "{language_name}" language is missing.\n'
'Please install the model using the command:\n\n'
f'python -m spacy download {model_name}\n\n'
'See https://spacy.io/usage/models for more information about available models.'
) from io_error
else:
raise io_error
self.tagger = Tagger(language=tagger_language)
except IOError as io_error:
# Return a more helpful error message if possible
if "Can't find model" in str(io_error):
model_name = utils.get_model_for_language(tagger_language)
if hasattr(tagger_language, 'ENGLISH_NAME'):
language_name = tagger_language.ENGLISH_NAME
else:
language_name = tagger_language
raise self.ChatBotException(
'Setup error:\n'
f'The Spacy model for "{language_name}" language is missing.\n'
'Please install the model using the command:\n\n'
f'python -m spacy download {model_name}\n\n'
'See https://spacy.io/usage/models for more information about available models.'
) from io_error
else:
raise io_error

# Initialize search algorithms
primary_search_algorithm = IndexedTextSearch(self, **kwargs)
text_search_algorithm = TextSearch(self, **kwargs)
semantic_vector_search_algorithm = SemanticVectorSearch(self, **kwargs)

self.search_algorithms = {
primary_search_algorithm.name: primary_search_algorithm,
text_search_algorithm.name: text_search_algorithm
text_search_algorithm.name: text_search_algorithm,
semantic_vector_search_algorithm.name: semantic_vector_search_algorithm
}

# Check if storage adapter has a preferred search algorithm
preferred_search_algorithm = self.storage.get_preferred_search_algorithm()
if preferred_search_algorithm and preferred_search_algorithm in self.search_algorithms:
# Set as default for logic adapters that don't specify their own search algorithm
# This ensures BestMatch and other adapters use the optimal search method
self.logger.info(f'Storage adapter prefers search algorithm: {preferred_search_algorithm}')
kwargs.setdefault('search_algorithm_name', preferred_search_algorithm)

for adapter in logic_adapters:
utils.validate_adapter_class(adapter, LogicAdapter)
logic_adapter = utils.initialize_class(adapter, self, **kwargs)
Expand Down Expand Up @@ -191,15 +210,22 @@ def get_response(self, statement: Union[Statement, str, dict] = None, **kwargs)
input_statement.in_response_to = previous_statement.text

# Make sure the input statement has its search text saved

if not input_statement.search_text:
_search_text = self.tagger.get_text_index_string(input_statement.text)
input_statement.search_text = _search_text

if not input_statement.search_in_response_to and input_statement.in_response_to:
input_statement.search_in_response_to = self.tagger.get_text_index_string(
input_statement.in_response_to
)
if not self.tagger.needs_text_indexing():
# Tagger doesn't transform text, use it directly
if not input_statement.search_text:
input_statement.search_text = input_statement.text
if not input_statement.search_in_response_to and input_statement.in_response_to:
input_statement.search_in_response_to = input_statement.in_response_to
else:
# Use tagger for text indexing or transformations
if not input_statement.search_text:
_search_text = self.tagger.get_text_index_string(input_statement.text)
input_statement.search_text = _search_text

if not input_statement.search_in_response_to and input_statement.in_response_to:
input_statement.search_in_response_to = self.tagger.get_text_index_string(
input_statement.in_response_to
)

response = self.generate_response(
input_statement,
Expand Down
70 changes: 70 additions & 0 deletions chatterbot/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,73 @@ def search(self, input_statement, **additional_parameters):
if confidence >= 1.0:
self.chatbot.logger.info('Exact match found, stopping search')
break


class SemanticVectorSearch:
"""
Semantic vector search for storage adapters that use vector embeddings.
Does not require a tagger or comparison function - relies on the storage
adapter's native vector similarity search capabilities.

:param search_page_size:
The maximum number of records to load into memory at a time when searching.
Defaults to 1000
"""

name = 'semantic_vector_search'

def __init__(self, chatbot, **kwargs):
self.chatbot = chatbot

self.search_page_size = kwargs.get(
'search_page_size', 1000
)

def search(self, input_statement, **additional_parameters):
"""
Search for semantically similar statements using vector similarity.
Confidence scores are calculated by the storage adapter based on
vector distances and returned in the results.

:param input_statement: A statement.
:type input_statement: chatterbot.conversation.Statement

:param **additional_parameters: Additional parameters to be passed
to the ``filter`` method of the storage adapter when searching.

:rtype: Generator yielding one closest matching statement at a time.
"""
self.chatbot.logger.info('Beginning semantic vector search')

search_parameters = {
'search_in_response_to_contains': input_statement.text,
'persona_not_startswith': 'bot:',
'page_size': self.search_page_size
}

if additional_parameters:
search_parameters.update(additional_parameters)

statement_list = self.chatbot.storage.filter(**search_parameters)

best_confidence_so_far = 0

self.chatbot.logger.info('Processing search results')

# Yield statements with confidence scores from vector similarity
for statement in statement_list:
# Confidence should already be set by the storage adapter
confidence = getattr(statement, 'confidence', 0.0)

if confidence > best_confidence_so_far:
best_confidence_so_far = confidence

self.chatbot.logger.info('Similar statement found: {} {}'.format(
statement.in_response_to, confidence
))

yield statement

if confidence >= 1.0:
self.chatbot.logger.info('Exact match found, stopping search')
break
Loading
Loading