Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
services:
redis:
image: redis/redis-stack-server:latest
Expand Down
7 changes: 7 additions & 0 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,10 @@ def drop(self):
Remove the database.
"""
self.client.drop_database(self.database.name)

def close(self):
"""
Close the MongoDB client connection.
"""
if hasattr(self, 'client'):
self.client.close()
230 changes: 197 additions & 33 deletions chatterbot/storage/redis.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from datetime import datetime
import json
import re
from chatterbot.storage import StorageAdapter
from chatterbot.conversation import Statement as StatementObject


# TODO: This list may not be exhaustive.
# Is there a full list of characters reserved by redis?
REDIS_ESCAPE_CHARACTERS = {
'\\': '\\\\',
':': '\\:',
'|': '\\|',
'%': '\\%',
'!': '\\!',
'-': '\\-',
}

REDIS_TRANSLATION_TABLE = str.maketrans(REDIS_ESCAPE_CHARACTERS)


def _escape_redis_special_characters(text):
"""
Escape special characters in a string that are used in redis queries.

This function escapes characters that would interfere with the query syntax
used in the filter() method, specifically:
- Pipe (|) which is used as the OR operator when joining search terms
- Characters that could break the wildcard pattern matching
"""
return text.translate(REDIS_TRANSLATION_TABLE)
from redisvl.query.filter import TokenEscaper

# Remove space (last character) and add pipe
escape_pattern = TokenEscaper.DEFAULT_ESCAPED_CHARS.rstrip(' ]') + r'\|]'

escaper = TokenEscaper(escape_chars_re=re.compile(escape_pattern))
return escaper.escape(text)


class RedisVectorStorageAdapter(StorageAdapter):
Expand Down Expand Up @@ -158,7 +157,8 @@ def remove(self, statement):
Removes any responses from statements where the response text matches
the input text.
"""
self.vector_store.delete(ids=[statement.id.split(':')[1]])
client = self.vector_store.index.client
client.delete(statement.id)

def filter(self, page_size=4, **kwargs):
"""
Expand All @@ -180,6 +180,7 @@ def filter(self, page_size=4, **kwargs):
- search_in_response_to_contains
- order_by
"""
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag, Text

# https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/
Expand Down Expand Up @@ -244,6 +245,68 @@ def filter(self, page_size=4, **kwargs):
else:
filter_condition = query

# Handle search_text parameter (used by BestMatch logic adapter)
# BestMatch uses search_text to find statements with matching indexed text.
# Since Redis doesn't store search_text as a field, we approximate this by:
# 1. Using the search_text value as a semantic query against in_response_to
# 2. This finds statements that are responses to similar inputs
# The effect is similar to BestMatch's Phase 2: finding alternate responses
if 'search_text' in kwargs:
_search_text = kwargs.get('search_text', '')

# Get embedding for the search text
# Note: search_text may be indexed (e.g., "NOUN:cat VERB:run") so this
# approximates finding responses to semantically similar queries
embedding = self.vector_store.embeddings.embed_query(_search_text)

# Build return fields from metadata schema
return_fields = [
'text', 'in_response_to', 'conversation', 'persona', 'tags', 'created_at'
]

# Use direct index query via RedisVL
# Search on the vectorized content (in_response_to) to find similar response patterns
query = VectorQuery(
vector=embedding,
vector_field_name='embedding',
return_fields=return_fields,
num_results=page_size,
filter_expression=filter_condition
)

# Execute query
results = self.vector_store.index.query(query)

# Convert results to Document objects
Document = self.get_statement_model()
documents = []
for result in results:
# Extract metadata and content
in_response_to = result.get('in_response_to', '')

# Convert created_at from integer (YYMMDD) to datetime
created_at_int = int(result.get('created_at', 0))
if created_at_int:
created_at = datetime.strptime(str(created_at_int), '%y%m%d')
else:
created_at = datetime.now()

metadata = {
'text': result.get('text', ''),
'conversation': result.get('conversation', ''),
'persona': result.get('persona', ''),
'tags': result.get('tags', ''),
'created_at': created_at,
}
doc = Document(
page_content=in_response_to,
metadata=metadata,
id=result['id']
)
documents.append(doc)

return [self.model_to_object(document) for document in documents]

ordering = kwargs.get('order_by', None)

if ordering:
Expand All @@ -252,14 +315,56 @@ def filter(self, page_size=4, **kwargs):
if 'search_in_response_to_contains' in kwargs:
_search_text = kwargs.get('search_in_response_to_contains', '')

