Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
services:
web:
user: root # Run as root for tests to allow global package installation
environment:
- PYTHONPATH=/usr/local/lib/python3.11/site-packages
command: bash -c "pip install faker pytest-asyncio pytest-mock && pytest tests/ -v"
volumes:
- ./tests:/code/tests
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"uvloop>=0.19.0",
"httptools>=0.6.1",
"uuid>=1.30",
"uuid6>=2024.1.12",
"alembic>=1.13.1",
"asyncpg>=0.29.0",
"SQLAlchemy-Utils>=0.41.1",
Expand Down Expand Up @@ -117,4 +118,4 @@ explicit_package_bases = true

[[tool.mypy.overrides]]
module = "src.app.*"
disallow_untyped_defs = true
disallow_untyped_defs = true
6 changes: 5 additions & 1 deletion src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ async def get_current_user(
user = await crud_users.get(db=db, username=token_data.username_or_email, is_deleted=False)

if user:
return cast(dict[str, Any], user)
# Ensure consistent return type - always return dict
if hasattr(user, 'model_dump'): # It's a Pydantic model
return user.model_dump()
else: # It's already a dict
return user

raise UnauthorizedException("User not authenticated.")

Expand Down
42 changes: 36 additions & 6 deletions src/app/api/v1/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ async def read_posts(
page: int = 1,
items_per_page: int = 10,
) -> dict:
db_user = await crud_users.get(db=db, username=username, is_deleted=False, schema_to_select=UserRead)
db_user = await crud_users.get(
db=db,
username=username,
is_deleted=False,
schema_to_select=UserRead,
return_as_model=True
)
if not db_user:
raise NotFoundException("User not found")

Expand All @@ -81,7 +87,13 @@ async def read_posts(
async def read_post(
request: Request, username: str, id: int, db: Annotated[AsyncSession, Depends(async_get_db)]
) -> PostRead:
db_user = await crud_users.get(db=db, username=username, is_deleted=False, schema_to_select=UserRead)
db_user = await crud_users.get(
db=db,
username=username,
is_deleted=False,
schema_to_select=UserRead,
return_as_model=True
)
if db_user is None:
raise NotFoundException("User not found")

Expand All @@ -105,7 +117,13 @@ async def patch_post(
current_user: Annotated[dict, Depends(get_current_user)],
db: Annotated[AsyncSession, Depends(async_get_db)],
) -> dict[str, str]:
db_user = await crud_users.get(db=db, username=username, is_deleted=False, schema_to_select=UserRead)
db_user = await crud_users.get(
db=db,
username=username,
is_deleted=False,
schema_to_select=UserRead,
return_as_model=True
)
if db_user is None:
raise NotFoundException("User not found")

Expand All @@ -130,7 +148,13 @@ async def erase_post(
current_user: Annotated[dict, Depends(get_current_user)],
db: Annotated[AsyncSession, Depends(async_get_db)],
) -> dict[str, str]:
db_user = await crud_users.get(db=db, username=username, is_deleted=False, schema_to_select=UserRead)
db_user = await crud_users.get(
db=db,
username=username,
is_deleted=False,
schema_to_select=UserRead,
return_as_model=True
)
if db_user is None:
raise NotFoundException("User not found")

Expand All @@ -152,7 +176,13 @@ async def erase_post(
async def erase_db_post(
request: Request, username: str, id: int, db: Annotated[AsyncSession, Depends(async_get_db)]
) -> dict[str, str]:
db_user = await crud_users.get(db=db, username=username, is_deleted=False, schema_to_select=UserRead)
db_user = await crud_users.get(
db=db,
username=username,
is_deleted=False,
schema_to_select=UserRead,
return_as_model=True
)
if db_user is None:
raise NotFoundException("User not found")

Expand All @@ -161,4 +191,4 @@ async def erase_db_post(
raise NotFoundException("Post not found")

await crud_posts.db_delete(db=db, id=id)
return {"message": "Post deleted from the database"}
return {"message": "Post deleted from the database"}
32 changes: 21 additions & 11 deletions src/app/api/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ async def read_user(request: Request, username: str, db: Annotated[AsyncSession,
return cast(UserRead, db_user)


# In src/app/api/v1/users.py, replace the patch_user function with this:

@router.patch("/user/{username}")
async def patch_user(
request: Request,
Expand All @@ -80,24 +82,32 @@ async def patch_user(
current_user: Annotated[dict, Depends(get_current_user)],
db: Annotated[AsyncSession, Depends(async_get_db)],
) -> dict[str, str]:
db_user = await crud_users.get(db=db, username=username, schema_to_select=UserRead)
db_user = await crud_users.get(db=db, username=username)
if db_user is None:
raise NotFoundException("User not found")

db_user = cast(UserRead, db_user)
if db_user.username != current_user["username"]:
raise ForbiddenException()
# Handle both dict and UserRead object types
if isinstance(db_user, dict):
db_username = db_user["username"]
db_email = db_user["email"]
else:
db_username = db_user.username
db_email = db_user.email

if values.username != db_user.username:
existing_username = await crud_users.exists(db=db, username=values.username)
if existing_username:
raise DuplicateValueException("Username not available")
if db_username != current_user["username"]:
raise ForbiddenException()

if values.email != db_user.email:
existing_email = await crud_users.exists(db=db, email=values.email)
if existing_email:
# Check for email conflicts if email is being updated
if values.email is not None and values.email != db_email:
if await crud_users.exists(db=db, email=values.email):
raise DuplicateValueException("Email is already registered")

# Check for username conflicts if username is being updated
if values.username is not None and values.username != db_username:
if await crud_users.exists(db=db, username=values.username):
raise DuplicateValueException("Username not available")

# Update the user
await crud_users.update(db=db, object=values, username=username)
return {"message": "User updated"}

Expand Down
1 change: 1 addition & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class EnvironmentSettings(BaseSettings):

class Settings(
AppSettings,
SQLiteSettings,
PostgresSettings,
CryptSettings,
FirstUserSettings,
Expand Down
3 changes: 2 additions & 1 deletion src/app/core/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ class Base(DeclarativeBase, MappedAsDataclass):
pass


DATABASE_URI = settings.POSTGRES_URI
DATABASE_URI = settings.POSTGRES_URI
DATABASE_PREFIX = settings.POSTGRES_ASYNC_PREFIX
DATABASE_URL = f"{DATABASE_PREFIX}{DATABASE_URI}"


async_engine = create_async_engine(DATABASE_URL, echo=False, future=True)

local_session = async_sessionmaker(bind=async_engine, class_=AsyncSession, expire_on_commit=False)
Expand Down
3 changes: 2 additions & 1 deletion src/app/core/db/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid as uuid_pkg
from uuid6 import uuid7 #126
from datetime import UTC, datetime

from sqlalchemy import Boolean, DateTime, text
Expand All @@ -8,7 +9,7 @@

class UUIDMixin:
uuid: Mapped[uuid_pkg.UUID] = mapped_column(
UUID, primary_key=True, default=uuid_pkg.uuid4, server_default=text("gen_random_uuid()")
UUID(as_uuid=True), primary_key=True, default=uuid7, server_default=text("gen_random_uuid()")
)


Expand Down
3 changes: 2 additions & 1 deletion src/app/core/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid as uuid_pkg
from uuid6 import uuid7 #126
from datetime import UTC, datetime
from typing import Any

Expand All @@ -13,7 +14,7 @@ class HealthCheck(BaseModel):

# -------------- mixins --------------
class UUIDSchema(BaseModel):
uuid: uuid_pkg.UUID = Field(default_factory=uuid_pkg.uuid4)
uuid: uuid_pkg.UUID = Field(default_factory=uuid7)


class TimestampSchema(BaseModel):
Expand Down
10 changes: 6 additions & 4 deletions src/app/core/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ async def _delete_keys_by_pattern(pattern: str) -> None:
many keys simultaneously may impact the performance of the Redis server.
"""
if client is None:
raise MissingClientError

cursor = -1
while cursor != 0:
return
cursor = 0 # Make sure cursor starts at 0
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=100)
if keys:
await client.delete(*keys)
if cursor == 0: # cursor returns to 0 when scan is complete
break


def cache(
Expand Down
5 changes: 3 additions & 2 deletions src/app/models/post.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid as uuid_pkg
from datetime import UTC, datetime
from uuid6 import uuid7 #126

from sqlalchemy import DateTime, ForeignKey, String
from sqlalchemy import DateTime, ForeignKey, String,UUID
from sqlalchemy.orm import Mapped, mapped_column

from ..core.db.database import Base
Expand All @@ -14,7 +15,7 @@ class Post(Base):
created_by_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), index=True)
title: Mapped[str] = mapped_column(String(30))
text: Mapped[str] = mapped_column(String(63206))
uuid: Mapped[uuid_pkg.UUID] = mapped_column(default_factory=uuid_pkg.uuid4, primary_key=True, unique=True)
uuid: Mapped[uuid_pkg.UUID] = mapped_column(UUID(as_uuid=True),default_factory=uuid7, unique=True)
media_url: Mapped[str | None] = mapped_column(String, default=None)

created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default_factory=lambda: datetime.now(UTC))
Expand Down
13 changes: 8 additions & 5 deletions src/app/models/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import uuid as uuid_pkg
from uuid6 import uuid7
from datetime import UTC, datetime
import uuid as uuid_pkg

from sqlalchemy import DateTime, ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column

from ..core.db.database import Base
Expand All @@ -10,19 +12,20 @@
class User(Base):
__tablename__ = "user"

id: Mapped[int] = mapped_column("id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False)

# Option 1: Use integer ID as primary key (recommended for compatibility)
id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True, init=False)

name: Mapped[str] = mapped_column(String(30))
username: Mapped[str] = mapped_column(String(20), unique=True, index=True)
email: Mapped[str] = mapped_column(String(50), unique=True, index=True)
hashed_password: Mapped[str] = mapped_column(String)

profile_image_url: Mapped[str] = mapped_column(String, default="https://profileimageurl.com")
uuid: Mapped[uuid_pkg.UUID] = mapped_column(default_factory=uuid_pkg.uuid4, primary_key=True, unique=True)
uuid: Mapped[uuid_pkg.UUID] = mapped_column(UUID(as_uuid=True), default_factory=uuid7, unique=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default_factory=lambda: datetime.now(UTC))
updated_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), default=None)
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), default=None)
is_deleted: Mapped[bool] = mapped_column(default=False, index=True)
is_superuser: Mapped[bool] = mapped_column(default=False)

tier_id: Mapped[int | None] = mapped_column(ForeignKey("tier.id"), index=True, default=None, init=False)
tier_id: Mapped[int | None] = mapped_column(ForeignKey("tier.id"), index=True, default=None, init=False)
4 changes: 2 additions & 2 deletions src/scripts/create_first_superuser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
import uuid
from uuid6 import uuid7 #126
from datetime import UTC, datetime

from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, MetaData, String, Table, insert, select
Expand Down Expand Up @@ -37,7 +37,7 @@ async def create_first_user(session: AsyncSession) -> None:
Column("email", String(50), nullable=False, unique=True, index=True),
Column("hashed_password", String, nullable=False),
Column("profile_image_url", String, default="https://profileimageurl.com"),
Column("uuid", UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True),
Column("uuid", UUID(as_uuid=True), default=uuid7, unique=True),
Column("created_at", DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False),
Column("updated_at", DateTime),
Column("deleted_at", DateTime),
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def sample_user_data():
@pytest.fixture
def sample_user_read():
"""Generate a sample UserRead object."""
import uuid
from uuid6 import uuid7

from src.app.schemas.user import UserRead

return UserRead(
id=1,
uuid=uuid.uuid4(),
uuid=uuid7(),
name=fake.name(),
username=fake.user_name(),
email=fake.email(),
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/generators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import uuid as uuid_pkg
from uuid6 import uuid7 #126

from sqlalchemy.orm import Session

Expand All @@ -14,7 +14,7 @@ def create_user(db: Session, is_super_user: bool = False) -> models.User:
email=fake.email(),
hashed_password=get_password_hash(fake.password()),
profile_image_url=fake.image_url(),
uuid=uuid_pkg.uuid4(),
uuid=uuid7,
is_superuser=is_super_user,
)

Expand Down
13 changes: 9 additions & 4 deletions tests/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ class TestPatchUser:
async def test_patch_user_success(self, mock_db, current_user_dict, sample_user_read):
"""Test successful user update."""
username = current_user_dict["username"]
sample_user_read.username = username # Make sure usernames match
user_update = UserUpdate(name="New Name")

# Convert the UserRead model to a dictionary for the mock
user_dict = sample_user_read.model_dump()
user_dict["username"] = username

with patch("src.app.api.v1.users.crud_users") as mock_crud:
mock_crud.get = AsyncMock(return_value=sample_user_read)
mock_crud.get = AsyncMock(return_value=user_dict) # Return dict instead of UserRead
mock_crud.exists = AsyncMock(return_value=False) # No conflicts
mock_crud.update = AsyncMock(return_value=None)

Expand All @@ -134,11 +137,13 @@ async def test_patch_user_success(self, mock_db, current_user_dict, sample_user_
async def test_patch_user_forbidden(self, mock_db, current_user_dict, sample_user_read):
"""Test user update when user tries to update another user."""
username = "different_user"
sample_user_read.username = username
user_update = UserUpdate(name="New Name")
# Convert the UserRead model to a dictionary for the mock
user_dict = sample_user_read.model_dump()
user_dict["username"] = username

with patch("src.app.api.v1.users.crud_users") as mock_crud:
mock_crud.get = AsyncMock(return_value=sample_user_read)
mock_crud.get = AsyncMock(return_value=user_dict) # Return dict instead of UserRead

with pytest.raises(ForbiddenException):
await patch_user(Mock(), user_update, username, current_user_dict, mock_db)
Expand Down
Loading