Построение новой модели с минимальным дополнительным кодом

Этот пост в блоге является частью мини-серии, в которой рассказывается о различных аспектах создания проекта глубокого обучения PyTorch с использованием вариационных автоэнкодеров.

Часть 1: Математические основы и реализация
Часть 2: Суперзарядка с помощью PyTorch Lightning
Часть 3: Сверточный VAE, наследование и модульное тестирование
Часть 4: Веб-приложение Streamlit и развертывание

В этом разделе мы рассмотрим, как мы можем использовать код, который мы написали в предыдущем разделе, и использовать его для создания сверточной VAE. Этот VAE будет лучше определять важные функции на изображениях и, таким образом, создавать еще более качественные изображения.

Самое приятное то, что эту новую модель можно построить с минимальным дополнительным кодом благодаря модулям PyTorch и наследованию классов.

Что такое сверточный VAE?

Свертка — это операция, обычно используемая при обработке изображений для извлечения признаков конкретного изображения. Изображения, как правило, полны ненужной информации и масштабируются в любой пиксель, окружающие пиксели, вероятно, имеют очень похожий цвет. В сверточных нейронных сетях (CNN) многие сверточные фильтры автоматически изучаются для получения функций, полезных при классификации и идентификации изображений. Мы просто заимствуем эти принципы, чтобы использовать сверточные слои для построения VAE.

Создавая сверточный VAE, мы стремимся улучшить процесс извлечения признаков. Несмотря на то, что мы не выполняем никаких задач классификации/регрессии, мы хотим, чтобы скрытое представление было как можно более информативным. Благодаря более мощному извлечению признаков декодер может генерировать более убедительные точки данных.

Несмотря на то, что эта новая модель использует новую архитектуру, мы хотим писать код эффективно. В хорошем и эффективном коде используется принцип DRY (не повторяйтесь). Чтобы избежать ненужного повторения кода, мы будем использовать наследование — мощную концепцию для построения нашей модели.

Что такое наследование?

Наследование — очень мощная концепция, присутствующая в языках объектно-ориентированного программирования (ООП). Это позволяет пользователям определять объекты, а затем создавать новые объекты, сохраняя при этом некоторые функциональные возможности исходного объекта. Наследование — довольно обширная тема, и есть такие вещи, как множественное наследование, о которых я не буду вдаваться в подробности. ознакомьтесь с этим для получения дополнительной информации об объектно-ориентированном программировании в Python и наследовании.



На самом деле наследование настолько распространено, что мы уже использовали наследование в Части 1. Даже не зная об этом, наследование широко используется в PyTorch, где каждая нейронная сеть наследуется от базового класса nn.Module.

Из-за этого нам нужно только определить методы __init__ и forward, а базовый класс сделает все остальное. Модель, которую мы собираемся построить, пойдет дальше и будет основываться на VAE, созданной в предыдущем разделе.

Наследование позволяет нам строить сложные модели на разных этапах. Созданная нами предыдущая модель VAE действует как скелет. Он выполняет репараметризацию и реализует потерю KL-дивергенции.

Затем мы можем наследовать этот класс и создать лучшую модель, использующую архитектуру, более подходящую для задачи.

Затем эту модель можно адаптировать, изменив кодировщик и декодер. Кодер просто выполняет обучение представлению, а декодер выполняет генерацию. Эти подсети могут быть простыми линейными слоями или сложными сетями.

В нашем сверточном VAE мы хотим изменить эти компоненты, сохранив при этом все остальное идентичным. Это легко сделать с помощью наследования

Это позволяет нам избежать повторения большого количества кода. Такие методы класса, как forward, training_step, train_loader , останутся точно такими же, а наследование позволит нам автоматически копировать их.

Рефакторинг кода

Если внимательно приметить, в предыдущей модели. Шаг вперед включал выравнивание вектора перед его подачей в кодировщик. Для Convolution VAE мы не хотим делать это сглаживание, так как это мешает нам использовать 2D-свертки.

Похоже, чтобы наследование заработало, нам нужно провести рефакторинг кода!

По сути, рефакторинг кода — это внесение некоторых изменений в код при сохранении той же внешней функциональности. Это означает, что код по-прежнему будет иметь такое же поведение с точки зрения входных и выходных данных. Рефакторинг может быть выполнен, чтобы сделать код быстрее или, в нашем случае, упростить код, чтобы мы могли повторно использовать его где-то еще.

Вместо того, чтобы переписывать весь шаг вперед, мы можем реорганизовать наш код так, чтобы сглаживание входного тензора и изменение его формы обратно до 28 x 28 происходило внутри self.encoder и self.decoder, а не внутри прямой функции.

Это позволяет модели быть более универсальной, поскольку она может работать с различными кодировщиками, такими как свертка, когда мы не хотим сглаживать входной вектор.

Модульные тесты

Но держись! Давайте пока не будем делать рефакторинг. Худшее, что вы хотите, это чтобы ваша модель сломалась, когда вы вносите изменения в код. Мы хотим убедиться, что модель VAE по-прежнему делает то же самое после рефакторинга.

