Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Can't have nullable relationship #227

Open
1 of 4 tasks
ste-pool opened this issue Jul 1, 2024 · 6 comments
Open
1 of 4 tasks

Bug: Can't have nullable relationship #227

ste-pool opened this issue Jul 1, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@ste-pool
Copy link

ste-pool commented Jul 1, 2024

Description

If I want a one to one relationship that is nullable I can create some tables like this:

class Foo(UUIDBase):                                                  
    __tablename__ = "foo"                                             
    bar: Mapped[Optional[Bar]] = relationship(back_populates="foo") 
                                                                      
class Bar(UUIDBase):                                                  
    __tablename__ = "bar"                                             
    foo_id: Mapped[UUID] = mapped_column(ForeignKey("foo.id"))        
    foo: Mapped[Foo] = relationship(back_populates="bar")             

However, if I try a get route in litestar to retrieve a row from the db where bar is not present in foo, I get:

TypeError: Missing required argument 'bar'

I think this is because the function detect_nullable_relationship doesn't check if it itself is mapped as optional (if the relationship is a one to many)?

URL to code causing the issue

No response

MCVE

from __future__ import annotations                                                                                                                                                                                                                                              
import traceback                                                                                                                                                                                                                                                                
from sqlalchemy.ext.asyncio import AsyncSession                                                                                                                                                                                                                                 
from sqlalchemy import select                                                                                                                                                                                                                                                   
from litestar import Litestar, get, post                                                                                                                                                                                                                                        
from litestar.contrib.sqlalchemy.plugins import SQLAlchemyAsyncConfig, SQLAlchemyPlugin                                                                                                                                                                                         
from sqlalchemy.orm import relationship                                                                                                                                                                                                                                         
from typing import Optional                                                                                                                                                                                                                                                     
from advanced_alchemy.base import UUIDBase                                                                                                                                                                                                                                      
from sqlalchemy import ForeignKey                                                                                                                                                                                                                                               
from uuid import UUID                                                                                                                                                                                                                                                           
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload                                                                                                                                                                                                    
from advanced_alchemy.config import AsyncSessionConfig                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                
connection_string = f"sqlite+aiosqlite:////tmp/test1.sqlite"                                                                                                                                                                                                                    
sqlalchemy_config = SQLAlchemyAsyncConfig(                                                                                                                                                                                                                                      
    create_all=True,                                                                                                                                                                                                                                                            
    session_config=AsyncSessionConfig(expire_on_commit=False),                                                                                                                                                                                                                  
    connection_string=connection_string,                                                                                                                                                                                                                                        
)                                                                                                                                                                                                                                                                               
sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                
class Foo(UUIDBase):                                                                                                                                                                                                                                                            
    __tablename__ = "foo"                                                                                                                                                                                                                                                       
    bar: Mapped[Optional[Bar]] = relationship(back_populates="foo")                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                
class Bar(UUIDBase):                                                                                                                                                                                                                                                            
    __tablename__ = "bar"                                                                                                                                                                                                                                                       
    foo_id: Mapped[UUID] = mapped_column(ForeignKey("foo.id"))                                                                                                                                                                                                                  
    foo: Mapped[Foo] = relationship(back_populates="bar")                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                
@post("/add")                                                                                                                                                                                                                                                                   
async def add_foo(db_session: AsyncSession) -> None:                                                                                                                                                                                                                            
    row = Foo()                                                                                                                                                                                                                                                                 
    db_session.add(row)                                                                                                                                                                                                                                                         
    await db_session.commit()                                                                                                                                                                                                                                                   
    return None                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                                
@get("/get_foo")                                                                                                                                                                                                                                                                
async def get_foo(db_session: AsyncSession) -> list[Foo]:                                                                                                                                                                                                                       
    resp = await db_session.scalars(select(Foo).options(selectinload(Foo.bar)))                                                                                                                                                                                                 
    return resp.all()                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                
def plain_text_exception_handler(_, exc):                                                                                                                                                                                                                                       
    print(traceback.format_exc())                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                
app = Litestar(                                                                                                                                                                                                                                                                 
    route_handlers=[add_foo, get_foo],                                                                                                                                                                                                                                          
    plugins=[sqlalchemy_plugin],                                                                                                                                                                                                                                                
    exception_handlers={Exception: plain_text_exception_handler},                                                                                                                                                                                                               
)

Package Version

0.16.0

