From 567768b3df134555cff123b89f0b004601670caa Mon Sep 17 00:00:00 2001 From: mAyty Date: Sun, 19 Jan 2025 02:01:26 +0200 Subject: [PATCH 1/5] fix server-side query params escaping --- clickhouse_driver/util/escape.py | 56 ++++++++-- tests/test_substitution.py | 181 ++++++++++++++++++++++++++++++- testsrequire.py | 1 + 3 files changed, 224 insertions(+), 14 deletions(-) diff --git a/clickhouse_driver/util/escape.py b/clickhouse_driver/util/escape.py index 465c42fc..aea975cf 100644 --- a/clickhouse_driver/util/escape.py +++ b/clickhouse_driver/util/escape.py @@ -26,7 +26,12 @@ def escape_datetime(item, context): if item.tzinfo is not None: item = item.astimezone(server_tz) - return "'%s'" % item.strftime('%Y-%m-%d %H:%M:%S') + if item.microsecond: + format = '%Y-%m-%d %H:%M:%S.%f' + else: + format = '%Y-%m-%d %H:%M:%S' + + return "'%s'" % item.strftime(format) def maybe_enquote_for_server(f): @@ -34,19 +39,36 @@ def maybe_enquote_for_server(f): def wrapper(*args, **kwargs): rv = f(*args, **kwargs) - if kwargs.get('for_server'): - is_str = isinstance(rv, str) + if not kwargs.get('for_server'): + return rv + + is_str = isinstance(rv, str) + + nested = kwargs.get('nested') + item = kwargs['item'] if 'item' in kwargs else args[0] + if is_str and not isinstance(item, (list, tuple)): + if rv[0] == "'": + if nested: + return "\\'%s\\'" % rv[1:-1] + return rv + if nested: + return "\\'%s\\'" % rv + return "'%s'" % rv - if not is_str or (is_str and not rv.startswith("'")): - rv = "'%s'" % rv + if kwargs.get('for_iterable'): + return '%s' % rv - return rv + if nested: + return "\\'%s\\'" % rv + return "'%s'" % rv return wrapper @maybe_enquote_for_server -def escape_param(item, context, for_server=False): +def escape_param( + item, context, for_server=False, for_iterable=False, nested=False +): if item is None: return 'NULL' @@ -67,12 +89,28 @@ def escape_param(item, context, for_server=False): elif isinstance(item, list): return "[%s]" % ', '.join( - str(escape_param(x, context, for_server=for_server)) for x in item + str( + escape_param( + x, + context, + for_server=for_server, + for_iterable=True, + nested=True, + ) + ) for x in item ) elif isinstance(item, tuple): return "(%s)" % ', '.join( - str(escape_param(x, context, for_server=for_server)) for x in item + str( + escape_param( + x, + context, + for_server=for_server, + for_iterable=True, + nested=True, + ) + ) for x in item ) elif isinstance(item, Enum): diff --git a/tests/test_substitution.py b/tests/test_substitution.py index 800759fd..f31ca717 100644 --- a/tests/test_substitution.py +++ b/tests/test_substitution.py @@ -1,10 +1,13 @@ # coding=utf-8 from __future__ import unicode_literals +import struct +import unittest from datetime import date, datetime, time from decimal import Decimal +from ipaddress import ip_address from unittest.mock import Mock -from uuid import UUID +from uuid import UUID, uuid4 from enum import IntEnum, Enum from pytz import timezone @@ -247,13 +250,181 @@ class ServerSideParametersSubstitutionTestCase(BaseTestCase): client_kwargs = {'settings': {'server_side_params': True}} + def _test_type_aliases(self, x, type_name, type_postfix=''): + aliases = self.client.execute( + "SELECT name FROM system.data_type_families " + f"WHERE alias_to = '{type_name}'" + ) + for (alias,) in aliases: + with self.subTest( + msg=f'{alias}{type_postfix}', + alias_to=f'{type_name}{type_postfix}', + ): + rv = self.client.execute( + f'SELECT {{x:{alias}{type_postfix}}}', {'x': x} + ) + self.assertEqual(rv, [(x, )]) + + def _test_type_serialization(self, x, type_pattern, type_postfix=''): + matching_types = self.client.execute( + f"SELECT name FROM system.data_type_families " + f"WHERE match(name, '{type_pattern}')" + ) + self.assertGreaterEqual( + len(matching_types), 1, msg='Matching types not found' + ) + for (matching_type,) in matching_types: + with self.subTest(msg=f'{matching_type}{type_postfix}'): + rv = self.client.execute( + f'SELECT {{x:{matching_type}{type_postfix}}}', {'x': x} + ) + self.assertEqual(rv, [(x, )]) + self._test_type_aliases(x, matching_type, type_postfix) + def test_int(self): - rv = self.client.execute('SELECT {x:Int32}', {'x': 123}) - self.assertEqual(rv, [(123, )]) + self._test_type_serialization(123, '^Int\\d+$') + + def test_uint(self): + self._test_type_serialization(123, '^UInt\\d+$') + + def test_float(self): + # Make sure float is the same in single and double precision + x = struct.unpack('=f', struct.pack('=f', 123.45))[0] + + self._test_type_serialization(x, '^Float\\d+$') + + def test_decimal(self): + x = Decimal(12345) / Decimal(100) + self._test_type_serialization(x, '^Decimal$', '(5,2)') def test_str(self): - rv = self.client.execute('SELECT {x:Int32}', {'x': '123'}) - self.assertEqual(rv, [(123, )]) + x = "123'" + self._test_type_serialization(x, '^String$') + + def test_date(self): + x = date(year=2024, month=1, day=18) + self._test_type_serialization(x, '^Date\\d*$') + + def test_datetime(self): + x = datetime( + year=2024, + month=1, + day=18, + hour=23, + minute=12, + second=27, + microsecond=0, + tzinfo=None, + ) + self._test_type_serialization(x, '^DateTime$') + + def test_datetime64(self): + x = datetime( + year=2024, + month=1, + day=18, + hour=23, + minute=12, + second=27, + microsecond=123, + tzinfo=None, + ) + self._test_type_serialization(x, '^DateTime64$', '(6)') + + def test_enum(self): + class HelloEnum(Enum): + hello = 'hello' + x = HelloEnum.hello + with self.subTest(msg='Enum'): + rv = self.client.execute( + "SELECT {x:Enum('hello')}", {'x': x} + ) + self.assertEqual(rv, [(x.value, )]) + aliases = self.client.execute( + "SELECT name FROM system.data_type_families " + "WHERE alias_to = 'Enum'" + ) + for (alias,) in aliases: + with self.subTest( + msg=f"{alias}('hello')", + alias_to="Enum('hello')", + ): + rv = self.client.execute( + f"SELECT {{x:{alias}('hello')}}", {'x': x} + ) + self.assertEqual(rv, [(x.value, )]) + + def test_bool(self): + x = True + self._test_type_serialization(x, '^Bool$') + + def test_uuid(self): + x = uuid4() + self._test_type_serialization(x, '^UUID') + + def test_ipv4(self): + x = ip_address('127.0.0.1') + self._test_type_serialization(x, '^IPv4$') + + def test_ipv6(self): + x = ip_address('2001:db8::') + self._test_type_serialization(x, '^IPv6$') + + def test_array__int(self): + x = [1, 2, 3] + self._test_type_serialization(x, '^Array$', '(Int32)') + + def test_array__float(self): + x = [1.23, 2.34, 3.45] + self._test_type_serialization(x, '^Array$', '(Float64)') + + def test_array__str(self): + x = ['1', '2', '3'] + self._test_type_serialization(x, '^Array$', '(String)') + + def test_2d_array__int(self): + x = [[1, 2, 3], [5, 6]] + self._test_type_serialization(x, '^Array$', '(Array(Int32))') + + def test_2d_array__float(self): + x = [[1.23, 2.34, 3.45], [5.67, 6.78]] + self._test_type_serialization(x, '^Array$', '(Array(Float64))') + + def test_2d_array__str(self): + x = [['1', '2', '3'], ['5', '6']] + self._test_type_serialization(x, '^Array$', '(Array(String))') + + def test_3d_array__int(self): + x = [[[1, 2, 3], [5, 6]]] + self._test_type_serialization(x, '^Array$', '(Array(Array(Int32)))') + + def test_3d_array__float(self): + x = [[[1.23, 2.34, 3.45], [5.67, 6.78]]] + self._test_type_serialization(x, '^Array$', '(Array(Array(Float64)))') + + def test_3d_array__str(self): + x = [[['1', '2', '3'], ['5', '6']]] + self._test_type_serialization(x, '^Array$', '(Array(Array(String)))') + + def test_tuple(self): + x = (1, 1.23, '123') + self._test_type_serialization(x, '^Tuple$', '(Int32, Float64, String)') + + def test_nested_tuple(self): + x = (1, (1.23, '123'), [('1', 1), ('2', 2), ('3', 3)]) + self._test_type_serialization( + x, + '^Tuple$', + '(Int32, Tuple(Float64, String), Array(Tuple(String, Int32)))' + ) + + def test_map(self): + x = {1: 2, 3: 4} + self._test_type_serialization(x, '^Map$', '(UInt32, UInt32)') + + @unittest.skip('Duplicate keys not supported') + def test_map__duplicate_keys(self): + pass def test_escaped_str(self): rv = self.client.execute( diff --git a/testsrequire.py b/testsrequire.py index 06c98cc4..a83be7f4 100644 --- a/testsrequire.py +++ b/testsrequire.py @@ -5,6 +5,7 @@ tests_require = [ 'pytest', + 'pytest-subtests' 'parameterized', 'freezegun', 'zstd', From e2327e86af4480f816fb4481968524ce20a34ac1 Mon Sep 17 00:00:00 2001 From: mAyty Date: Sun, 19 Jan 2025 02:44:37 +0200 Subject: [PATCH 2/5] fix actions --- .github/workflows/actions.yml | 16 +++++++++------- tests/requirements.txt | 5 +++++ testsrequire.py | 28 ---------------------------- 3 files changed, 14 insertions(+), 35 deletions(-) create mode 100644 tests/requirements.txt delete mode 100644 testsrequire.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 2ad445a7..4cc5f149 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -86,18 +86,20 @@ jobs: - name: Build cython extensions with tracing run: CYTHON_TRACE=1 python setup.py build_ext --define CYTHON_TRACE if: ${{ !contains(matrix.python-version, 'pypy') }} - - name: Install requirements + - name: Prepare for install run: | # Newer coveralls do not work with github actions. pip install 'coveralls<3.0.0' pip install cython pip install -U 'setuptools<72.2' - python testsrequire.py - python setup.py develop - # Limit each test time execution. - pip install pytest-timeout - env: - USE_NUMPY: ${{ matrix.use-numpy }} + - name: Install requirements (numpy) + run: pip install -e .[lz4,zstd,numpy] + if: ${{ matrix.use-numpy }} + - name: Install requirements + run: pip install -e .[lz4,zstd] + if: ${{ !matrix.use-numpy }} + - name: Install test requirements + run: pip install -r tests/requirements.txt - name: Run tests run: coverage run -m pytest --timeout=10 -v timeout-minutes: 5 diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..1eb6248b --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,5 @@ +pytest +pytest-subtests +parameterized +freezegun +pytest-timeout diff --git a/testsrequire.py b/testsrequire.py deleted file mode 100644 index a83be7f4..00000000 --- a/testsrequire.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import sys - -USE_NUMPY = bool(int(os.getenv('USE_NUMPY', '0'))) - -tests_require = [ - 'pytest', - 'pytest-subtests' - 'parameterized', - 'freezegun', - 'zstd', - 'clickhouse-cityhash>=1.0.2.1' -] - -if sys.implementation.name == 'pypy': - tests_require.append('lz4<=3.0.1') -else: - tests_require.append('lz4') - -if USE_NUMPY: - tests_require.extend(['numpy', 'pandas']) - -try: - from pip import main as pipmain -except ImportError: - from pip._internal import main as pipmain - -pipmain(['install'] + tests_require) From 8ebf9b31b6bb500fade02c37a12260204fe834fb Mon Sep 17 00:00:00 2001 From: mAyty Date: Sun, 19 Jan 2025 04:15:24 +0200 Subject: [PATCH 3/5] fix map serialization --- clickhouse_driver/util/escape.py | 46 ++++++++++++++++++++++---------- tests/test_substitution.py | 6 ++++- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/clickhouse_driver/util/escape.py b/clickhouse_driver/util/escape.py index aea975cf..8e7a212c 100644 --- a/clickhouse_driver/util/escape.py +++ b/clickhouse_driver/util/escape.py @@ -1,6 +1,6 @@ from datetime import date, datetime, time from enum import Enum -from functools import wraps +from functools import wraps, partial from uuid import UUID from pytz import timezone @@ -31,7 +31,7 @@ def escape_datetime(item, context): else: format = '%Y-%m-%d %H:%M:%S' - return "'%s'" % item.strftime(format) + return f"'{item.strftime(format)}'" def maybe_enquote_for_server(f): @@ -49,18 +49,18 @@ def wrapper(*args, **kwargs): if is_str and not isinstance(item, (list, tuple)): if rv[0] == "'": if nested: - return "\\'%s\\'" % rv[1:-1] + return f"\\'{rv[1:-1]}\\'" return rv if nested: - return "\\'%s\\'" % rv - return "'%s'" % rv + return f"\\'{rv}\\'" + return f"'{rv}'" if kwargs.get('for_iterable'): - return '%s' % rv + return str(rv) if nested: - return "\\'%s\\'" % rv - return "'%s'" % rv + return f"\\'{rv!s}\\'" + return f"'{rv!s}'" return wrapper @@ -76,19 +76,19 @@ def escape_param( return escape_datetime(item, context) elif isinstance(item, date): - return "'%s'" % item.strftime('%Y-%m-%d') + return f"'{item.strftime('%Y-%m-%d')}'" elif isinstance(item, time): - return "'%s'" % item.strftime('%H:%M:%S') + return f"'{item.strftime('%H:%M:%S')}'" elif isinstance(item, str): # We need double escaping for server-side parameters. if for_server: item = ''.join(escape_chars_map.get(c, c) for c in item) - return "'%s'" % ''.join(escape_chars_map.get(c, c) for c in item) + return f"'{''.join(escape_chars_map.get(c, c) for c in item)}'" elif isinstance(item, list): - return "[%s]" % ', '.join( + serialized_array = ', '.join( str( escape_param( x, @@ -99,9 +99,10 @@ def escape_param( ) ) for x in item ) + return f'[{serialized_array}]' elif isinstance(item, tuple): - return "(%s)" % ', '.join( + serialized_tuple = ', '.join( str( escape_param( x, @@ -113,11 +114,28 @@ def escape_param( ) for x in item ) + return f'({serialized_tuple})' + + elif isinstance(item, dict): + serializer = partial( + escape_param, + context=context, + for_server=for_server, + for_iterable=True, + nested=True, + ) + + serialized_dict = ', '.join( + f'{serializer(key)!s}: {serializer(value)!s}' + for key, value in item.items() + ) + return f'{{{serialized_dict}}}' + elif isinstance(item, Enum): return escape_param(item.value, context, for_server=for_server) elif isinstance(item, UUID): - return "'%s'" % str(item) + return f"'{item!s}'" else: return item diff --git a/tests/test_substitution.py b/tests/test_substitution.py index f31ca717..c3438a4f 100644 --- a/tests/test_substitution.py +++ b/tests/test_substitution.py @@ -418,10 +418,14 @@ def test_nested_tuple(self): '(Int32, Tuple(Float64, String), Array(Tuple(String, Int32)))' ) - def test_map(self): + def test_map__int(self): x = {1: 2, 3: 4} self._test_type_serialization(x, '^Map$', '(UInt32, UInt32)') + def test_map__string(self): + x = {'1': '34', '2': '45'} + self._test_type_serialization(x, '^Map$', '(String, String)') + @unittest.skip('Duplicate keys not supported') def test_map__duplicate_keys(self): pass From 69a594ea93af7bcaddd9ba719c586c1845d1bf85 Mon Sep 17 00:00:00 2001 From: mAyty Date: Sun, 19 Jan 2025 04:34:21 +0200 Subject: [PATCH 4/5] refactoring --- clickhouse_driver/util/escape.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/clickhouse_driver/util/escape.py b/clickhouse_driver/util/escape.py index 8e7a212c..5b5b3036 100644 --- a/clickhouse_driver/util/escape.py +++ b/clickhouse_driver/util/escape.py @@ -55,11 +55,8 @@ def wrapper(*args, **kwargs): return f"\\'{rv}\\'" return f"'{rv}'" - if kwargs.get('for_iterable'): - return str(rv) - if nested: - return f"\\'{rv!s}\\'" + return str(rv) return f"'{rv!s}'" return wrapper @@ -67,7 +64,7 @@ def wrapper(*args, **kwargs): @maybe_enquote_for_server def escape_param( - item, context, for_server=False, for_iterable=False, nested=False + item, context, for_server=False, nested=False ): if item is None: return 'NULL' @@ -94,7 +91,6 @@ def escape_param( x, context, for_server=for_server, - for_iterable=True, nested=True, ) ) for x in item @@ -108,7 +104,6 @@ def escape_param( x, context, for_server=for_server, - for_iterable=True, nested=True, ) ) for x in item @@ -121,7 +116,6 @@ def escape_param( escape_param, context=context, for_server=for_server, - for_iterable=True, nested=True, ) From 03cfc1fdb9b1167eecd347ee1e5b755141486763 Mon Sep 17 00:00:00 2001 From: mAyty Date: Sun, 19 Jan 2025 16:40:43 +0200 Subject: [PATCH 5/5] fix valgrind action --- .github/workflows/actions.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 4cc5f149..cf593959 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -161,10 +161,8 @@ jobs: echo '127.0.0.1 clickhouse-server' | sudo tee /etc/hosts > /dev/null - name: Install requirements run: | - python testsrequire.py - python setup.py develop - env: - USE_NUMPY: 1 + pip install -e .[lz4,zstd,numpy] + pip install -r tests/requirements.txt - name: Run tests under valgrind run: valgrind -s --error-exitcode=1 --suppressions=valgrind.supp py.test -v env: