Кто знал, что запуск моделей PyTorch в производство может быть таким простым?

Итак, у вас есть модель, которую вы хотите развернуть. Может быть, вы исследователь, энтузиаст или *взволнованно хихикает* новейший стартап в области высоких технологий. И теперь мы все стоим на одном перекрестке — как лучше всего настроить и запустить нашу модель, чтобы мы могли поделиться своими технологиями со всем миром?

Сегодня я расскажу вам, как это сделать для простого случая — развертывания модели классификации изображений PyTorch. Этот код прост в использовании, работает для любой модели классификации изображений PyTorch и разработан таким образом, чтобы его можно было легко обобщить для других задач, таких как обнаружение объектов, или даже таких задач, как классификация/генерация текста.

Прежде чем мы начнем, я отмечу, что я предполагаю, что вы хотя бы на практике знакомы с Flask. Если нет, я бы порекомендовал бегло просмотреть эти два руководства; это должно занять максимум 10 минут.

Готовый? Пристегнитесь, друзья, это будет дикая поездка!

Монтаж

Во-первых, убедитесь, что у вас установлены PyTorch и Flask. Вы можете установить эти пакеты, используя pip install torch и pip install flask.

Встать и бежать

Создайте файл app.py и маршрут для домашней страницы, добавив импорт для наших двух звездных библиотек.

from flask import Flask
import torch


app = Flask(__name__)

@app.route('/')
def home():
    return 'Welcome to the PyTorch Flask app!'

Теперь, когда вы запустите python app.py и посетите http://localhost:5000/, вы должны увидеть простое сообщение — Добро пожаловать в приложение PyTorch Flask! !

Оформление сцены…

Теперь, когда наш скаффолд готов, давайте добавим некоторый код «настройки», который позволит использовать данные, которые мы хотим, в нашу модель!

Добавьте еще пару импортов

from flask import Flask, request, render_template
from PIL import Image
import torch
import torchvision.transforms as transforms

Замените наш основной домашний маршрут на настоящую HTML-страницу.

@app.route('/')
def home():
    return render_template('home.html')

Создайте папку templates и добавьте следующий HTML-код для нашего шаблона home.html.

<html>
  <head>
    <title>PyTorch Image Classification</title>
  </head>
  <body>
    <h1>PyTorch Image Classification</h1>
    <form method="POST" enctype="multipart/form-data" action="/predict">
      <input type="file" name="image">
      <input type="submit" value="Predict">
    </form>
  </body>
</html>

HTML довольно прост — у нас есть кнопка загрузки для загрузки любых данных (в нашем случае изображений), через которые мы хотим запустить нашу модель.

Веселая часть

Мы сделали всю эту настройку, и теперь без лишних слов давайте погрузимся в настоящее мясо — подключим модель.

Над маршрутом home давайте загрузим нашу модель.

model = torch.jit.load('path/to/model.pth')

Однако в большинстве случаев мы хотим каким-то образом «обработать» входные данные, чтобы преобразовать их в формат, подходящий для нашей модели, — преобразовать их в тензор, изменить размер или выполнить любую форму предварительной обработки.

В нашем случае мы работаем с изображениями, поэтому делаем небольшую рутинную очистку.

def process_image(image):
    # Preprocess image for model
    transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transformation(image).unsqueeze(0)
    
    return image_tensor

Быстро создайте массив, чтобы указать наши классы

class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES

Сейчас настоящий момент, в котором мы были — создание маршрута для предсказания

Создайте маршрут, который принимает загруженное изображение, обрабатывает его, делает прогноз с использованием модели и возвращает вероятности для каждого класса.