Platform

  • Linux
  • Mac
  • Windows
  • Other (Please specify in the description above)
@ste-pool ste-pool added the bug Something isn't working label Jul 1, 2024
@provinzkraut
Copy link
Member

Are you sure this is the code you're running? I can't seem to reproduce this

@ste-pool
Copy link
Author

ste-pool commented Jul 1, 2024

Yup, can still see it :/

Just recreated in a new python 3.10.6 venv to be doubly sure. Copied code above as app.py

pip install litestar[sqlalchemy] aiosqlite uvicorn
uvicorn app:app &
curl -X POST localhost:8000/add
curl localhost:8000/get_foo
...
  File "<string>", line 2, in func
  File "<string>", line 2, in <genexpr>
  File "<string>", line 59, in func
TypeError: Missing required argument 'bar'
...

pip freeze:

advanced_alchemy==0.17.0
aiosqlite==0.20.0
alembic==1.13.2
anyio==4.4.0
certifi==2024.6.2
click==8.1.7
exceptiongroup==1.2.1
Faker==26.0.0
greenlet==3.0.3
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
idna==3.7
litestar==2.9.1
Mako==1.3.5
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
msgspec==0.18.6
multidict==6.0.5
polyfactory==2.16.0
Pygments==2.18.0
python-dateutil==2.9.0.post0
PyYAML==6.0.1
rich==13.7.1
rich-click==1.8.3
six==1.16.0
sniffio==1.3.1
SQLAlchemy==2.0.31
typing_extensions==4.12.2
uvicorn==0.30.1

Full traceback if it's helpful:

ERROR:    Exception in ASGI application                                                                                                                                                             [96/739]
Traceback (most recent call last):
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/middleware/_internal/exceptions/middleware.py", line 159, in __call__
    await self.app(scope, receive, capture_response_started)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/_asgi/asgi_router.py", line 99, in __call__
    await asgi_app(scope, receive, send)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/routes/http.py", line 80, in handle
    response = await self._get_response_for_request(
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/routes/http.py", line 132, in _get_response_for_request
    return await self._call_handler_function(
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/routes/http.py", line 156, in _call_handler_function
    response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/handlers/http_handlers/base.py", line 554, in to_response
    data = return_dto_type(request).data_to_encodable_type(data)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/dto/base_dto.py", line 101, in data_to_encodable_type
    return backend.encode_data(data)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/dto/_codegen_backend.py", line 161, in encode_data
    return cast("LitestarEncodableType", self._encode_data(data))
  File "<string>", line 2, in func
  File "<string>", line 2, in <genexpr>
  File "<string>", line 59, in func
TypeError: Missing required argument 'bar'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 396, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
    return await self.app(scope, receive, send)
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/app.py", line 591, in __call__
    await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope))  # type: ignore[arg-type]
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/middleware/_internal/exceptions/middleware.py", line 176, in __call__
    await self.handle_request_exception(
  File "/Users/anon/.pyenv/versions/3.10.6/envs/test2/lib/python3.10/site-packages/litestar/middleware/_internal/exceptions/middleware.py", line 208, in handle_request_exception
    await response.to_asgi_response(app=None, request=request, type_encoders=type_encoders)(
AttributeError: 'NoneType' object has no attribute 'to_asgi_response'

For experiments sake if I changed the detect_nullable_relationship return to:

return (elem.direction == RelationshipDirection.MANYTOONE and all(c.nullable for c in elem.local_columns)) or (elem.direction == RelationshipDirection.ONETOMANY)

then it works (I know it's not correct, I'm just not sure how to get the mapped type hint of elem)

@ste-pool
Copy link
Author

Any ideas on this? The workaround I suggested is obviously not the best!

@ftsartek
Copy link

I had a similar discussion (actually referencing this issue) in discord. I could reproduce the same behaviour with SQLAlchemy itself, so I think it's specific to Litestar rather than advanced alchemy.

Regardless, Cofin is aware of the issue: https://discord.com/channels/919193495116337154/1262488171023695919

@ste-pool
Copy link
Author

Ah okay, thanks! I'll keep a look out 👀

I wasn't sure where the problem was but given I could change a file in here to fix it, I assumed this was the place 🤦

@atom-andrew
Copy link

Setting experimental_codegen_backend=False in the DTOConfig allowed me to work around this problem when I encounted it (presumably at a significant cost in performance).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants