Skip to content

Commit 78ad6bc

Browse files
committed
Add profiling for sync.
1 parent fddc68f commit 78ad6bc

File tree

5 files changed

+987
-25
lines changed

5 files changed

+987
-25
lines changed

gel/_internal/_save.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def make_save_executor_constructor(
10741074
refetch: bool,
10751075
warn_on_large_sync_set: bool = False,
10761076
save_postcheck: bool = False,
1077+
executor_type: type,
10771078
) -> Callable[[], SaveExecutor]:
10781079
(
10791080
create_batches,
@@ -1085,7 +1086,7 @@ def make_save_executor_constructor(
10851086
refetch=refetch,
10861087
warn_on_large_sync_set=warn_on_large_sync_set,
10871088
)
1088-
return lambda: SaveExecutor(
1089+
return lambda: executor_type(
10891090
objs=objs,
10901091
create_batches=create_batches,
10911092
updates=updates,

gel/_testbase.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,17 +479,21 @@ def make_test_client(
479479
if cls.is_client_async
480480
else blocking_client.BlockingIOConnection
481481
)
482-
client_class = (
483-
TestAsyncIOClient
484-
if issubclass(connection_class, asyncio_client.AsyncIOConnection)
485-
else TestClient
486-
)
482+
client_class = cls._get_client_class(connection_class)
487483
return client_class(
488484
connection_class=connection_class,
489485
max_concurrency=1,
490486
**conargs,
491487
)
492488

489+
@classmethod
490+
def _get_client_class(cls, connection_class):
491+
return (
492+
TestAsyncIOClient
493+
if issubclass(connection_class, asyncio_client.AsyncIOConnection)
494+
else TestClient
495+
)
496+
493497
@classmethod
494498
def get_connect_args(
495499
cls, *, cluster=None, database="edgedb", user="edgedb", password="test"

gel/asyncio_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .protocol import asyncio_proto # type: ignore [attr-defined, unused-ignore]
3838
from .protocol.protocol import InputLanguage, OutputFormat
3939

40-
from ._internal._save import make_save_executor_constructor
40+
from ._internal._save import make_save_executor_constructor, SaveExecutor
4141

4242
if typing.TYPE_CHECKING:
4343
from ._internal._qbmodel._pydantic import GelModel
@@ -600,6 +600,7 @@ class AsyncIOClient(
600600

601601
__slots__ = ()
602602
_impl_class = _AsyncIOPoolImpl
603+
_save_executor_type = SaveExecutor
603604

604605
async def check_connection(self) -> base_client.ConnectionInfo:
605606
return await self._impl.ensure_connected()
@@ -647,6 +648,7 @@ async def _save_impl(
647648
refetch=refetch,
648649
save_postcheck=opts.save_postcheck,
649650
warn_on_large_sync_set=warn_on_large_sync_set,
651+
executor_type=self._save_executor_type,
650652
)
651653

652654
async for tx in self._batch():

gel/blocking_client.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@
3939
from .protocol import blocking_proto # type: ignore [attr-defined, unused-ignore]
4040
from .protocol.protocol import InputLanguage, OutputFormat
4141

42-
from ._internal._save import make_save_executor_constructor
42+
from ._internal._save import (
43+
QueryBatch,
44+
QueryRefetch,
45+
SaveExecutor,
46+
make_save_executor_constructor,
47+
)
4348

4449
if typing.TYPE_CHECKING:
4550
from ._internal._qbmodel._pydantic import GelModel
@@ -651,6 +656,7 @@ class Client(
651656

652657
__slots__ = ()
653658
_impl_class = _PoolImpl
659+
_save_executor_type = SaveExecutor
654660

655661
def _save_impl(
656662
self,
@@ -659,12 +665,9 @@ def _save_impl(
659665
objs: tuple[GelModel, ...],
660666
warn_on_large_sync_set: bool = False,
661667
) -> None:
662-
opts = self._get_debug_options()
663-
664-
make_executor = make_save_executor_constructor(
665-
objs,
668+
make_executor = self._get_make_save_executor(
666669
refetch=refetch,
667-
save_postcheck=opts.save_postcheck,
670+
objs=objs,
668671
warn_on_large_sync_set=warn_on_large_sync_set,
669672
)
670673

@@ -674,28 +677,74 @@ def _save_impl(
674677

675678
with executor:
676679
for batches in executor:
677-
for batch in batches:
678-
tx.send_query(batch.query, batch.args)
679-
batch_ids = tx.wait()
680+
batch_ids = self._send_batch_queries(tx, batches)
680681
for ids, batch in zip(batch_ids, batches, strict=True):
681682
batch.feed_db_data(ids)
682683

683684
if refetch:
684685
ref_queries = executor.get_refetch_queries()
685-
for ref in ref_queries:
686-
tx.send_query(
687-
ref.query,
688-
spec=ref.args.spec,
689-
new=ref.args.new,
690-
existing=ref.args.existing,
691-
)
692-
693-
refetch_data = tx.wait()
686+
refetch_data = self._send_refetch_queries(
687+
tx, ref_queries
688+
)
694689
for ref_data, ref in zip(
695690
refetch_data, ref_queries, strict=True
696691
):
697692
ref.feed_db_data(ref_data)
698693

694+
def _get_make_save_executor(
695+
self,
696+
*,
697+
refetch: bool,
698+
objs: tuple[GelModel, ...],
699+
warn_on_large_sync_set: bool = False,
700+
) -> typing.Callable[[], SaveExecutor]:
701+
opts = self._get_debug_options()
702+
703+
return make_save_executor_constructor(
704+
objs,
705+
refetch=refetch,
706+
save_postcheck=opts.save_postcheck,
707+
warn_on_large_sync_set=warn_on_large_sync_set,
708+
executor_type=self._save_executor_type,
709+
)
710+
711+
def _send_batch_queries(
712+
self,
713+
tx: BatchIteration,
714+
batches: list[QueryBatch],
715+
) -> list[Any]:
716+
for batch in batches:
717+
self._send_batch_query(tx, batch)
718+
return tx.wait()
719+
720+
def _send_refetch_queries(
721+
self,
722+
tx: BatchIteration,
723+
ref_queries: list[QueryRefetch],
724+
) -> list[Any]:
725+
for ref in ref_queries:
726+
self._send_refetch_query(tx, ref)
727+
return tx.wait()
728+
729+
def _send_batch_query(
730+
self,
731+
tx: BatchIteration,
732+
batch: QueryBatch,
733+
) -> None:
734+
tx.send_query(batch.query, batch.args)
735+
736+
def _send_refetch_query(
737+
self,
738+
tx: BatchIteration,
739+
ref: QueryRefetch,
740+
) -> None:
741+
tx.send_query(
742+
ref.query,
743+
spec=ref.args.spec,
744+
new=ref.args.new,
745+
existing=ref.args.existing,
746+
)
747+
699748
def save(
700749
self,
701750
*objs: GelModel,
@@ -723,6 +772,7 @@ def __debug_save__(self, *objs: GelModel) -> SaveDebug:
723772
make_executor = make_save_executor_constructor(
724773
objs,
725774
refetch=False, # TODO
775+
executor_type=self._save_executor_type,
726776
)
727777
plan_time = time.monotonic_ns() - ns
728778

0 commit comments

Comments
 (0)