Skip to content

Commit e790cce

Browse files
committed
UUID to ObjectId conversion in apiserver
1 parent c05b628 commit e790cce

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

src/apiserver/api_models.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
import uuid
32
import pydantic
43
import typing
54
import db_models
@@ -11,17 +10,13 @@ class ApiVersion(pydantic.BaseModel):
1110

1211

1312
class Run(pydantic.BaseModel):
14-
id: uuid.UUID
15-
"""
16-
Run identifier, unique in the system.
17-
"""
1813
toolchain_name: str
1914
problem_name: str
20-
user_id: uuid.UUID
15+
user_id: db_models.Id
2116
contest_name: str
2217
status: typing.Mapping[str, str] = pydantic.Field(default_factory=dict)
2318

2419
@staticmethod
2520
def from_db(doc: db_models.RunMainProj) -> Run:
26-
return Run(id=doc['id'], toolchain_name=doc['toolchain_name'],
21+
return Run(toolchain_name=doc['toolchain_name'],
2722
user_id=doc['user_id'], contest_name=doc['contest_name'], problem_name=doc['problem_name'], status=doc['status'])

src/apiserver/db_models.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
1-
import uuid
1+
from bson import ObjectId
22
from enum import Enum
3-
import time
43
import typing
54
from pydantic import BaseModel, Field
65

76

7+
class Id(ObjectId):
8+
@classmethod
9+
def __get_validators__(cls):
10+
yield cls.validate
11+
12+
@classmethod
13+
def validate(cls, v):
14+
print(type(v))
15+
if not isinstance(v, ObjectId):
16+
raise TypeError('ObjectId required')
17+
return str(v)
18+
19+
@classmethod
20+
def __modify_schema__(cls, schema):
21+
schema.update({
22+
'Title': 'MongoDB ObjectID',
23+
'type': 'string'
24+
})
25+
26+
827
class RunPhase(Enum):
928
"""
1029
# QUEUED
@@ -24,10 +43,9 @@ class RunPhase(Enum):
2443

2544

2645
class RunMainProj(BaseModel):
27-
id: uuid.UUID
2846
toolchain_name: str
2947
problem_name: str
30-
user_id: uuid.UUID
48+
user_id: Id
3149
contest_name: str
3250
phase: str # RunPhase
3351
status: typing.Mapping[str, str] = Field(default_factory=dict)
@@ -37,7 +55,7 @@ class RunMainProj(BaseModel):
3755
"""
3856

3957

40-
RunMainProj.FIELDS = ['id', 'toolchain_name',
58+
RunMainProj.FIELDS = ['toolchain_name',
4159
'problem_name', 'user_id', 'contest_name', 'status']
4260

4361

src/apiserver/routes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import fastapi
22
import db_models
33
import api_models
4-
import uuid
54
import typing
65
import base64
76
import pymongo
7+
from bson import ObjectId
88
import pydantic
99

1010

@@ -97,15 +97,15 @@ def route_submit(params: RunSubmitSimpleParams, db: pymongo.database.Database =
9797
fields of request body; `id` will be real id of this run.
9898
"""
9999

100-
run_uuid = uuid.uuid4()
101-
user_id = uuid.UUID('12345678123456781234567812345678')
102-
doc_main = db_models.RunMainProj(id=run_uuid, toolchain_name=params.toolchain,
100+
user_id = ObjectId('507f1f77bcf86cd799439011')
101+
doc_main = db_models.RunMainProj(toolchain_name=params.toolchain,
103102
problem_name=params.problem, user_id=user_id, contest_name=params.contest, phase=str(db_models.RunPhase.QUEUED))
104103
doc_source = db_models.RunSourceProj(
105104
source=base64.b64decode(params.code))
106105
doc = {**dict(doc_main), **dict(doc_source)}
107106
db.runs.insert_one(doc)
108-
return api_models.Run(id=run_uuid, toolchain_name=params.toolchain, problem_name=params.problem, user_id=user_id, contest_name=params.contest)
107+
print(type(user_id))
108+
return api_models.Run(toolchain_name=params.toolchain, problem_name=params.problem, user_id=user_id, contest_name=params.contest)
109109

110110
@app.get('/runs', response_model=typing.List[api_models.Run],
111111
operation_id='listRuns')
@@ -122,7 +122,7 @@ def route_list_runs(db: pymongo.database.Database = fastapi.Depends(db_connect))
122122
return runs
123123

124124
@app.get('/runs/{run_id}', response_model=api_models.Run, operation_id='getRun')
125-
def route_get_run(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Depends(db_connect)):
125+
def route_get_run(run_id: db_models.Id, db: pymongo.database.Database = fastapi.Depends(db_connect)):
126126
"""
127127
Loads run by id
128128
"""
@@ -139,7 +139,7 @@ def route_get_run(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Dep
139139
'description': "Run source is not available"
140140
}
141141
})
142-
def route_get_run_source(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Depends(db_connect)):
142+
def route_get_run_source(run_id: db_models.Id, db: pymongo.database.Database = fastapi.Depends(db_connect)):
143143
"""
144144
Returns run source as base64-encoded JSON string
145145
"""
@@ -154,7 +154,7 @@ def route_get_run_source(run_id: uuid.UUID, db: pymongo.database.Database = fast
154154
return base64.b64encode(doc['source'])
155155

156156
@app.patch('/runs/{run_id}', response_model=api_models.Run, operation_id='patchRun')
157-
def route_run_patch(run_id: uuid.UUID, patch: RunPatch, db: pymongo.database.Database = fastapi.Depends(db_connect)):
157+
def route_run_patch(run_id: db_models.Id, patch: RunPatch, db: pymongo.database.Database = fastapi.Depends(db_connect)):
158158
"""
159159
Modifies existing run
160160

0 commit comments

Comments
 (0)