Научитесь создавать API классификации изображений с помощью Tensorflow и FastAPI с нуля.
FastAPI - это высокопроизводительный асинхронный фреймворк для создания API на Python.
Видеоурок также доступен для этого блога.
Исходный код этого блога доступен aniketmaurya / tensorflow-fastapi-starter-pack
Начнем с простого примера hello-world
Сначала мы импортируем FastAPI
класс и создаем объект app
. Этот класс имеет полезные параметры, например, мы можем передать заголовок и описание пользовательского интерфейса Swagger.
from fastapi import FastAPI
app = FastAPI(title='Hello world')
Мы определяем функцию и украшаем ее @app.get
. Это означает, что наш API /index
поддерживает метод GET. Определенная здесь функция является асинхронной, FastAPI автоматически заботится об асинхронности и без асинхронных методов, создавая пул потоков для обычных функций def, и использует цикл событий async для асинхронных функций.
@app.get('/index')
async def hello_world():
return "hello world"
API распознавания изображений
Мы создадим API для классификации изображений, назовем его predict/image
. Мы будем использовать Tensorflow для создания модели классификации изображений.
Мы создаем функцию load_model
, которая будет возвращать модель MobileNet CNN с предварительно обученными весами, то есть она уже обучена классифицировать 1000 уникальных категорий изображений.
import tensorflow as tf
def load_model(): model = tf.keras.applications.MobileNetV2(weights="imagenet") print("Model loaded") return model
model = load_model()
Мы определяем predict
функцию, которая принимает изображение и возвращает прогнозы. Мы изменяем размер изображения до 224x224 и нормализуем значения пикселей, чтобы они были в [-1, 1].
from tensorflow.keras.applications.imagenet_utils import decode_predictions
decode_predictions
используется для декодирования имени класса прогнозируемого объекта. Здесь мы вернем второй вероятный класс.
def predict(image: Image.Image):
image = np.asarray(image.resize((224, 224)))[..., :3] image = np.expand_dims(image, 0) image = image / 127.5 - 1.0
result = decode_predictions(model.predict(image), 2)[0]
response = [] for i, res in enumerate(result): resp = {} resp["class"] = res[1] resp["confidence"] = f"{res[2]*100:0.2f} %"
response.append(resp)
return response
Теперь мы создадим API /predict/image
, поддерживающий загрузку файлов. Мы отфильтруем расширение файла, чтобы поддерживать изображения только в форматах jpg, jpeg и png.
Мы будем использовать Pillow для загрузки загруженного изображения.
def read_imagefile(file) -> Image.Image: image = Image.open(BytesIO(file)) return image
@app.post("/predict/image") async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image)
return prediction
Окончательный код
import uvicorn from fastapi import FastAPI, File, UploadFile
from application.components import predict, read_imagefile
app = FastAPI()
@app.post("/predict/image") async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image)
return prediction
@app.post("/api/covid-symptom-check") def check_risk(symptom: Symptom): return symptom_check.get_risk_level(symptom)
if __name__ == "__main__": uvicorn.run(app, debug=True)
Документация FastAPI - лучшее место, чтобы узнать больше об основных концепциях фреймворка.
Надеюсь, вам понравилась статья.
Не стесняйтесь задавать свои вопросы в комментариях или свяжитесь со мной лично
👉 Twitter: https://twitter.com/aniketmaurya
👉 Linkedin: https://linkedin.com/in/aniketmaurya