Skip to content

Commit

Permalink
Refactor project structure
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigo.arenas committed Feb 11, 2021
1 parent 84c7d8e commit 56e89a7
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
20 changes: 5 additions & 15 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
import models.ml.classifier as clf
import numpy as np
from fastapi import FastAPI
from joblib import load
from models.iris import Iris, IrisPredictionResponse
from routes.v1.iris_predict import app_iris_predict_v1
from routes.home import app_home


app = FastAPI(title="Iris ML API", description="API for iris dataset ml model", version="1.0")


@app.on_event('startup')
async def load_model():
clf.model = load('models/ml/iris_dt_v1.joblib')
np.seterr(divide='warn')


@app.get('/', tags=["Intro"])
async def hello():
return {"message": "Hello!"}

app.include_router(app_home)
app.include_router(app_iris_predict_v1, prefix='/v1')

@app.post('/predict', tags=["predictions"],
response_model=IrisPredictionResponse)
async def get_prediction(iris: Iris):
data = dict(iris)['data']
prediction = clf.model.predict(data).tolist()
log_probability = clf.model.predict_log_proba(data).tolist()
return {"prediction": prediction,
"log_probability": log_probability}
1 change: 1 addition & 0 deletions models/iris.py → models/schemas/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ class Iris(BaseModel):

class IrisPredictionResponse(BaseModel):
prediction: List[int]
probability: List[Any]
log_probability: List[Any]
8 changes: 8 additions & 0 deletions routes/home.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from fastapi import APIRouter

app_home = APIRouter()


@app_home.get('/', tags=["Intro"])
async def hello():
return {"message": "Hello!"}
18 changes: 18 additions & 0 deletions routes/v1/iris_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fastapi import APIRouter
from models.schemas.iris import Iris, IrisPredictionResponse
import models.ml.classifier as clf

app_iris_predict_v1 = APIRouter()


@app_iris_predict_v1.post('/iris/predict',
tags=["Predictions"],
response_model=IrisPredictionResponse)
async def get_prediction(iris: Iris):
data = dict(iris)['data']
prediction = clf.model.predict(data).tolist()
probability = clf.model.predict_proba(data).tolist()
log_probability = clf.model.predict_log_proba(data).tolist()
return {"prediction": prediction,
"probability": probability,
"log_probability": log_probability}
1 change: 0 additions & 1 deletion tests/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def predict(self):
self.client.post('/predict', json=request_body)



class IrisLoadTest(HttpUser):
tasks = [IrisPredict]
host = 'http://127.0.0.1'
Expand Down

0 comments on commit 56e89a7

Please sign in to comment.