diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 2ad445a7..cf593959 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -86,18 +86,20 @@ jobs: - name: Build cython extensions with tracing run: CYTHON_TRACE=1 python setup.py build_ext --define CYTHON_TRACE if: ${{ !contains(matrix.python-version, 'pypy') }} - - name: Install requirements + - name: Prepare for install run: | # Newer coveralls do not work with github actions. pip install 'coveralls<3.0.0' pip install cython pip install -U 'setuptools<72.2' - python testsrequire.py - python setup.py develop - # Limit each test time execution. - pip install pytest-timeout - env: - USE_NUMPY: ${{ matrix.use-numpy }} + - name: Install requirements (numpy) + run: pip install -e .[lz4,zstd,numpy] + if: ${{ matrix.use-numpy }} + - name: Install requirements + run: pip install -e .[lz4,zstd] + if: ${{ !matrix.use-numpy }} + - name: Install test requirements + run: pip install -r tests/requirements.txt - name: Run tests run: coverage run -m pytest --timeout=10 -v timeout-minutes: 5 @@ -159,10 +161,8 @@ jobs: echo '127.0.0.1 clickhouse-server' | sudo tee /etc/hosts > /dev/null - name: Install requirements run: | - python testsrequire.py - python setup.py develop - env: - USE_NUMPY: 1 + pip install -e .[lz4,zstd,numpy] + pip install -r tests/requirements.txt - name: Run tests under valgrind run: valgrind -s --error-exitcode=1 --suppressions=valgrind.supp py.test -v env: diff --git a/clickhouse_driver/util/escape.py b/clickhouse_driver/util/escape.py index 465c42fc..5b5b3036 100644 --- a/clickhouse_driver/util/escape.py +++ b/clickhouse_driver/util/escape.py @@ -1,6 +1,6 @@ from datetime import date, datetime, time from enum import Enum -from functools import wraps +from functools import wraps, partial from uuid import UUID from pytz import timezone @@ -26,7 +26,12 @@ def escape_datetime(item, context): if item.tzinfo is not None: item = item.astimezone(server_tz) - return "'%s'" % item.strftime('%Y-%m-%d %H:%M:%S') + if item.microsecond: + format = '%Y-%m-%d %H:%M:%S.%f' + else: + format = '%Y-%m-%d %H:%M:%S' + + return f"'{item.strftime(format)}'" def maybe_enquote_for_server(f): @@ -34,19 +39,33 @@ def maybe_enquote_for_server(f): def wrapper(*args, **kwargs): rv = f(*args, **kwargs) - if kwargs.get('for_server'): - is_str = isinstance(rv, str) + if not kwargs.get('for_server'): + return rv - if not is_str or (is_str and not rv.startswith("'")): - rv = "'%s'" % rv + is_str = isinstance(rv, str) - return rv + nested = kwargs.get('nested') + item = kwargs['item'] if 'item' in kwargs else args[0] + if is_str and not isinstance(item, (list, tuple)): + if rv[0] == "'": + if nested: + return f"\\'{rv[1:-1]}\\'" + return rv + if nested: + return f"\\'{rv}\\'" + return f"'{rv}'" + + if nested: + return str(rv) + return f"'{rv!s}'" return wrapper @maybe_enquote_for_server -def escape_param(item, context, for_server=False): +def escape_param( + item, context, for_server=False, nested=False +): if item is None: return 'NULL' @@ -54,32 +73,63 @@ def escape_param(item, context, for_server=False): return escape_datetime(item, context) elif isinstance(item, date): - return "'%s'" % item.strftime('%Y-%m-%d') + return f"'{item.strftime('%Y-%m-%d')}'" elif isinstance(item, time): - return "'%s'" % item.strftime('%H:%M:%S') + return f"'{item.strftime('%H:%M:%S')}'" elif isinstance(item, str): # We need double escaping for server-side parameters. if for_server: item = ''.join(escape_chars_map.get(c, c) for c in item) - return "'%s'" % ''.join(escape_chars_map.get(c, c) for c in item) + return f"'{''.join(escape_chars_map.get(c, c) for c in item)}'" elif isinstance(item, list): - return "[%s]" % ', '.join( - str(escape_param(x, context, for_server=for_server)) for x in item + serialized_array = ', '.join( + str( + escape_param( + x, + context, + for_server=for_server, + nested=True, + ) + ) for x in item ) + return f'[{serialized_array}]' elif isinstance(item, tuple): - return "(%s)" % ', '.join( - str(escape_param(x, context, for_server=for_server)) for x in item + serialized_tuple = ', '.join( + str( + escape_param( + x, + context, + for_server=for_server, + nested=True, + ) + ) for x in item + ) + + return f'({serialized_tuple})' + + elif isinstance(item, dict): + serializer = partial( + escape_param, + context=context, + for_server=for_server, + nested=True, + ) + + serialized_dict = ', '.join( + f'{serializer(key)!s}: {serializer(value)!s}' + for key, value in item.items() ) + return f'{{{serialized_dict}}}' elif isinstance(item, Enum): return escape_param(item.value, context, for_server=for_server) elif isinstance(item, UUID): - return "'%s'" % str(item) + return f"'{item!s}'" else: return item diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..1eb6248b --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,5 @@ +pytest +pytest-subtests +parameterized +freezegun +pytest-timeout diff --git a/tests/test_substitution.py b/tests/test_substitution.py index 800759fd..c3438a4f 100644 --- a/tests/test_substitution.py +++ b/tests/test_substitution.py @@ -1,10 +1,13 @@ # coding=utf-8 from __future__ import unicode_literals +import struct +import unittest from datetime import date, datetime, time from decimal import Decimal +from ipaddress import ip_address from unittest.mock import Mock -from uuid import UUID +from uuid import UUID, uuid4 from enum import IntEnum, Enum from pytz import timezone @@ -247,13 +250,185 @@ class ServerSideParametersSubstitutionTestCase(BaseTestCase): client_kwargs = {'settings': {'server_side_params': True}} + def _test_type_aliases(self, x, type_name, type_postfix=''): + aliases = self.client.execute( + "SELECT name FROM system.data_type_families " + f"WHERE alias_to = '{type_name}'" + ) + for (alias,) in aliases: + with self.subTest( + msg=f'{alias}{type_postfix}', + alias_to=f'{type_name}{type_postfix}', + ): + rv = self.client.execute( + f'SELECT {{x:{alias}{type_postfix}}}', {'x': x} + ) + self.assertEqual(rv, [(x, )]) + + def _test_type_serialization(self, x, type_pattern, type_postfix=''): + matching_types = self.client.execute( + f"SELECT name FROM system.data_type_families " + f"WHERE match(name, '{type_pattern}')" + ) + self.assertGreaterEqual( + len(matching_types), 1, msg='Matching types not found' + ) + for (matching_type,) in matching_types: + with self.subTest(msg=f'{matching_type}{type_postfix}'): + rv = self.client.execute( + f'SELECT {{x:{matching_type}{type_postfix}}}', {'x': x} + ) + self.assertEqual(rv, [(x, )]) + self._test_type_aliases(x, matching_type, type_postfix) + def test_int(self): - rv = self.client.execute('SELECT {x:Int32}', {'x': 123}) - self.assertEqual(rv, [(123, )]) + self._test_type_serialization(123, '^Int\\d+$') + + def test_uint(self): + self._test_type_serialization(123, '^UInt\\d+$') + + def test_float(self): + # Make sure float is the same in single and double precision + x = struct.unpack('=f', struct.pack('=f', 123.45))[0] + + self._test_type_serialization(x, '^Float\\d+$') + + def test_decimal(self): + x = Decimal(12345) / Decimal(100) + self._test_type_serialization(x, '^Decimal$', '(5,2)') def test_str(self): - rv = self.client.execute('SELECT {x:Int32}', {'x': '123'}) - self.assertEqual(rv, [(123, )]) + x = "123'" + self._test_type_serialization(x, '^String$') + + def test_date(self): + x = date(year=2024, month=1, day=18) + self._test_type_serialization(x, '^Date\\d*$') + + def test_datetime(self): + x = datetime( + year=2024, + month=1, + day=18, + hour=23, + minute=12, + second=27, + microsecond=0, + tzinfo=None, + ) + self._test_type_serialization(x, '^DateTime$') + + def test_datetime64(self): + x = datetime( + year=2024, + month=1, + day=18, + hour=23, + minute=12, + second=27, + microsecond=123, + tzinfo=None, + ) + self._test_type_serialization(x, '^DateTime64$', '(6)') + + def test_enum(self): + class HelloEnum(Enum): + hello = 'hello' + x = HelloEnum.hello + with self.subTest(msg='Enum'): + rv = self.client.execute( + "SELECT {x:Enum('hello')}", {'x': x} + ) + self.assertEqual(rv, [(x.value, )]) + aliases = self.client.execute( + "SELECT name FROM system.data_type_families " + "WHERE alias_to = 'Enum'" + ) + for (alias,) in aliases: + with self.subTest( + msg=f"{alias}('hello')", + alias_to="Enum('hello')", + ): + rv = self.client.execute( + f"SELECT {{x:{alias}('hello')}}", {'x': x} + ) + self.assertEqual(rv, [(x.value, )]) + + def test_bool(self): + x = True + self._test_type_serialization(x, '^Bool$') + + def test_uuid(self): + x = uuid4() + self._test_type_serialization(x, '^UUID') + + def test_ipv4(self): + x = ip_address('127.0.0.1') + self._test_type_serialization(x, '^IPv4$') + + def test_ipv6(self): + x = ip_address('2001:db8::') + self._test_type_serialization(x, '^IPv6$') + + def test_array__int(self): + x = [1, 2, 3] + self._test_type_serialization(x, '^Array$', '(Int32)') + + def test_array__float(self): + x = [1.23, 2.34, 3.45] + self._test_type_serialization(x, '^Array$', '(Float64)') + + def test_array__str(self): + x = ['1', '2', '3'] + self._test_type_serialization(x, '^Array$', '(String)') + + def test_2d_array__int(self): + x = [[1, 2, 3], [5, 6]] + self._test_type_serialization(x, '^Array$', '(Array(Int32))') + + def test_2d_array__float(self): + x = [[1.23, 2.34, 3.45], [5.67, 6.78]] + self._test_type_serialization(x, '^Array$', '(Array(Float64))') + + def test_2d_array__str(self): + x = [['1', '2', '3'], ['5', '6']] + self._test_type_serialization(x, '^Array$', '(Array(String))') + + def test_3d_array__int(self): + x = [[[1, 2, 3], [5, 6]]] + self._test_type_serialization(x, '^Array$', '(Array(Array(Int32)))') + + def test_3d_array__float(self): + x = [[[1.23, 2.34, 3.45], [5.67, 6.78]]] + self._test_type_serialization(x, '^Array$', '(Array(Array(Float64)))') + + def test_3d_array__str(self): + x = [[['1', '2', '3'], ['5', '6']]] + self._test_type_serialization(x, '^Array$', '(Array(Array(String)))') + + def test_tuple(self): + x = (1, 1.23, '123') + self._test_type_serialization(x, '^Tuple$', '(Int32, Float64, String)') + + def test_nested_tuple(self): + x = (1, (1.23, '123'), [('1', 1), ('2', 2), ('3', 3)]) + self._test_type_serialization( + x, + '^Tuple$', + '(Int32, Tuple(Float64, String), Array(Tuple(String, Int32)))' + ) + + def test_map__int(self): + x = {1: 2, 3: 4} + self._test_type_serialization(x, '^Map$', '(UInt32, UInt32)') + + def test_map__string(self): + x = {'1': '34', '2': '45'} + self._test_type_serialization(x, '^Map$', '(String, String)') + + @unittest.skip('Duplicate keys not supported') + def test_map__duplicate_keys(self): + pass def test_escaped_str(self): rv = self.client.execute( diff --git a/testsrequire.py b/testsrequire.py deleted file mode 100644 index 06c98cc4..00000000 --- a/testsrequire.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -import sys - -USE_NUMPY = bool(int(os.getenv('USE_NUMPY', '0'))) - -tests_require = [ - 'pytest', - 'parameterized', - 'freezegun', - 'zstd', - 'clickhouse-cityhash>=1.0.2.1' -] - -if sys.implementation.name == 'pypy': - tests_require.append('lz4<=3.0.1') -else: - tests_require.append('lz4') - -if USE_NUMPY: - tests_require.extend(['numpy', 'pandas']) - -try: - from pip import main as pipmain -except ImportError: - from pip._internal import main as pipmain - -pipmain(['install'] + tests_require)