From 3fa5dfa184a44588c8e12348341f37eeafe841af Mon Sep 17 00:00:00 2001 From: cjeongmin Date: Tue, 18 Feb 2025 17:16:07 +0900 Subject: [PATCH 1/2] feat: add batch prediction function for classifier model --- apps/classifier/model.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/apps/classifier/model.py b/apps/classifier/model.py index 68d1541..0576f31 100644 --- a/apps/classifier/model.py +++ b/apps/classifier/model.py @@ -49,3 +49,33 @@ def predict(text: str, type: str) -> tuple[str, float]: predicted_probability = probabilities[0, predicted_label].item() return inv_label_map[predicted_label], predicted_probability + + +def predict_batch(texts: list[str], type: str) -> list[tuple[str, float]]: + inputs = tokenizer( + texts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=512, + ) + + model = models[type] + inv_label_map = inv_label_maps[type] + + with torch.no_grad(): + outputs = model(**inputs) + + logits = outputs.logits + probabilities = torch.softmax(logits, dim=-1) + predicted_labels = torch.argmax(probabilities, dim=-1) + predicted_probabilities = probabilities[ + torch.arange(probabilities.size(0)), predicted_labels + ] + + return [ + (inv_label_map[label], prob) + for label, prob in zip( + predicted_labels.tolist(), predicted_probabilities.tolist() + ) + ] From 6c6705aa760c277fb1d4773a8e5896c3ed123ec6 Mon Sep 17 00:00:00 2001 From: cjeongmin Date: Tue, 18 Feb 2025 17:16:17 +0900 Subject: [PATCH 2/2] feat: update slang prediction endpoint to support batch input --- apps/classifier/app.py | 10 ++++++---- apps/classifier/schemas.py | 16 +++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/apps/classifier/app.py b/apps/classifier/app.py index e3040cf..8012a95 100644 --- a/apps/classifier/app.py +++ b/apps/classifier/app.py @@ -5,7 +5,7 @@ PredictionRequest, PredictionResponse, ) -from model import predict +from model import predict, predict_batch app = FastAPI() @@ -26,6 +26,8 @@ async def improve_reply_predict(data: PredictionRequest): @app.post("/slang-predict", response_model=SlangPredictionResponse) async def slang_predict(data: SlangPredictionRequest): - text = data.input - predicted = predict(text, type="slang") - return {"predicted": predicted[0], "probability": predicted[1]} + text = data.inputs + predicted = predict_batch(text, type="slang") + return { + "predictions": [{"predicted": p[0], "probability": p[1]} for p in predicted] + } diff --git a/apps/classifier/schemas.py b/apps/classifier/schemas.py index 3b6d410..4dd44ab 100644 --- a/apps/classifier/schemas.py +++ b/apps/classifier/schemas.py @@ -1,27 +1,29 @@ from pydantic import BaseModel +from typing import List class SlangPredictionRequest(BaseModel): - input: str + inputs: List[str] class Config: json_schema_extra = { "example": { - "input": "X같네", + "inputs": ["X같네"], } } -class SlangPredictionResponse(BaseModel): +class SlangPredictionItem(BaseModel): predicted: str probability: float + +class SlangPredictionResponse(BaseModel): + predictions: List[SlangPredictionItem] + class Config: json_schema_extra = { - "example": { - "predicted": "욕설", - "probability": 0.99, - } + "example": {"predictions": [{"predicted": "욕설", "probability": 0.99}]} }