diff --git a/test/integ/test_cursor.py b/test/integ/test_cursor.py index 5c1e6d744..a49d62398 100644 --- a/test/integ/test_cursor.py +++ b/test/integ/test_cursor.py @@ -10,6 +10,8 @@ import logging import os import pickle +import string +import tempfile import time from datetime import date, datetime, timezone from typing import TYPE_CHECKING, NamedTuple @@ -1843,3 +1845,74 @@ def test_nanoarrow_usage_deprecation(): and "snowflake.connector.cursor.NanoarrowUsage has been deprecated" in str(record[2].message) ) + + +def _generate_lob(length: int) -> str: + base = string.printable.replace(",", "").replace("\n", "") + times, reminder = length // len(base), length % len(base) + return base * times + base[:reminder] + + +_MB = 1024 * 1024 +_LOB_SIZES = [(2**x) * 16 * _MB for x in range(4)] +_LOB_RESULT_FORMATS = ["Arrow", "JSON"] +_LOB_TABLE = "my_lob_test" + + +@pytest.mark.parametrize("lob_size", _LOB_SIZES) +@pytest.mark.parametrize("result_format", _LOB_RESULT_FORMATS) +def test_lob_insert_select(conn_cnx, lob_size, result_format): + with conn_cnx( + session_parameters={ + "ALLOW_LARGE_LOBS_IN_EXTERNAL_SCAN": True, + "python_connector_query_result_format": result_format, + }, + ) as con, con.cursor() as cur: + cur.execute(f"create or replace table {_LOB_TABLE}(c1 varchar({lob_size}))") + lob = _generate_lob(lob_size) + cur.execute( + f"insert into {_LOB_TABLE} values (?)", + params=(lob,), + _force_qmark_paramstyle=True, + ) + fetched_lob = cur.execute(f"select c1 from {_LOB_TABLE}").fetchall()[0][0] + assert lob == fetched_lob + cur.execute(f"drop table if exists {_LOB_TABLE}") + + +@pytest.mark.parametrize("lob_size", _LOB_SIZES) +@pytest.mark.parametrize("result_format", _LOB_RESULT_FORMATS) +def test_lob_put_get(conn_cnx, lob_size, result_format): + with conn_cnx( + session_parameters={ + "ALLOW_LARGE_LOBS_IN_EXTERNAL_SCAN": True, + "python_connector_query_result_format": result_format, + }, + ) as con, con.cursor() as cur: + cur.execute(f"create or replace table {_LOB_TABLE}(c1 varchar({lob_size}))") + lob = _generate_lob(lob_size) + with tempfile.NamedTemporaryFile("w", suffix=".csv") as tmp_input: + tmp_input.write(lob) + tmp_input.flush() + file_path = tmp_input.name + escaped_file_path = file_path.replace("\\", "\\\\") + file_name = os.path.basename(file_path) + put_result = cur.execute( + f"PUT 'file://{escaped_file_path}' @%{_LOB_TABLE}" + ).fetchall()[0] + assert put_result[0] == file_name + assert put_result[1] == file_name + ".gz" + assert put_result[5] == "GZIP" + assert put_result[6] == "UPLOADED" + + ls_result = cur.execute(f"ls @%{_LOB_TABLE}").fetchall()[0] + assert ls_result[0] == file_name + ".gz" + + cur.execute( + f"copy into {_LOB_TABLE} from @%{_LOB_TABLE} file_format=(type=csv compression='gzip')" + ) + + fetched_lob = cur.execute(f"select c1 from {_LOB_TABLE}").fetchall()[0][0] + assert lob == fetched_lob + + cur.execute(f"drop table if exists {_LOB_TABLE}")