Skip to content

Commit d3dafb0

Browse files
committed
Add bulk upsert to connection
1 parent 7b4fac3 commit d3dafb0

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tests/test_connections.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,57 @@ def _test_errors(
149149
maybe_await(cur.execute_scheme("DROP TABLE test"))
150150
maybe_await(cur.close())
151151

152+
def _test_bulk_upsert(self, connection: dbapi.Connection) -> None:
153+
cursor = connection.cursor()
154+
with suppress(dbapi.DatabaseError):
155+
maybe_await(cursor.execute_scheme("DROP TABLE pet"))
156+
157+
maybe_await(cursor.execute_scheme(
158+
"""
159+
CREATE TABLE pet (
160+
pet_id INT,
161+
name TEXT NOT NULL,
162+
pet_type TEXT NOT NULL,
163+
birth_date TEXT NOT NULL,
164+
owner TEXT NOT NULL,
165+
PRIMARY KEY (pet_id)
166+
);
167+
"""
168+
))
169+
170+
column_types = (
171+
ydb.BulkUpsertColumns()
172+
.add_column("pet_id", ydb.OptionalType(ydb.PrimitiveType.Int32))
173+
.add_column("name", ydb.PrimitiveType.Utf8)
174+
.add_column("pet_type", ydb.PrimitiveType.Utf8)
175+
.add_column("birth_date", ydb.PrimitiveType.Utf8)
176+
.add_column("owner", ydb.PrimitiveType.Utf8)
177+
)
178+
179+
rows = [
180+
{
181+
"pet_id": 3,
182+
"name": "Lester",
183+
"pet_type": "Hamster",
184+
"birth_date": "2020-06-23",
185+
"owner": "Lily"
186+
},
187+
{
188+
"pet_id": 4,
189+
"name": "Quincy",
190+
"pet_type": "Parrot",
191+
"birth_date": "2013-08-11",
192+
"owner": "Anne"
193+
},
194+
]
195+
196+
maybe_await(connection.bulk_upsert("pet", rows, column_types))
197+
198+
maybe_await(cursor.execute("SELECT * FROM pet"))
199+
assert cursor.rowcount == 2
200+
201+
maybe_await(cursor.execute_scheme("DROP TABLE pet"))
202+
152203

153204
class TestConnection(BaseDBApiTestSuit):
154205
@pytest.fixture
@@ -191,6 +242,9 @@ def test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
191242
def test_errors(self, connection: dbapi.Connection) -> None:
192243
self._test_errors(connection)
193244

245+
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
246+
self._test_bulk_upsert(connection)
247+
194248

195249
class TestAsyncConnection(BaseDBApiTestSuit):
196250
@pytest_asyncio.fixture
@@ -244,3 +298,9 @@ async def test_cursor_raw_query(
244298
@pytest.mark.asyncio
245299
async def test_errors(self, connection: dbapi.AsyncConnection) -> None:
246300
await greenlet_spawn(self._test_errors, connection)
301+
302+
@pytest.mark.asyncio
303+
async def test_bulk_upsert(
304+
self, connection: dbapi.AsyncConnection
305+
) -> None:
306+
await greenlet_spawn(self._test_bulk_upsert, connection)

ydb_dbapi/connections.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import posixpath
4+
from collections.abc import Sequence
45
from enum import Enum
56
from typing import NamedTuple
67

@@ -301,6 +302,25 @@ def callee() -> ydb.Directory:
301302
result.extend(self._get_table_names(child_abs_path))
302303
return result
303304

305+
@handle_ydb_errors
306+
def bulk_upsert(
307+
self,
308+
table_name: str,
309+
rows: Sequence,
310+
column_types: ydb.BulkUpsertColumns,
311+
) -> None:
312+
settings = self._get_request_settings()
313+
abs_table_path = posixpath.join(
314+
self.database, self.table_path_prefix, table_name
315+
)
316+
317+
self._driver.table_client.bulk_upsert(
318+
abs_table_path,
319+
rows=rows,
320+
column_types=column_types,
321+
settings=settings,
322+
)
323+
304324

305325
class AsyncConnection(BaseConnection):
306326
_driver_cls = ydb.aio.Driver
@@ -446,6 +466,25 @@ async def callee() -> ydb.Directory:
446466
result.extend(await self._get_table_names(child_abs_path))
447467
return result
448468

469+
@handle_ydb_errors
470+
async def bulk_upsert(
471+
self,
472+
table_name: str,
473+
rows: Sequence,
474+
column_types: ydb.BulkUpsertColumns,
475+
) -> None:
476+
settings = self._get_request_settings()
477+
abs_table_path = posixpath.join(
478+
self.database, self.table_path_prefix, table_name
479+
)
480+
481+
await self._driver.table_client.bulk_upsert(
482+
abs_table_path,
483+
rows=rows,
484+
column_types=column_types,
485+
settings=settings,
486+
)
487+
449488

450489
def connect(*args: tuple, **kwargs: dict) -> Connection:
451490
conn = Connection(*args, **kwargs) # type: ignore

0 commit comments

Comments
 (0)