# TODO similarity_search_with_score
documents = self.vector_store.similarity_search(
_search_text,
k=page_size, # The number of results to return
return_all=True, # Include the full document with IDs
filter=filter_condition,
sort_by=ordering
# Get embedding for the search text
embedding = self.vector_store.embeddings.embed_query(_search_text)

# Build return fields from metadata schema
return_fields = [
'text', 'in_response_to', 'conversation', 'persona', 'tags', 'created_at'
]

# Use direct index query via RedisVL
# langchain's similarity_search has issues with filters in v0.2.4
# and may not work properly with existing indexes
# TODO: Look into similarity_search_with_score implementation
query = VectorQuery(
vector=embedding,
vector_field_name='embedding',
return_fields=return_fields,
num_results=page_size,
filter_expression=filter_condition
)

# Execute query
results = self.vector_store.index.query(query)

# Convert results to Document objects
Document = self.get_statement_model()
documents = []
for result in results:
# Extract metadata and content
in_response_to = result.get('in_response_to', '')

# Convert created_at from integer (YYMMDD) to datetime
created_at_int = int(result.get('created_at', 0))
if created_at_int:
created_at = datetime.strptime(str(created_at_int), '%y%m%d')
else:
created_at = datetime.now()

metadata = {
'text': result.get('text', ''),
'conversation': result.get('conversation', ''),
'persona': result.get('persona', ''),
'tags': result.get('tags', ''),
'created_at': created_at,
}
doc = Document(
page_content=in_response_to,
metadata=metadata,
id=result['id']
)
documents.append(doc)
else:
documents = self.vector_store.query_search(
k=page_size,
Expand All @@ -284,25 +389,29 @@ def create(

_default_date = datetime.now()

# Prevent duplicate tag entries in the database
unique_tags = list(set(tags)) if tags else []

metadata = {
'text': text,
'category': kwargs.get('category', ''),
# NOTE: `created_at` must have a valid numeric value or results will
# not be returned for similarity_search for some reason
'created_at': kwargs.get('created_at') or int(_default_date.strftime('%y%m%d')),
'tags': '|'.join(tags) if tags else '',
'tags': '|'.join(unique_tags) if unique_tags else '',
'conversation': kwargs.get('conversation', ''),
'persona': kwargs.get('persona', ''),
}

ids = self.vector_store.add_texts([in_response_to or ''], [metadata])

metadata['created_at'] = _default_date
metadata['tags'] = tags or []
metadata['tags'] = unique_tags
metadata.pop('text')
statement = StatementObject(
id=ids[0],
text=text,
in_response_to=in_response_to,
**metadata
)
return statement
Expand All @@ -320,7 +429,10 @@ def create_many(self, statements):
'conversation': statement.conversation or '',
'created_at': int(statement.created_at.strftime('%y%m%d')),
'persona': statement.persona or '',
'tags': '|'.join(statement.tags) if statement.tags else '',
# Prevent duplicate tag entries in the database
'tags': '|'.join(
list(set(statement.tags))
) if statement.tags else '',
}
) for statement in statements
]
Expand All @@ -334,12 +446,15 @@ def update(self, statement):
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
# Prevent duplicate tag entries in the database
unique_tags = list(set(statement.tags)) if statement.tags else []

metadata = {
'text': statement.text,
'conversation': statement.conversation or '',
'created_at': int(statement.created_at.strftime('%y%m%d')),
'persona': statement.persona or '',
'tags': '|'.join(statement.tags) if statement.tags else '',
'tags': '|'.join(unique_tags) if unique_tags else '',
}

Document = self.get_statement_model()
Expand All @@ -349,9 +464,31 @@ def update(self, statement):
)

if statement.id:
self.vector_store.add_texts(
[document.page_content], [metadata], keys=[statement.id.split(':')[1]]
# When updating with an existing ID, first delete the old entry
# to ensure a duplicate entry is not created
client = self.vector_store.index.client
client.delete(statement.id)

# NOTE: langchain-redis has an inconsistency - it uses :: for auto-generated
# IDs but : (single colon) when keys are explicitly provided
if '::' in statement.id:
key = statement.id.split('::', 1)[1]
elif ':' in statement.id:
key = statement.id.split(':', 1)[1]
else:
# If no delimiter found, use the entire ID as the key
key = statement.id

ids = self.vector_store.add_texts(
[document.page_content], [metadata], keys=[key]
)

# Normalize the ID to use :: delimiter (if langchain-redis returned single colon)
if ids and ':' in ids[0] and '::' not in ids[0]:
# Replace first occurrence of single colon with double colon
normalized_id = ids[0].replace(':', '::', 1)
# Update the key in Redis to use the correct format
client.rename(ids[0], normalized_id)
else:
self.vector_store.add_documents([document])

Expand All @@ -364,12 +501,31 @@ def get_random(self):
random_key = client.randomkey()

if random_key:
random_id = random_key.decode().split(':')[1]
# Get the hash data from Redis
data = client.hgetall(random_key)

documents = self.vector_store.get_by_ids([random_id])
if data and b'_metadata_json' in data:
# Parse the metadata
metadata = json.loads(data[b'_metadata_json'].decode())

if documents:
return self.model_to_object(documents[0])
# Convert created_at from integer (YYMMDD) back to datetime
if 'created_at' in metadata and isinstance(metadata['created_at'], int):
created_at_str = str(metadata['created_at'])
# Parse YYMMDD format
metadata['created_at'] = datetime.strptime(created_at_str, '%y%m%d')

# Get the in_response_to from the hash
in_response_to = data.get(b'in_response_to', b'').decode()

# Create a Document-like object to use with model_to_object
Document = self.get_statement_model()
document = Document(
page_content=in_response_to if in_response_to else '',
metadata=metadata,
id=random_key.decode()
)

return self.model_to_object(document)

raise self.EmptyDatabaseException()

Expand All @@ -389,3 +545,11 @@ def drop(self):
# we want is to delete all the keys in the index, but
# keep the index itself)
# self.vector_store.index.delete(drop=True)

def close(self):
"""
Close the Redis client connection.
"""
if hasattr(self, 'vector_store') and hasattr(self.vector_store, 'index'):
if hasattr(self.vector_store.index, 'client'):
self.vector_store.index.client.close()
8 changes: 8 additions & 0 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,11 @@ def create_database(self):
"""
from chatterbot.ext.sqlalchemy_app.models import Base
Base.metadata.create_all(self.engine)

def close(self):
"""
Close the database connection and dispose of the engine.
This ensures proper cleanup of resources.
"""
if hasattr(self, 'engine'):
self.engine.dispose()
8 changes: 8 additions & 0 deletions chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ def drop(self):
'The `drop` method is not implemented by this adapter.'
)

def close(self):
"""
Close any open connections or sessions.
This method should be called when the storage adapter is no longer needed
to properly clean up resources and avoid resource warnings.
"""
pass

class EmptyDatabaseException(Exception):

def __init__(self, message=None):
Expand Down
Loading
Loading