Building Machine Learning API with FastAPI and Tensorflow

Tutorial on FastAPI - high performance asynchronous framework for faster development of production ready APIs.

July 26, 2020


tensorflow, fastapi, python, web development, machine learning, computer vision

Youtube tutorial version of this blog is also available

FastAPI is a high-performance asynchronous framework for building APIs in Python. It provides support for Swagger UI out of the box.

Source code for this blog is available aniketmaurya/tensorflow-fastapi-starter-pack

hello-world of FastAPI

First, we import FastAPI class and create an object app. This class has useful parameters like we can pass the title and description for Swagger UI.

from fastapi import FastAPI
app = FastAPI(title='Hello world', description='This is a hello world example', version='0.0.1')

We define a function and decorate it with @app.get. This means that our API /index supports the GET method. The function defined here is async, FastAPI automatically takes care of async and without async methods by creating a thread pool for the normal def functions and it uses an async event loop for async functions.

async def hello_world():
    return "hello world"

Pydantic support

One of my favorite features offered by FastAPI is Pydantic support. We can define Pydantic models and request-response will be handled by FastAPI for these models. Let’s create a COVID-19 symptom checker API to understand this.

Covid-19 symptom checker API

We create a request body, it is the format in which the client should send the request. It will be used by Swagger UI.

from pydantic import BaseModel

class Symptom(BaseModel):
    fever: bool = False
    dry_cough: bool = False
    tiredness: bool = False
    breathing_problem: bool = False

Let’s create a function to assign a risk level based on the inputs.

This is just for learning and should not be used in real life, better consult a doctor.

def get_risk_level(symptom: Symptom):
    if not (symptom.fever or symptom.dry_cough or symptom.tiredness or symptom.breathing_problem):
        return 'Low risk level. THIS IS A DEMO APP'

    if not (symptom.breathing_problem or symptom.dry_cough):
        if symptom.fever:
            return 'moderate risk level. THIS IS A DEMO APP'

    if symptom.breathing_problem:
        return 'High-risk level. THIS IS A DEMO APP'

    return 'THIS IS A DEMO APP'

Let’s create the API for checking the symptoms'/api/covid-symptom-check')
def check_risk(symptom: Symptom):
    return get_risk_level(symptom)

Image recognition API

We will create an API to classify images, we name it predict/image. We will use Tensorflow for creating the image classification model.

Tutorial for Image Classification with Tensorflow

We create a function load_model, which will return a MobileNet CNN Model with pre-trained weights i.e. it is already trained to classify 1000 unique categories of images.

import tensorflow as tf

def load_model():
    model = tf.keras.applications.MobileNetV2(weights="imagenet")
    print("Model loaded")
    return model

model = load_model()

We define a predict function that will accept an image and returns the predictions. We resize the image to 224x224 and normalize the pixel values to be in [-1, 1].

from tensorflow.keras.applications.imagenet_utils import decode_predictions

decode_predictions is used to decode the class name of the predicted object. Here we will return the top-2 probable class.

def predict(image: Image.Image):

    image = np.asarray(image.resize((224, 224)))[..., :3]
    image = np.expand_dims(image, 0)
    image = image / 127.5 - 1.0

    result = decode_predictions(model.predict(image), 2)[0]

    response = []
    for i, res in enumerate(result):
        resp = {}
        resp["class"] = res[1]
        resp["confidence"] = f"{res[2]*100:0.2f} %"


    return response

Now we will create an API /predict/image which supports file upload. We will filter the file extension to support only jpg, jpeg, and png format of images.

We will use Pillow to load the uploaded image.

def read_imagefile(file) -> Image.Image:
    image =
    return image"/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await
    prediction = predict(image)

    return prediction

Final code

import uvicorn
from fastapi import FastAPI, File, UploadFile

from application.components import predict, read_imagefile

app = FastAPI()"/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await
    prediction = predict(image)

    return prediction"/api/covid-symptom-check")
def check_risk(symptom: Symptom):
    return symptom_check.get_risk_level(symptom)

if __name__ == "__main__":, debug=True)

FastAPI documentation is the best place to learn more about core concepts of the framework.

> Hope you liked the article.

Feel free to ask your questions in the comments or reach me out personally

👉 Twitter:

👉 Linkedin: