Skip to content

Commit

Permalink
Fix grid table result type
Browse files Browse the repository at this point in the history
  • Loading branch information
BielStela committed Nov 22, 2024
1 parent e99bf14 commit 7a0b8e5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion api/app/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Validate API key."""
if credentials is None or credentials.credentials != get_settings().auth_token:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized")
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized")
6 changes: 5 additions & 1 deletion api/app/models/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated, Literal

from fastapi import Query
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from pydantic_extra_types.color import Color
from sqlalchemy.sql import column, desc, nullslast, select, table

Expand Down Expand Up @@ -138,10 +138,14 @@ def to_sql_query(self, table_name: str) -> str:
return str(query.compile(compile_kwargs={"literal_binds": True}))


H3Index = Annotated[str, Field(description="H3 cell index", examples=["81a8fffffffffff"])]


class TableResultColumn(BaseModel):
column: Annotated[str, Field(title="column", description="Column name")]
values: Annotated[list, Field(description="Check dataset metadata for type info")]


class TableResults(BaseModel):
table: list[TableResultColumn]
cells: list[H3Index]
17 changes: 10 additions & 7 deletions api/app/routers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydantic import ValidationError

from app.config.config import get_settings
from app.models.grid import MultiDatasetMeta, TableFilters, TableResults
from app.models.grid import MultiDatasetMeta, TableFilters, TableResultColumn, TableResults

log = logging.getLogger("uvicorn.error") # Show the logs in the uvicorn runner logs

Expand All @@ -36,7 +36,9 @@ class ArrowIPCResponse(Response): # noqa: D101
media_type = "application/octet-stream"


def get_tile(tile_index: str, columns: list[str]) -> tuple[pl.LazyFrame, int]:
def get_tile(
tile_index: Annotated[str, Path(description="The `h3` index of the tile")], columns: list[str]
) -> tuple[pl.LazyFrame, int]:
"""Get the tile from filesystem filtered by column and the resolution of the tile index"""
try:
z = h3.api.basic_str.h3_get_resolution(tile_index)
Expand Down Expand Up @@ -66,15 +68,14 @@ def polars_to_string_ipc(df: pl.DataFrame) -> bytes:
# a custom string type. As of today, the frontend library @loadrs.gl/arrow only supports
# `string` type so we need to downcast with pyarrow
table: pa.Table = df.to_arrow()

schema = table.schema
schema = schema.set(schema.get_field_index("cell"), pa.field("cell", pa.string()))
schema = table.schema.set(table.schema.get_field_index("cell"), pa.field("cell", pa.string()))
table = table.cast(schema)
sink = io.BytesIO()
with pa.ipc.new_file(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue()


@grid_router.get(
"/tile/{tile_index}",
summary="Get a grid tile",
Expand Down Expand Up @@ -176,5 +177,7 @@ def read_table(
except pl.exceptions.ComputeError as e: # raised if wrong type in compare.
log.exception(e)
raise HTTPException(status_code=422, detail=str(e)) from None

return TableResults(table=[{"column": k, "values": v} for k, v in res.to_dict(as_series=False).items()])
columns = res.to_dict(as_series=False)
table = [TableResultColumn(column=k, values=v) for k, v in columns.items() if k != "cell"]
cells = columns["cell"]
return TableResults(table=table, cells=cells)
8 changes: 4 additions & 4 deletions api/tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ def test_grid_table(grid_dataset):
assert response.status_code == 200
assert json.loads(response.read()) == {
"table": [
{"column": "cell", "values": ["865f00007ffffff", "895f4261e03ffff"]},
{"column": "landcover", "values": [4, 1]},
{"column": "population", "values": [200, 100]},
]
],
"cells": ["865f00007ffffff", "895f4261e03ffff"],
}


Expand All @@ -241,10 +241,10 @@ def test_grid_table_geojson(grid_dataset, geojson):
assert response.status_code == 200
assert json.loads(response.read()) == {
"table": [
{"column": "cell", "values": ["895f4261e03ffff"]},
{"column": "landcover", "values": [1]},
{"column": "population", "values": [100]},
]
],
"cells": ["895f4261e03ffff"],
}


Expand Down

0 comments on commit 7a0b8e5

Please sign in to comment.