Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ jobs:
python -Wonce -m unittest discover -s tests -v
- name: Run tests for Django example app
run: |
python -m pip install "django<=4.1"
python -Wonce runtests.py
python -Wonce examples/django_example/manage.py test examples/django_example/
# --------------------------------------------------------------
# TODO: Fix & re-enable later
Expand Down
10 changes: 6 additions & 4 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def __init__(self, name, stream=False, **kwargs):
try:
Tagger = kwargs.get('tagger', PosLemmaTagger)

self.tagger = Tagger(language=tagger_language)
# Allow instances to be provided for performance optimization
# (Example: a pre-loaded model in a tagger when unit testing)
if not isinstance(Tagger, type):
self.tagger = Tagger
else:
self.tagger = Tagger(language=tagger_language)
except IOError as io_error:
# Return a more helpful error message if possible
if "Can't find model" in str(io_error):
Expand Down Expand Up @@ -224,7 +229,6 @@ def get_response(self, statement: Union[Statement, str, dict] = None, **kwargs)
# Save the response generated for the input
self.learn_response(response, previous_statement=input_statement)


return response

def generate_response(self, input_statement, additional_response_selection_parameters=None):
Expand Down Expand Up @@ -345,8 +349,6 @@ def get_latest_response(self, conversation: str):
Returns the latest response in a conversation if it exists.
Returns None if a matching conversation cannot be found.
"""
from chatterbot.conversation import Statement as StatementObject

conversation_statements = list(self.storage.filter(
conversation=conversation,
order_by=['id']
Expand Down
5 changes: 4 additions & 1 deletion chatterbot/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ class SpacySimilarity(Comparator):
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm

Alternatively, the ``spacy`` models can be installed as Python packages. The following lines could be included in a ``requirements.txt`` or ``pyproject.yml`` file if you needed to pin specific versions:
Alternatively, the ``spacy`` models can be installed as Python
packages. The following lines could be included in a
``requirements.txt`` or ``pyproject.yml`` file if you needed to pin
specific versions:

.. code-block:: text

Expand Down
1 change: 0 additions & 1 deletion chatterbot/logic/best_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def process(self, input_statement: Statement, additional_response_selection_para
additional_response_selection_parameters
)


# Get all statements with text similar to the closest match
response_list = list(self.chatbot.storage.filter(**response_selection_parameters))

Expand Down
2 changes: 1 addition & 1 deletion chatterbot/logic/unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_unit(self, unit_variations):
for unit in unit_variations:
try:
return getattr(self.unit_registry, unit)
except Exception:
except AttributeError:
continue
return None

Expand Down
25 changes: 18 additions & 7 deletions chatterbot/response_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,28 @@ def get_most_frequent_response(input_statement: Statement, response_list: list[S

:return: The response statement with the greatest number of occurrences.
"""
matching_response = None
occurrence_count = -1

logger = logging.getLogger(__name__)
logger.info('Selecting response with greatest number of occurrences.')

# Collect all unique text values from response_list
response_texts = set(statement.text for statement in response_list)

# Fetch all statements matching the input in a single query
# Then count occurrences in memory
all_matching = list(storage.filter(in_response_to=input_statement.text))

# Count how many times each response text appears in the database
occurrence_counts = {}
for statement in all_matching:
if statement.text in response_texts:
occurrence_counts[statement.text] = occurrence_counts.get(statement.text, 0) + 1

# Find the response with the highest occurrence count
matching_response = None
occurrence_count = -1

for statement in response_list:
count = len(list(storage.filter(
text=statement.text,
in_response_to=input_statement.text)
))
count = occurrence_counts.get(statement.text, 0)

# Keep the more common statement
if count >= occurrence_count:
Expand Down
8 changes: 8 additions & 0 deletions chatterbot/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def search(self, input_statement, **additional_parameters):

yield statement

if confidence >= 1.0:
self.chatbot.logger.info('Exact match found, stopping search')
break


class TextSearch:
"""
Expand Down Expand Up @@ -149,3 +153,7 @@ def search(self, input_statement, **additional_parameters):
))

yield statement

if confidence >= 1.0:
self.chatbot.logger.info('Exact match found, stopping search')
break
2 changes: 1 addition & 1 deletion chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def filter(self, **kwargs):
search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
if isinstance(tags, str):
tags = [tags]

if tags:
Expand Down
20 changes: 11 additions & 9 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,17 @@ def filter(self, **kwargs):
for order in order_by:
mongo_ordering.append((order, pymongo.ASCENDING))

total_statements = self.statements.count_documents(kwargs)

for start_index in range(0, total_statements, page_size):
if mongo_ordering:
for match in self.statements.find(kwargs).sort(mongo_ordering).skip(start_index).limit(page_size):
yield self.mongo_to_object(match)
else:
for match in self.statements.find(kwargs).skip(start_index).limit(page_size):
yield self.mongo_to_object(match)
# Build the query cursor
if mongo_ordering:
cursor = self.statements.find(kwargs).sort(mongo_ordering)
else:
cursor = self.statements.find(kwargs)

# Use batch_size for efficient pagination without counting total documents
cursor = cursor.batch_size(page_size)

for match in cursor:
yield self.mongo_to_object(match)

def create(self, **kwargs):
"""
Expand Down
3 changes: 2 additions & 1 deletion chatterbot/storage/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

REDIS_TRANSLATION_TABLE = str.maketrans(REDIS_ESCAPE_CHARACTERS)


def _escape_redis_special_characters(text):
"""
Escape special characters in a string that are used in redis queries.
Expand Down Expand Up @@ -369,7 +370,7 @@ def get_random(self):

if documents:
return self.model_to_object(documents[0])

raise self.EmptyDatabaseException()

def drop(self):
Expand Down
19 changes: 11 additions & 8 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def filter(self, **kwargs):
search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
if isinstance(tags, str):
tags = [tags]

if len(kwargs) == 0:
Expand Down Expand Up @@ -240,15 +240,18 @@ def create(
)

tags = frozenset(tags) if tags else frozenset()
for tag_name in frozenset(tags):
# TODO: Query existing tags in bulk
tag = session.query(Tag).filter_by(name=tag_name).first()

if not tag:
# Create the tag
tag = Tag(name=tag_name)
# Batch query tags
if tags:
existing_tags = session.query(Tag).filter(Tag.name.in_(tags)).all()
existing_tag_dict = {tag.name: tag for tag in existing_tags}

statement.tags.append(tag)
for tag_name in tags:
tag = existing_tag_dict.get(tag_name)
if not tag:
# Create the tag if it doesn't exist
tag = Tag(name=tag_name)
statement.tags.append(tag)

session.add(statement)

Expand Down
3 changes: 2 additions & 1 deletion chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def train(self, data_path: str, limit=None):
)
)


class CsvFileTrainer(GenericFileTrainer):
"""
.. note::
Expand Down Expand Up @@ -550,7 +551,7 @@ def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
self.chatbot.logger.info('File extracted to {}'.format(self.data_path))

return True

def _get_file_list(self, data_path: str, limit: Union[int, None]):
"""
Get a list of files to read from the data set.
Expand Down
3 changes: 1 addition & 2 deletions chatterbot/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
"""
from __future__ import annotations

from typing import Any, List, Sequence
from typing import List

from langchain_core.documents import Document
from redisvl.redis.utils import convert_bytes
from redisvl.query import FilterQuery

from langchain_core.documents import Document
from langchain_redis.vectorstores import RedisVectorStore as LangChainRedisVectorStore


Expand Down
19 changes: 10 additions & 9 deletions docs/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ You can run ChatterBot's main test suite using Python's built-in test runner. Fo

python -m unittest discover -s tests -v

This command will run all tests including Django integration tests (if Django is installed).

*Note* that the ``unittest`` command also allows you to specify individual test cases to run.
For example, the following command will run all tests in the test-module `tests/logic/`

Expand All @@ -34,22 +36,21 @@ Tests can also be run in "fail fast" mode, in which case they will run until the

python -m unittest discover -f tests

For more information on ``unittest`` functionality, see the `unittest documentation`_.

Django integration tests
------------------------

Tests for Django integration have been included in the `tests_django` directory and
can be run with:
Django integration tests are included in ``tests/django_integration/`` and will automatically run
when you execute the main test suite (if Django is installed). If Django is not available,
these tests will be gracefully skipped.

.. sourcecode:: sh
To run only Django integration tests:

python runtests.py
.. sourcecode:: sh

Django example app tests
------------------------
python -m unittest discover -s tests/django_integration/ -v

Tests for the example Django app can be run with the following command from within the `examples/django_example` directory.
The Django example app tests can be run separately with the following command from within
the `examples/django_example` directory:

.. sourcecode:: sh

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"mathparse>=0.1,<0.2",
"mathparse>=0.2,<0.3",
"python-dateutil>=2.9,<2.10",
"sqlalchemy>=2.0,<2.1",
"spacy>=3.8,<3.9",
Expand All @@ -74,6 +74,7 @@ test = [
"sphinx>=5.3,<8.2",
"sphinx-sitemap>=2.6.0",
"huggingface_hub",
"django<=4.1,<6.0"
]
dev = [
"pint>=0.8.1",
Expand Down
22 changes: 0 additions & 22 deletions runtests.py

This file was deleted.

21 changes: 20 additions & 1 deletion tests/base_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@


class ChatBotTestCase(TestCase):
"""
Base test case class that provides common test utilities.
"""

# Share a single tagger instance across all tests in a test class to avoid
# repeatedly loading the spaCy model (saves 1-3 seconds per test)
_shared_tagger = None

@classmethod
def setUpClass(cls):
super().setUpClass()
if cls._shared_tagger is None:
from chatterbot.tagging import PosLemmaTagger
cls._shared_tagger = PosLemmaTagger()

def setUp(self):
self.chatbot = ChatBot('Test Bot', **self.get_kwargs())
kwargs = self.get_kwargs()
kwargs['tagger'] = self._shared_tagger
self.chatbot = ChatBot('Test Bot', **kwargs)

def _add_search_text(self, **kwargs):
"""
Expand Down Expand Up @@ -106,6 +122,9 @@ def setUpClass(cls):
except ServerSelectionTimeoutError:
raise SkipTest('Unable to connect to Mongo DB.')

# Initialize the shared tagger
super().setUpClass()

def get_kwargs(self):
kwargs = super().get_kwargs()
kwargs['database_uri'] = 'mongodb://localhost:27017/chatterbot_test_database'
Expand Down
Loading
Loading