Проще говоря, PyTorch Lightning — это дополнение к PyTorch, которое значительно упрощает модели обучения. Модули PyTorch Lightning имеют методы класса по умолчанию, которые могут уменьшить количество ненужного шаблонного кода, необходимого при обучении модели. При проведении экспериментов мы обычно хотим внести небольшие изменения, что может быть очень сложно для больших моделей. Помимо оптимизации кода, Pytorch Lightning также содержит множество полезных функций, таких как автоматическое определение скорости обучения. Давайте посмотрим, как перевести код Pytorch в Pytorch Lightning.

Этапы обучения и проверки

В стандартном классе PyTorch необходимо определить только два метода: метод __init__, определяющий архитектуру модели, и метод forward, определяющий прямой проход. Все остальные операции, такие как загрузка набора данных, обучение и проверка, являются функциями, выполняемыми вне класса.

В PyTorch Lightning мы используем training_step для определения операций, которые происходят на этапе обучения. Это может включать такие операции, как регистрация, расчет потерь и обратное распространение. Однако при использовании PyTorch Lightning требуется значительно меньше кода, и довольно просто повторно использовать уже определенные методы класса.

Метод training_step будет вызываться PL. На основе потерь, рассчитанных в training_step , PL проведет обратное распространение потерь, рассчитает градиент и оптимизирует веса модели. В коде выше мы видим, что есть метод configure_optimizers, это простой метод класса, который возвращает оптимизатор, который будет использоваться для обучения модели. В отличие от обычного PyTorch, для которого требуются optimizer.step() и loss.backward(), эти шаги абстрагируются и автоматически выполняются как часть метода training_step.

Выполнение шага проверки — это почти тот же код. Потеря проверки будет рассчитана и возвращена функцией обратного распространения, а обновления параметров не будут выполняться.

PL тренер

Еще одним важным дополнением к пониманию того, как использовать PyTorch Lightning, является класс Trainer. Этот класс необходим для подгонки любых данных к модели. С помощью trainer.fit(model,dataloader) мы указываем, на каких данных обучать модель. При создании экземпляра объекта тренера также есть множество различных параметров, которые мы можем передать для определения процесса обучения.

  • Параметры обучения
  • Количество эпох
  • Количество графических процессоров для обучения
  • auto_lr_find — автоматически определяет, какую скорость обучения использовать. Действительно полезно для быстрого определения базовых показателей и обучения модели без дополнительных экспериментов.

Загрузка набора данных

Для случаев, когда источник данных для обучения, проверки и тестирования фиксирован, мы можем дополнительно расширить LightningModule, определив DataLoaders внутри класса. Затем PyTorch Lightning автоматически получит данные от соответствующих Dataloaders и использует это для обучения моделей.

Обратите внимание, что мы можем просто использовать trainer.fit(model) без указания DataLoader для обучения или проверки. В этом случае, поскольку MNIST является общим набором данных, код для train_dataloader относительно короткий, но для более сложных примеров этот шаг может включать несколько операций предварительной обработки, прежде чем преобразовать загрузчик данных в подходящий формат для обучения.

Дополнительные методы и обратные вызовы

Теперь вы можете подумать про себя.

«Проблема, которую я пытаюсь решить, не так проста, что, если мне нужно:»

  • Меняйте набор данных каждую эпоху
  • Уменьшить скорость обучения
  • Сохраните часть вывода

К счастью, базовый класс pl.LightningModule имеет множество дополнительных методов, помогающих выполнять любые операции, которые могут быть полезны при обучении вашей нейронной сети. Здесь мы реализуем validation_epoch_end. Это будет сохранять образец изображений после каждой эпохи, чтобы мы могли проверить, правильно ли обучается наша модель VAE.

validation_epoch_end принимает все выходные данные из validation_step. В validation_step результат прямого вызова возвращается вместе с потерей. Мы можем просто взять образец из x_out, изменить его форму до нужного размера и сохранить массив как изображение.

Помимо использования дополнительных методов класса, еще одним способом настроить процесс обучения является использование обратных вызовов. Обратные вызовы содержат очень похожие методы по сравнению с функциями класса, но более настраиваемые. Обратные вызовы следует использовать, если они вызываются в нескольких разных точках в течение цикла обучения, чтобы предотвратить повторение одних и тех же фрагментов кода. Существуют также готовые обратные вызовы, которые могут выполнять стандартные операции, такие как сохранение контрольных точек модели или ранняя остановка.

Отслеживание тензорной доски

Если вы заметили ранее, для этапов обучения и проверки мы используем self.log(). PyTorch Lightning позволяет чрезвычайно легко регистрировать различные показатели во время обучения, и к этим данным можно легко получить доступ с помощью TensorBoard. Все значения, зарегистрированные во время цикла обучения, будут храниться в lightning_logs.

Запустив tensorboard --logdir lightning_logs/ в терминале, можно визуализировать и отслеживать все зарегистрированные показатели/потери. Это поможет вам отслеживать различные эксперименты и настраивать некоторые используемые гиперпараметры.

Окончательный результат

Сравнивая изображения из первой и последней эпох, мы видим, что модель успешно научилась на изображениях. На левом изображении большинство элементов нечеткие, а цифры на правом изображении значительно четче. Важно отметить, что изображения все еще довольно размыты. Одна интересная особенность VAE заключается в том, что их можно использовать вместе с архитектурой выбора функций, которая лучше подходит для этой задачи. В этом примере мы используем сглаженное векторное представление цифры MNIST, но это определенно не лучший подход. Для задач компьютерного зрения использование нескольких слоев свертки позволит лучше извлекать признаки и поможет модели достичь значительно лучших результатов. В этом случае мы все еще можем добиться достойной производительности просто потому, что MNIST — это «простой» набор данных. Использование сверточной VAE позволит достичь значительно более высокой производительности для рекреационных потерь просто потому, что функции, извлеченные из узкого места, будут более полезными.

Чтобы построить VAE для изображений, мы можем начать с нескольких слоев шагов свертки, взять функцию max, чтобы сгладить вектор, и использовать этот вектор для получения векторов μ и σ (см. Часть 1) в узком месте. Архитектура нейронной сети, которая выполняет регуляризацию для VAE, встречается только в узком месте, и этот компонент может использоваться в других архитектурах нейронных сетей.

Помимо изображений, VAE были успешно применены ко многим различным наборам данных и достигли довольно замечательных результатов в задачах обработки естественного языка (NLP). Более поздние исследования VAE также привели к созданию новых архитектур, таких как MMD-VAE и VQ-VAE, которые обеспечивают еще более высокую производительность. Перейдите по ссылкам ниже для получения дополнительной информации о различных архитектурах VAE.

Не стесняйтесь проверить полный код на GitHub, и мы будем очень признательны за любые отзывы!

Github: https://github.com/reoneo97/pytorch-lightning-vae
LinkedIn: https://www.linkedin.com/in/reo-neo/

Полезные ресурсы

(1) Реализация PyTorch различных архитектур VAE



Полезная компиляция различных архитектур VAE, показывающая соответствующую реализацию PyTorch и результаты.

(2) Обучение нейронному дискретному представлению



Статья о векторно-квантованном VAE (VQ-VAE). Интересная статья о том, как еще больше улучшить VAE для дискретных данных. Выделяет некоторые из существующих проблем с VAE и то, как VQ-VAE могут решать эти проблемы.

(3) MMD-VAE



Интересный пост в блоге о другом подходе к регуляризации распределения скрытых вероятностей без использования функции потерь KL-дивергенции.