Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gel/_internal/_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,13 +1172,14 @@ def make_save_executor_constructor(
refetch: bool,
warn_on_large_sync_set: bool = False,
save_postcheck: bool = False,
executor_type: type,
) -> Callable[[], SaveExecutor]:
plan = make_plan(
objs,
refetch=refetch,
warn_on_large_sync_set=warn_on_large_sync_set,
)
return lambda: SaveExecutor(
return lambda: executor_type(
objs=objs,
create_batches=plan.create_batches,
updates=plan.update_batch,
Expand Down
36 changes: 30 additions & 6 deletions gel/_internal/_testbase/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,17 @@
if self._instance is not None:
return

if self._server_addr is not None:
server_addr = self._server_addr
if server_addr is None:
server_addr = {
"host": "localhost",
"port": 5656,
"tls_ca_file": "/home/dnwpark/work/dev-3.12/edgedb/tmp/devdatadir/edbtlscert.pem",

Check failure on line 364 in gel/_internal/_testbase/_base.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

gel/_internal/_testbase/_base.py:364:80: E501 Line too long (98 > 79)
}

if server_addr is not None:
await self._set_up_running_instance(
self._server_addr,
server_addr,
self._server_version,
)
else:
Expand Down Expand Up @@ -634,6 +642,7 @@
# of 10s.
DEFAULT_CONNECT_TIMEOUT = 30

CLIENT_TYPE: ClassVar[type[TestClient | TestAsyncIOClient] | None]
client: ClassVar[TestClient | TestAsyncIOClient]

@classmethod
Expand Down Expand Up @@ -682,7 +691,9 @@
if self.ISOLATED_TEST_BRANCHES:
cls = type(self)
testdb = cls.loop.run_until_complete(self.setup_branch_copy())
client = cls.make_test_client(database=testdb)._with_debug(
client = cls.make_test_client(
database=testdb, client_class=self.CLIENT_TYPE
)._with_debug(
save_postcheck=True,
)
self.client = client # type: ignore[misc]
Expand Down Expand Up @@ -721,6 +732,7 @@
def make_test_client(
cls,
*,
client_class: type[TestClient | TestAsyncIOClient] | None = None,
connection_class: type[
asyncio_client.AsyncIOConnection
| blocking_client.BlockingIOConnection
Expand Down Expand Up @@ -762,14 +774,17 @@
cls,
*,
instance: _server.BaseInstance,
client_class: type[TestClient] | None = None,
connection_class: type[blocking_client.BlockingIOConnection]
| None = None,
**kwargs: Any,
) -> TestClient:
if client_class is None:
client_class = TestClient
if connection_class is None:
connection_class = blocking_client.BlockingIOConnection
client = instance.create_blocking_client(
client_class=TestClient,
client_class=client_class,
connection_class=connection_class,
**cls.get_connect_args(instance, **kwargs),
)
Expand Down Expand Up @@ -805,13 +820,16 @@
cls,
*,
instance: _server.BaseInstance,
client_class: type[TestAsyncIOClient] | None = None,
connection_class: type[asyncio_client.AsyncIOConnection] | None = None,
**kwargs: Any,
) -> TestAsyncIOClient:
if client_class is None:
client_class = TestAsyncIOClient
if connection_class is None:
connection_class = asyncio_client.AsyncIOConnection
client = instance.create_async_client(
client_class=TestAsyncIOClient,
client_class=client_class,
connection_class=connection_class,
**cls.get_connect_args(instance, **kwargs),
)
Expand Down Expand Up @@ -889,7 +907,9 @@
await cls._create_empty_branch(dbname)

if not cls.ISOLATED_TEST_BRANCHES:
cls.client = cls.make_test_client(database=dbname)
cls.client = cls.make_test_client(
database=dbname, client_class=cls.CLIENT_TYPE
)
if isinstance(cls.client, gel.AsyncIOClient):
await cls.client.ensure_connected()
else:
Expand Down Expand Up @@ -1029,11 +1049,13 @@
def make_test_client( # pyright: ignore [reportIncompatibleMethodOverride]
cls,
*,
client_class: type[TestAsyncIOClient] | None = None,
connection_class: type[asyncio_client.AsyncIOConnection] | None = None, # type: ignore [override]
**kwargs: str,
) -> TestAsyncIOClient:
return cls.make_async_test_client(
instance=cls.instance,
client_class=client_class,
connection_class=connection_class,
**kwargs,
)
Expand Down Expand Up @@ -1070,12 +1092,14 @@
def make_test_client( # pyright: ignore [reportIncompatibleMethodOverride]
cls,
*,
client_class: type[TestClient] | None = None,
connection_class: type[blocking_client.BlockingIOConnection] # type: ignore [override]
| None = None,
**kwargs: str,
) -> TestClient:
return cls.make_blocking_test_client(
instance=cls.instance,
client_class=client_class,
connection_class=connection_class,
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion gel/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .protocol import asyncio_proto # type: ignore [attr-defined, unused-ignore]
from .protocol.protocol import InputLanguage, OutputFormat

from ._internal._save import make_save_executor_constructor
from ._internal._save import make_save_executor_constructor, SaveExecutor

if typing.TYPE_CHECKING:
from ._internal._qbmodel._pydantic import GelModel
Expand Down Expand Up @@ -675,6 +675,7 @@ async def _save_impl(
refetch=refetch,
save_postcheck=opts.save_postcheck,
warn_on_large_sync_set=warn_on_large_sync_set,
executor_type=SaveExecutor,
)

async for tx in self._batch():
Expand Down
83 changes: 65 additions & 18 deletions gel/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from .protocol import blocking_proto # type: ignore [attr-defined, unused-ignore]
from .protocol.protocol import InputLanguage, OutputFormat

from ._internal._save import make_save_executor_constructor
from ._internal._save import (
QueryBatch,
QueryRefetch,
SaveExecutor,
make_save_executor_constructor,
)

if typing.TYPE_CHECKING:
from ._internal._qbmodel._pydantic import GelModel
Expand Down Expand Up @@ -681,6 +686,7 @@ class Client(

__slots__ = ()
_impl_class = _PoolImpl
_save_executor_type = SaveExecutor

def _save_impl(
self,
Expand All @@ -689,12 +695,9 @@ def _save_impl(
objs: tuple[GelModel, ...],
warn_on_large_sync_set: bool = False,
) -> None:
opts = self._get_debug_options()

make_executor = make_save_executor_constructor(
objs,
make_executor = self._get_make_save_executor(
refetch=refetch,
save_postcheck=opts.save_postcheck,
objs=objs,
warn_on_large_sync_set=warn_on_large_sync_set,
)

Expand All @@ -703,23 +706,13 @@ def _save_impl(
executor = make_executor()

for batches in executor:
for batch in batches:
tx.send_query(batch.query, batch.args)
batch_ids = tx.wait()
batch_ids = self._send_batch_queries(tx, batches)
for ids, batch in zip(batch_ids, batches, strict=True):
batch.record_inserted_data(ids)

if refetch:
ref_queries = executor.get_refetch_queries()
for ref in ref_queries:
tx.send_query(
ref.query,
spec=ref.args.spec,
new=ref.args.new,
existing=ref.args.existing,
)

refetch_data = tx.wait()
refetch_data = self._send_refetch_queries(tx, ref_queries)

for ref_data, ref in zip(
refetch_data, ref_queries, strict=True
Expand All @@ -728,6 +721,60 @@ def _save_impl(

executor.commit()

def _get_make_save_executor(
self,
*,
refetch: bool,
objs: tuple[GelModel, ...],
warn_on_large_sync_set: bool = False,
) -> typing.Callable[[], SaveExecutor]:
opts = self._get_debug_options()

return make_save_executor_constructor(
objs,
refetch=refetch,
save_postcheck=opts.save_postcheck,
warn_on_large_sync_set=warn_on_large_sync_set,
executor_type=self._save_executor_type,
)

def _send_batch_queries(
self,
tx: BatchIteration,
batches: list[QueryBatch],
) -> list[Any]:
for batch in batches:
self._send_batch_query(tx, batch)
return tx.wait()

def _send_refetch_queries(
self,
tx: BatchIteration,
ref_queries: list[QueryRefetch],
) -> list[Any]:
for ref in ref_queries:
self._send_refetch_query(tx, ref)
return tx.wait()

def _send_batch_query(
self,
tx: BatchIteration,
batch: QueryBatch,
) -> None:
tx.send_query(batch.query, batch.args)

def _send_refetch_query(
self,
tx: BatchIteration,
ref: QueryRefetch,
) -> None:
tx.send_query(
ref.query,
spec=ref.args.spec,
new=ref.args.new,
existing=ref.args.existing,
)

def save(
self,
*objs: GelModel,
Expand Down
Loading
Loading