-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
rodrigo.arenas
committed
Feb 11, 2021
1 parent
84c7d8e
commit 56e89a7
Showing
5 changed files
with
32 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,18 @@ | ||
import models.ml.classifier as clf | ||
import numpy as np | ||
from fastapi import FastAPI | ||
from joblib import load | ||
from models.iris import Iris, IrisPredictionResponse | ||
from routes.v1.iris_predict import app_iris_predict_v1 | ||
from routes.home import app_home | ||
|
||
|
||
app = FastAPI(title="Iris ML API", description="API for iris dataset ml model", version="1.0") | ||
|
||
|
||
@app.on_event('startup') | ||
async def load_model(): | ||
clf.model = load('models/ml/iris_dt_v1.joblib') | ||
np.seterr(divide='warn') | ||
|
||
|
||
@app.get('/', tags=["Intro"]) | ||
async def hello(): | ||
return {"message": "Hello!"} | ||
|
||
app.include_router(app_home) | ||
app.include_router(app_iris_predict_v1, prefix='/v1') | ||
|
||
@app.post('/predict', tags=["predictions"], | ||
response_model=IrisPredictionResponse) | ||
async def get_prediction(iris: Iris): | ||
data = dict(iris)['data'] | ||
prediction = clf.model.predict(data).tolist() | ||
log_probability = clf.model.predict_log_proba(data).tolist() | ||
return {"prediction": prediction, | ||
"log_probability": log_probability} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from fastapi import APIRouter | ||
|
||
app_home = APIRouter() | ||
|
||
|
||
@app_home.get('/', tags=["Intro"]) | ||
async def hello(): | ||
return {"message": "Hello!"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from fastapi import APIRouter | ||
from models.schemas.iris import Iris, IrisPredictionResponse | ||
import models.ml.classifier as clf | ||
|
||
app_iris_predict_v1 = APIRouter() | ||
|
||
|
||
@app_iris_predict_v1.post('/iris/predict', | ||
tags=["Predictions"], | ||
response_model=IrisPredictionResponse) | ||
async def get_prediction(iris: Iris): | ||
data = dict(iris)['data'] | ||
prediction = clf.model.predict(data).tolist() | ||
probability = clf.model.predict_proba(data).tolist() | ||
log_probability = clf.model.predict_log_proba(data).tolist() | ||
return {"prediction": prediction, | ||
"probability": probability, | ||
"log_probability": log_probability} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters