Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/strawchemy/sqlalchemy/repository/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from strawchemy.strawberry.typing import QueryNodeType
from strawchemy.typing import SupportedDialect

RowLike: TypeAlias = "Row[Any] | NamedTuple"


__all__ = ("InsertOrUpdate", "RowLike", "SQLAlchemyGraphQLRepository")

Expand Down Expand Up @@ -228,22 +230,42 @@ def _m2m_values(
self, model: DeclarativeBase, parent: Union[RowLike, DeclarativeBase], relationship: RelationshipProperty[Any]
) -> dict[str, Any]:
assert relationship.local_remote_pairs
return {
remote.key: getattr(model, local.key) if local.table is model.__table__ else getattr(parent, local.key)
for local, remote in relationship.local_remote_pairs
if local.key and remote.key
}

# Local optimization: avoid attribute access twice per key by pulling out __table__ and .__dict__
model_table = model.__table__
model_dict = model.__dict__
parent_dict = parent.__dict__ if hasattr(parent, "__dict__") else None

result = {}
for local, remote in relationship.local_remote_pairs:
if local.key and remote.key:
# Use __dict__ for direct attribute access if possible, avoids getattr descriptor resolution
if local.table is model_table:
value = model_dict[local.key]
else:
value = (
parent_dict[local.key]
if parent_dict is not None and local.key in parent_dict
else getattr(parent, local.key)
)
result[remote.key] = value
return result

def _update_values(
self, model: DeclarativeBase, parent: Union[RowLike, DeclarativeBase], relationship: RelationshipProperty[Any]
) -> dict[str, Any]:
assert relationship.local_remote_pairs
if relationship.secondary is None:
return {column.key: getattr(model, column.key) for column in model.__mapper__.primary_key if column.key} | {
remote.key: getattr(parent, local.key)
for local, remote in relationship.local_remote_pairs
if local.key and remote.key
}
# Local optimization: use dict update instead of dict union for performance,
# and re-use dicts for primary_key and remote mapping.
d = {}
for column in model.__mapper__.primary_key:
if column.key:
d[column.key] = getattr(model, column.key)
for local, remote in relationship.local_remote_pairs:
if local.key and remote.key:
d[remote.key] = getattr(parent, local.key)
return d
return self._m2m_values(model, parent, relationship)

def _to_one_nested_create_params(self, level: LevelInput) -> QueryParams:
Expand Down