Кто знал, что запуск моделей PyTorch в производство может быть таким простым?
Итак, у вас есть модель, которую вы хотите развернуть. Может быть, вы исследователь, энтузиаст или *взволнованно хихикает* новейший стартап в области высоких технологий. И теперь мы все стоим на одном перекрестке — как лучше всего настроить и запустить нашу модель, чтобы мы могли поделиться своими технологиями со всем миром?
Сегодня я расскажу вам, как это сделать для простого случая — развертывания модели классификации изображений PyTorch. Этот код прост в использовании, работает для любой модели классификации изображений PyTorch и разработан таким образом, чтобы его можно было легко обобщить для других задач, таких как обнаружение объектов, или даже таких задач, как классификация/генерация текста.
Прежде чем мы начнем, я отмечу, что я предполагаю, что вы хотя бы на практике знакомы с Flask. Если нет, я бы порекомендовал бегло просмотреть эти два руководства; это должно занять максимум 10 минут.
- Создание простого приложения Flask, быстро
- Глубокое погружение во Flask (перейти к разделу о шаблонах)
Готовый? Пристегнитесь, друзья, это будет дикая поездка!
Монтаж
Во-первых, убедитесь, что у вас установлены PyTorch и Flask. Вы можете установить эти пакеты, используя pip install torch
и pip install flask
.
Встать и бежать
Создайте файл app.py
и маршрут для домашней страницы, добавив импорт для наших двух звездных библиотек.
from flask import Flask import torch app = Flask(__name__) @app.route('/') def home(): return 'Welcome to the PyTorch Flask app!'
Теперь, когда вы запустите python app.py
и посетите http://localhost:5000/, вы должны увидеть простое сообщение — Добро пожаловать в приложение PyTorch Flask! !
Оформление сцены…
Теперь, когда наш скаффолд готов, давайте добавим некоторый код «настройки», который позволит использовать данные, которые мы хотим, в нашу модель!
Добавьте еще пару импортов
from flask import Flask, request, render_template from PIL import Image import torch import torchvision.transforms as transforms
Замените наш основной домашний маршрут на настоящую HTML-страницу.
@app.route('/') def home(): return render_template('home.html')
Создайте папку templates
и добавьте следующий HTML-код для нашего шаблона home.html
.
<html> <head> <title>PyTorch Image Classification</title> </head> <body> <h1>PyTorch Image Classification</h1> <form method="POST" enctype="multipart/form-data" action="/predict"> <input type="file" name="image"> <input type="submit" value="Predict"> </form> </body> </html>
HTML довольно прост — у нас есть кнопка загрузки для загрузки любых данных (в нашем случае изображений), через которые мы хотим запустить нашу модель.
Веселая часть
Мы сделали всю эту настройку, и теперь без лишних слов давайте погрузимся в настоящее мясо — подключим модель.
Над маршрутом home
давайте загрузим нашу модель.
model = torch.jit.load('path/to/model.pth')
Однако в большинстве случаев мы хотим каким-то образом «обработать» входные данные, чтобы преобразовать их в формат, подходящий для нашей модели, — преобразовать их в тензор, изменить размер или выполнить любую форму предварительной обработки.
В нашем случае мы работаем с изображениями, поэтому делаем небольшую рутинную очистку.
def process_image(image): # Preprocess image for model transformation = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transformation(image).unsqueeze(0) return image_tensor
Быстро создайте массив, чтобы указать наши классы
class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES
Сейчас настоящий момент, в котором мы были — создание маршрута для предсказания
Создайте маршрут, который принимает загруженное изображение, обрабатывает его, делает прогноз с использованием модели и возвращает вероятности для каждого класса.
@app.route('/predict', methods=['POST']) def predict(): # Get uploaded image file image = request.files['image'] # Process image and make prediction image_tensor = process_image(Image.open(image)) output = model(image_tensor) # Get class probabilities probabilities = torch.nn.functional.softmax(output, dim=1) probabilities = probabilities.detach().numpy()[0] # Get the index of the highest probability class_index = probabilities.argmax() # Get the predicted class and probability predicted_class = class_names[class_index] probability = probabilities[class_index] # Sort class probabilities in descending order class_probs = list(zip(class_names, probabilities)) class_probs.sort(key=lambda x: x[1], reverse=True) # Render HTML page with prediction results return render_template('predict.html', class_probs=class_probs, predicted_class=predicted_class, probability=probability)
Эта версия маршрута /predict
сначала получает вероятности класса с помощью функции softmax
, а затем получает индекс наибольшей вероятности. Затем он использует этот индекс для поиска предсказанного класса в списке имен классов и получает вероятность для этого класса. Затем он сортирует вероятности классов в порядке убывания и отображает HTML-страницу с результатами прогнозирования.
В целом наш файл app.py
должен выглядеть примерно так:
from flask import Flask, request, render_template from PIL import Image import torch import torchvision.transforms as transforms model = torch.jit.load('path/to/model.pth') @app.route('/') def home(): return render_template('home.html') def process_image(image): # Preprocess image for model transformation = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transformation(image).unsqueeze(0) return image_tensor class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES @app.route('/predict', methods=['POST']) def predict(): # Get uploaded image file image = request.files['image'] # Process image and make prediction image_tensor = process_image(Image.open(image)) output = model(image_tensor) # Get class probabilities probabilities = torch.nn.functional.softmax(output, dim=1) probabilities = probabilities.detach().numpy()[0] # Get the index of the highest probability class_index = probabilities.argmax() # Get the predicted class and probability predicted_class = class_names[class_index] probability = probabilities[class_index] # Sort class probabilities in descending order class_probs = list(zip(class_names, probabilities)) class_probs.sort(key=lambda x: x[1], reverse=True) # Render HTML page with prediction results return render_template('predict.html', class_probs=class_probs, predicted_class=predicted_class, probability=probability)
Последний мазок
Теперь грандиозный финал — реализация шаблона predict.html
.
Чтобы создать HTML-страницу, создайте новый файл шаблона с именем predict.html
и используйте следующий код:
<html> <head> <title>Prediction Results</title> </head> <body> <h1>Prediction Results</h1> <p>Predicted Class: {{ predicted_class }}</p> <p>Probability: {{ probability }}</p> <h2>Other Classes</h2> <ul> {% for class_name, prob in class_probs %} <li>{{ class_name }}: {{ prob }}</li> {% endfor %} </ul> </body> </html>
Эта HTML-страница отображает прогнозируемый класс и вероятность, а также список других классов в порядке убывания вероятности.
Примечание. Не стесняйтесь изменять форматирование значений вероятности по своему усмотрению.
Тестирование
Важная часть производства — убедиться, что наше приложение действительно работает. Для начала запустите свой сервер с помощью python app.py
Я предлагаю протестировать вашу модель, (а) загрузив изображение и (б) отправив запрос POST программно.
Вот некоторый код Python, который может сделать (b)
import requests # Set URL for Flask app url = 'http://localhost:5000/predict' # Set image file path image_path = 'path/to/image.jpg' # Read image file and set as payload image = open(image_path, 'rb') payload = {'image': image} # Send POST request with image and get response response = requests.post(url, headers=headers, data=payload) print(response.text)
Этот код отправит запрос POST в приложение Flask с указанным файлом изображения в качестве полезной нагрузки. Затем приложение обработает изображение, сделает прогноз и вернет ответ, который будет напечатан на консоли.
Примечание. Обязательно замените 'path/to/image.jpg'
на фактический путь к файлу изображения, которое вы хотите загрузить, и укажите правильный URL-адрес для вашего приложения Flask.
Вот и все! Поздравляем, вам удалось разработать целое приложение машинного обучения и протестировать его за пару минут!