@app.route('/predict', methods=['POST'])
def predict():
    # Get uploaded image file
    image = request.files['image']

    # Process image and make prediction
    image_tensor = process_image(Image.open(image))
    output = model(image_tensor)

    # Get class probabilities
    probabilities = torch.nn.functional.softmax(output, dim=1)
    probabilities = probabilities.detach().numpy()[0]

    # Get the index of the highest probability
    class_index = probabilities.argmax()

    # Get the predicted class and probability
    predicted_class = class_names[class_index]
    probability = probabilities[class_index]

    # Sort class probabilities in descending order
    class_probs = list(zip(class_names, probabilities))
    class_probs.sort(key=lambda x: x[1], reverse=True)

    # Render HTML page with prediction results
    return render_template('predict.html', class_probs=class_probs,
                           predicted_class=predicted_class, probability=probability)

Эта версия маршрута /predict сначала получает вероятности класса с помощью функции softmax, а затем получает индекс наибольшей вероятности. Затем он использует этот индекс для поиска предсказанного класса в списке имен классов и получает вероятность для этого класса. Затем он сортирует вероятности классов в порядке убывания и отображает HTML-страницу с результатами прогнозирования.

В целом наш файл app.py должен выглядеть примерно так:

from flask import Flask, request, render_template
from PIL import Image
import torch
import torchvision.transforms as transforms


model = torch.jit.load('path/to/model.pth')

@app.route('/')
def home():
    return render_template('home.html')

def process_image(image):
    # Preprocess image for model
    transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transformation(image).unsqueeze(0)
    
    return image_tensor


class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES

@app.route('/predict', methods=['POST'])
def predict():
    # Get uploaded image file
    image = request.files['image']

    # Process image and make prediction
    image_tensor = process_image(Image.open(image))
    output = model(image_tensor)

    # Get class probabilities
    probabilities = torch.nn.functional.softmax(output, dim=1)
    probabilities = probabilities.detach().numpy()[0]

    # Get the index of the highest probability
    class_index = probabilities.argmax()

    # Get the predicted class and probability
    predicted_class = class_names[class_index]
    probability = probabilities[class_index]

    # Sort class probabilities in descending order
    class_probs = list(zip(class_names, probabilities))
    class_probs.sort(key=lambda x: x[1], reverse=True)

    # Render HTML page with prediction results
    return render_template('predict.html', class_probs=class_probs,
                           predicted_class=predicted_class, probability=probability)

Последний мазок

Теперь грандиозный финал — реализация шаблона predict.html.

Чтобы создать HTML-страницу, создайте новый файл шаблона с именем predict.html и используйте следующий код:

<html>
  <head>
    <title>Prediction Results</title>
  </head>
  <body>
    <h1>Prediction Results</h1>
    <p>Predicted Class: {{ predicted_class }}</p>
    <p>Probability: {{ probability }}</p>
    <h2>Other Classes</h2>
    <ul>
      {% for class_name, prob in class_probs %}
        <li>{{ class_name }}: {{ prob }}</li>
      {% endfor %}
    </ul>
  </body>
</html>

Эта HTML-страница отображает прогнозируемый класс и вероятность, а также список других классов в порядке убывания вероятности.

Примечание. Не стесняйтесь изменять форматирование значений вероятности по своему усмотрению.

Тестирование

Важная часть производства — убедиться, что наше приложение действительно работает. Для начала запустите свой сервер с помощью python app.py

Я предлагаю протестировать вашу модель, (а) загрузив изображение и (б) отправив запрос POST программно.

Вот некоторый код Python, который может сделать (b)

import requests

# Set URL for Flask app
url = 'http://localhost:5000/predict'

# Set image file path
image_path = 'path/to/image.jpg'

# Read image file and set as payload
image = open(image_path, 'rb')
payload = {'image': image}

# Send POST request with image and get response
response = requests.post(url, headers=headers, data=payload)

print(response.text)

Этот код отправит запрос POST в приложение Flask с указанным файлом изображения в качестве полезной нагрузки. Затем приложение обработает изображение, сделает прогноз и вернет ответ, который будет напечатан на консоли.

Примечание. Обязательно замените 'path/to/image.jpg' на фактический путь к файлу изображения, которое вы хотите загрузить, и укажите правильный URL-адрес для вашего приложения Flask.

Вот и все! Поздравляем, вам удалось разработать целое приложение машинного обучения и протестировать его за пару минут!