Научитесь создавать API классификации изображений с помощью Tensorflow и FastAPI с нуля.

FastAPI - это высокопроизводительный асинхронный фреймворк для создания API на Python.

Видеоурок также доступен для этого блога.

Исходный код этого блога доступен aniketmaurya / tensorflow-fastapi-starter-pack

Начнем с простого примера hello-world

Сначала мы импортируем FastAPI класс и создаем объект app. Этот класс имеет полезные параметры, например, мы можем передать заголовок и описание пользовательского интерфейса Swagger.

from fastapi import FastAPI
app = FastAPI(title='Hello world')

Мы определяем функцию и украшаем ее @app.get. Это означает, что наш API /index поддерживает метод GET. Определенная здесь функция является асинхронной, FastAPI автоматически заботится об асинхронности и без асинхронных методов, создавая пул потоков для обычных функций def, и использует цикл событий async для асинхронных функций.

@app.get('/index')
async def hello_world():
    return "hello world"

API распознавания изображений

Мы создадим API для классификации изображений, назовем его predict/image. Мы будем использовать Tensorflow для создания модели классификации изображений.

Учебник Классификация изображений с помощью Tensorflow

Мы создаем функцию load_model, которая будет возвращать модель MobileNet CNN с предварительно обученными весами, то есть она уже обучена классифицировать 1000 уникальных категорий изображений.

import tensorflow as tf
def load_model():
    model = tf.keras.applications.MobileNetV2(weights="imagenet")
    print("Model loaded")
    return model
model = load_model()

Мы определяем predict функцию, которая принимает изображение и возвращает прогнозы. Мы изменяем размер изображения до 224x224 и нормализуем значения пикселей, чтобы они были в [-1, 1].

from tensorflow.keras.applications.imagenet_utils import decode_predictions

decode_predictions используется для декодирования имени класса прогнозируемого объекта. Здесь мы вернем второй вероятный класс.

def predict(image: Image.Image):
    image = np.asarray(image.resize((224, 224)))[..., :3]
    image = np.expand_dims(image, 0)
    image = image / 127.5 - 1.0
    result = decode_predictions(model.predict(image), 2)[0]
    response = []
    for i, res in enumerate(result):
        resp = {}
        resp["class"] = res[1]
        resp["confidence"] = f"{res[2]*100:0.2f} %"
        response.append(resp)
    return response

Теперь мы создадим API /predict/image, поддерживающий загрузку файлов. Мы отфильтруем расширение файла, чтобы поддерживать изображения только в форматах jpg, jpeg и png.

Мы будем использовать Pillow для загрузки загруженного изображения.

def read_imagefile(file) -> Image.Image:
    image = Image.open(BytesIO(file))
    return image
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await file.read())
    prediction = predict(image)
    return prediction

Окончательный код

import uvicorn
from fastapi import FastAPI, File, UploadFile
from application.components import predict, read_imagefile
app = FastAPI()
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await file.read())
    prediction = predict(image)
    return prediction

@app.post("/api/covid-symptom-check")
def check_risk(symptom: Symptom):
    return symptom_check.get_risk_level(symptom)

if __name__ == "__main__":
    uvicorn.run(app, debug=True)

Документация FastAPI - лучшее место, чтобы узнать больше об основных концепциях фреймворка.

Надеюсь, вам понравилась статья.

Не стесняйтесь задавать свои вопросы в комментариях или свяжитесь со мной лично

👉 Twitter: https://twitter.com/aniketmaurya

👉 Linkedin: https://linkedin.com/in/aniketmaurya