Skip to content

Commit 88d5cb8

Browse files
committed
fix all tests
1 parent baabab8 commit 88d5cb8

File tree

7 files changed

+214
-130
lines changed

7 files changed

+214
-130
lines changed

poetry.lock

Lines changed: 86 additions & 86 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ select = ["F", "E"]
4949

5050
[tool.pytest.ini_options]
5151
asyncio_mode = "auto"
52-
addopts = "-p no:warnings"
52+
addopts = "--tb native -v -r fxX -p no:warnings"
5353

5454
[[tool.mypy.overrides]]
5555
module = "ydb.*"

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ async def driver(endpoint, database, event_loop):
5555
yield driver
5656

5757
await driver.stop(timeout=10)
58+
del driver
5859

5960

6061
@pytest.fixture
61-
async def session_pool(driver: ydb.aio.Driver, event_loop):
62+
async def session_pool(driver: ydb.aio.Driver):
6263
session_pool = ydb.aio.QuerySessionPool(driver)
6364
async with session_pool:
6465
await session_pool.execute_with_retries(

tests/test_connection.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,32 @@ async def _test_isolation_level_read_only(
1414
isolation_level: str,
1515
read_only: bool,
1616
):
17-
await connection.cursor().execute(
18-
"CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))"
19-
)
20-
connection.set_isolation_level(isolation_level)
17+
async with connection.cursor() as cursor:
18+
with suppress(dbapi.DatabaseError):
19+
await cursor.execute("DROP TABLE foo")
2120

22-
cursor = connection.cursor()
21+
async with connection.cursor() as cursor:
22+
await cursor.execute(
23+
"CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))"
24+
)
25+
26+
connection.set_isolation_level(isolation_level)
2327

2428
await connection.begin()
2529

26-
query = "UPSERT INTO foo(id) VALUES (1)"
27-
if read_only:
28-
with pytest.raises(dbapi.DatabaseError):
30+
async with connection.cursor() as cursor:
31+
query = "UPSERT INTO foo(id) VALUES (1)"
32+
if read_only:
33+
with pytest.raises(dbapi.DatabaseError):
34+
await cursor.execute(query)
35+
await cursor.finish_query()
36+
else:
2937
await cursor.execute(query)
30-
else:
31-
await cursor.execute(query)
3238

3339
await connection.rollback()
3440

35-
await connection.cursor().execute("DROP TABLE foo")
36-
await connection.cursor().close()
41+
async with connection.cursor() as cursor:
42+
cursor.execute("DROP TABLE foo")
3743

3844
async def _test_connection(self, connection: dbapi.Connection):
3945
await connection.commit()
@@ -42,6 +48,7 @@ async def _test_connection(self, connection: dbapi.Connection):
4248
cur = connection.cursor()
4349
with suppress(dbapi.DatabaseError):
4450
await cur.execute("DROP TABLE foo")
51+
await cur.finish_query()
4552

4653
assert not await connection.check_exists("/local/foo")
4754
with pytest.raises(dbapi.ProgrammingError):
@@ -50,6 +57,7 @@ async def _test_connection(self, connection: dbapi.Connection):
5057
await cur.execute(
5158
"CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))"
5259
)
60+
await cur.finish_query()
5361

5462
assert await connection.check_exists("/local/foo")
5563

@@ -66,10 +74,12 @@ async def _test_cursor_raw_query(self, connection: dbapi.Connection):
6674

6775
with suppress(dbapi.DatabaseError):
6876
await cur.execute("DROP TABLE test")
77+
await cur.finish_query()
6978

7079
await cur.execute(
7180
"CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))"
7281
)
82+
await cur.finish_query()
7383

7484
await cur.execute(
7585
"""
@@ -91,6 +101,7 @@ async def _test_cursor_raw_query(self, connection: dbapi.Connection):
91101
)
92102
},
93103
)
104+
await cur.finish_query()
94105

95106
await cur.execute("DROP TABLE test")
96107

@@ -104,6 +115,7 @@ async def _test_errors(self, connection: dbapi.Connection):
104115

105116
with suppress(dbapi.DatabaseError):
106117
await cur.execute("DROP TABLE test")
118+
await cur.finish_query()
107119

108120
with pytest.raises(dbapi.DataError):
109121
await cur.execute("SELECT 18446744073709551616")
@@ -118,8 +130,11 @@ async def _test_errors(self, connection: dbapi.Connection):
118130
await cur.execute("SELECT * FROM test")
119131

120132
await cur.execute("CREATE TABLE test(id Int64, PRIMARY KEY (id))")
133+
await cur.finish_query()
121134

122135
await cur.execute("INSERT INTO test(id) VALUES(1)")
136+
await cur.finish_query()
137+
123138
with pytest.raises(dbapi.IntegrityError):
124139
await cur.execute("INSERT INTO test(id) VALUES(1)")
125140

@@ -143,10 +158,10 @@ async def connection(self, endpoint, database):
143158
[
144159
(dbapi.IsolationLevel.SERIALIZABLE, False),
145160
(dbapi.IsolationLevel.AUTOCOMMIT, False),
146-
# (dbapi.IsolationLevel.ONLINE_READONLY, True),
147-
# (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
148-
# (dbapi.IsolationLevel.STALE_READONLY, True),
149-
# (dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
161+
(dbapi.IsolationLevel.ONLINE_READONLY, True),
162+
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
163+
(dbapi.IsolationLevel.STALE_READONLY, True),
164+
(dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
150165
],
151166
)
152167
async def test_isolation_level_read_only(

ydb_dbapi/connection.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __init__(
6868
self._tx_context: Optional[ydb.QueryTxContext] = None
6969
self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
7070

71+
self._current_cursor: Optional[Cursor] = None
72+
self.interactive_transaction: bool = False
73+
7174
async def _wait(self, timeout: int = 5):
7275
try:
7376
await self._driver.wait(timeout, fail_fast=True)
@@ -81,15 +84,22 @@ async def _wait(self, timeout: int = 5):
8184
) from e
8285

8386
def cursor(self):
84-
return Cursor(
85-
session_pool=self._session_pool, tx_context=self._tx_context
87+
if self._current_cursor and not self._current_cursor._closed:
88+
raise RuntimeError(
89+
"Unable to create new Cursor before closing existing one."
90+
)
91+
self._current_cursor = Cursor(
92+
session_pool=self._session_pool,
93+
tx_context=self._tx_context,
94+
autocommit=(not self.interactive_transaction),
8695
)
96+
return self._current_cursor
8797

8898
async def begin(self):
8999
self._tx_context = None
90100
self._session = await self._session_pool.acquire()
91101
self._tx_context = self._session.transaction(self._tx_mode)
92-
await self._tx_context.begin()
102+
# await self._tx_context.begin()
93103

94104
async def commit(self):
95105
if self._tx_context and self._tx_context.tx_id:
@@ -107,6 +117,10 @@ async def rollback(self):
107117

108118
async def close(self):
109119
await self.rollback()
120+
121+
if self._current_cursor:
122+
await self._current_cursor.close()
123+
110124
if not self._shared_session_pool:
111125
await self._session_pool.stop()
112126
await self._driver.stop()

ydb_dbapi/cursors.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import dataclasses
21
import itertools
32
from typing import (
43
Any,
@@ -12,8 +11,8 @@
1211
)
1312

1413
import ydb
15-
from .errors import Error, DatabaseError
16-
from .utils import handle_ydb_errors, AsyncFromSyncIterator
14+
from .errors import Error, DatabaseError, InterfaceError, ProgrammingError
15+
from .utils import handle_ydb_errors, AsyncFromSyncIterator, CursorStatus
1716

1817

1918
ParametersType = Dict[
@@ -26,12 +25,6 @@
2625
]
2726

2827

29-
@dataclasses.dataclass
30-
class YdbQuery:
31-
yql_text: str
32-
is_ddl: bool = False
33-
34-
3528
def _get_column_type(type_obj: Any) -> str:
3629
return str(ydb.convert.type_to_native(type_obj))
3730

@@ -56,13 +49,26 @@ def __init__(
5649
self._rows: Optional[Iterator[Dict]] = None
5750
self._rows_count: int = -1
5851

52+
self._closed: bool = False
53+
self._state = CursorStatus.ready
54+
55+
@property
56+
def description(self):
57+
return self._description
58+
59+
@property
60+
def rowcount(self):
61+
return self._rows_count
62+
63+
@handle_ydb_errors
5964
async def _execute_generic_query(
6065
self, query: str, parameters: Optional[ParametersType] = None
6166
) -> List[ydb.convert.ResultSet]:
6267
return await self._session_pool.execute_with_retries(
6368
query=query, parameters=parameters
6469
)
6570

71+
@handle_ydb_errors
6672
async def _execute_transactional_query(
6773
self, query: str, parameters: Optional[ParametersType] = None
6874
) -> AsyncIterator:
@@ -76,10 +82,14 @@ async def _execute_transactional_query(
7682
commit_tx=self._autocommit,
7783
)
7884

79-
@handle_ydb_errors
8085
async def execute(
81-
self, query: str, parameters: Optional[ParametersType] = None
86+
self,
87+
query: str,
88+
parameters: Optional[ParametersType] = None,
89+
prefetch_first_set: bool = True,
8290
):
91+
self._check_cursor_closed()
92+
self._check_pending_query()
8393
if self._tx_context is not None:
8494
self._stream = await self._execute_transactional_query(
8595
query=query, parameters=parameters
@@ -93,8 +103,10 @@ async def execute(
93103
if self._stream is None:
94104
return
95105

96-
result_set = await self._stream.__anext__()
97-
self._update_result_set(result_set)
106+
self._begin_query()
107+
108+
if prefetch_first_set:
109+
await self.nextset()
98110

99111
def _update_result_set(self, result_set: ydb.convert.ResultSet):
100112
self._update_description(result_set)
@@ -144,31 +156,65 @@ async def fetchmany(self, size: Optional[int] = None):
144156
async def fetchall(self):
145157
return list(self._rows or iter([])) or None
146158

159+
@handle_ydb_errors
147160
async def nextset(self):
148161
if self._stream is None:
149162
return False
150163
try:
151164
result_set = await self._stream.__anext__()
152165
self._update_result_set(result_set)
153-
except (StopIteration, RuntimeError):
166+
except (StopIteration, StopAsyncIteration, RuntimeError):
167+
self._state = CursorStatus.finished
154168
return False
169+
except ydb.Error as e:
170+
self._state = CursorStatus.finished
171+
raise e
155172
return True
156173

174+
async def finish_query(self):
175+
self._check_cursor_closed()
176+
177+
if not self._state == CursorStatus.running:
178+
return
179+
180+
next_set_available = True
181+
while next_set_available:
182+
next_set_available = await self.nextset()
183+
184+
self._state = CursorStatus.finished
185+
157186
def setinputsizes(self):
158187
pass
159188

160189
def setoutputsize(self):
161190
pass
162191

163192
async def close(self):
164-
next_set_available = True
165-
while next_set_available:
166-
next_set_available = await self.nextset()
193+
if self._closed:
194+
return
167195

168-
@property
169-
def description(self):
170-
return self._description
196+
await self.finish_query()
197+
self._state = CursorStatus.closed
198+
self._closed = True
171199

172-
@property
173-
def rowcount(self):
174-
return self._rows_count
200+
def _begin_query(self):
201+
self._state = CursorStatus.running
202+
203+
def _check_pending_query(self):
204+
if self._state == CursorStatus.running:
205+
raise ProgrammingError(
206+
"Some records have not been fetched. "
207+
"Fetch the remaining records before executing the next query."
208+
)
209+
210+
def _check_cursor_closed(self):
211+
if self._state == CursorStatus.closed:
212+
raise InterfaceError(
213+
"Could not perform operation: Cursor is closed."
214+
)
215+
216+
async def __aenter__(self):
217+
return self
218+
219+
async def __aexit__(self, exc_type, exc, tb):
220+
await self.close()

ydb_dbapi/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
import functools
23
from typing import Iterator
34
import ydb
@@ -50,6 +51,13 @@ async def wrapper(*args, **kwargs):
5051
return wrapper
5152

5253

54+
class CursorStatus(str, Enum):
55+
ready = "ready"
56+
running = "running"
57+
finished = "finished"
58+
closed = "closed"
59+
60+
5361
class AsyncFromSyncIterator:
5462
def __init__(self, sync_iter: Iterator):
5563
self._sync_iter = sync_iter

0 commit comments

Comments
 (0)