Один из хороших способов убедиться, что ваш проект по-прежнему функционирует после внесения изменений, — это написать модульные тесты.

Модульные тесты — это простые сценарии, которые вы можете запустить, чтобы убедиться, что ваш код работает правильно. В контексте нашей модели это необходимо для того, чтобы созданная нами модель по-прежнему могла обучаться, а градиенты по-прежнему могли хорошо распространяться обратно.

Для этого мы будем использовать pytest, мощную библиотеку для написания модульных тестов, которая также содержит полезные инструменты отладки, чтобы выяснить, почему тесты терпят неудачу.

Сначала мы создаем папку в нашем каталоге с именем tests. Внутри этой папки мы создаем файл с именем test_model.py. В нем будут храниться все необходимые модульные тесты.

Давайте определим простой тест:

Еще одна интересная особенность pytest заключается в том, что он автоматически ищет тестовые функции в пакете. Пока имя функции начинается с test, pytest будет запускать тест соответствующим образом.

Запустив pytest в командной строке, мы можем подтвердить, что тест пройден.

Модульность Pytorch

Теперь, когда у нас есть надлежащие системы тестирования, мы можем начать вносить изменения в наш код.

Одна важная вещь, которую нужно понять о модулях PyTorch, заключается в том, что они в основном являются функциями. Когда ввод передается в любой модуль PyTorch, он просто выполняет некоторые операции и обратно распространяет градиенты.

Это означает, что даже простые операции изменения формы можно инициализировать как объект PyTorch.

Мы просто берем первую строку старой прямой функции и добавляем ее как модуль. Таким образом, размещение Flatten() в качестве модуля в кодировщике приводит к тому же результату.

Теперь давайте напишем код для модуля стека. Для набора данных MNIST этот модуль возвращает тензору его первоначальную форму (1,28,28).

В нашем случае, поскольку изображения черно-белые, есть только один канал, но давайте создадим модуль стека, который может работать и с цветными изображениями.

Для этого мы должны хранить информацию о наборе данных в самой модели. Это делается путем передачи исходной формы данных в качестве параметров в модуль. Этими параметрами являются каналы, высота и ширина. В этом случае прямая операция будет операцией view, аналогичной модулю Flatten.

Чтобы сохранить эти параметры, нам нужно использовать функцию __init__. Это позволяет нам хранить эти параметры как переменные класса. Для этого мы сначала инициализируем его как модуль PyTorch, и это делается путем вызова super(self,Stack).__init__() в функции __init__.

Теперь, когда мы абстрагировали эти функции преобразования в их собственные объекты, мы можем использовать nn.Sequential для определения этих операций как части модулей кодировщика и декодера.

Точно так же операции по изменению формы являются частью self.encoder и self.decoder.

Давайте запустим наши модульные тесты, чтобы убедиться, что код работает.

Хороший! Тест проходит, и код работает, как ожидалось.



Создание сверточного кодировщика

С этими изменениями мы можем начать создавать Conv_VAE. Начнем с энкодера.

Первая часть кодировщика представляет собой последовательные шаги из Conv2d слоев вместе с ReLU активациями и BatchNorm2d для ускорения обучения. На этом шаге выполняется извлечение признаков при уменьшении размера.

Следующая часть — это шаг сглаживания, чтобы преобразовать вектор обратно в одно измерение. Скрытое представление в VAE — это единый вектор, который нам нужен, чтобы привести входные данные к той же форме. Это можно сделать с помощью модуля Flatten(), который мы определили ранее. Просто импортируйте его из файла VAE, и мы сможем использовать его в кодировщике.

Переходим к декодеру

Архитектура декодера действительно похожа. Это в основном то же самое, но в противоположном направлении.

1. Слой прямой связи (nn.Linear)

2. Модуль Stack для преобразования линейного слоя в 2D-формы с каналами.

3. ConvTranspose2d слоев, повышающих дискретизацию изображений и генерирующих изображения большей высоты и ширины. (Напротив Conv2d)

4. Слой Conv2d для очистки окончательного вывода

Последний слой должен выводить что-то, что имеет тот же размер, что и исходная форма, и потери MSE можно легко применить.

И верьте или нет, мы закончили здесь! Классы Python наследуют все методы по умолчанию, поэтому все остальные функции за пределами __init__ не нужно определять заново. Это означает, что все, от обучения, проверки и даже save_images, будет автоматически представлено для использования в новом Conv VAE.

Результаты!

Изображения, полученные с помощью Convolutional VAE, кажутся более определенными, и в изображениях больше изменчивости. Несмотря на использование того же измерения скрытого пространства, новая модель способна захватывать и воссоздавать изображения с большим разнообразием. Это демонстрирует мощь улучшенной сети кодер-декодер, и эта разница будет еще более заметной применительно к цветным изображениям.

В следующем (и последнем) разделе я рассмотрю шаги, необходимые для полного развертывания модели на Heroku и создания игровой площадки для взаимодействия с ними!