Будучи энтузиастом исследований и любителем PyTorch (извините, Tensorflow), мне всегда нравилось использовать библиотеку PyTorch для моих проектов и рабочих процессов глубокого обучения. Как я заметил, большинство исследователей также предпочитают писать свой код на PyTorch. Но как бы мы ни любили PyTorch, мы могли бы согласиться с тем фактом, что есть несколько задач, которые можно было бы упростить, просто чтобы сэкономить время, включая, помимо прочего, создание цикла обучения, создание этапы проверки и тестирования и т. д. Я использовал Torch Snippets для одного и того же в течение очень долгого времени, просто чтобы сэкономить время, пока не узнал о PyTorch Lightning.
В тот день, когда я узнал о PyTorch Lightning, моя жизнь стала намного проще (по крайней мере, с точки зрения написания кода глубокого обучения), и мой рабочий процесс стал намного эффективнее.

В этом блоге я пишу о том, как вы можете использовать PyTorch Lightning для рабочего процесса вашего проекта, с примером реализации Variational Auto Encoder (VAE) с использованием библиотеки PyTorch Lightning.
Я также покажу вам, как это упрощает нашу жизнь по сравнению с написанием того же кода в PyTorch.

Внедрение VAE на PyTorch

Я не буду вдаваться в математику VAE, так как это выходит за рамки этой статьи, но если вам интересно узнать об этом, вы можете обратиться к ЭТОМУ блогу. На этом фоне давайте посмотрим, как реализовать VAE с помощью библиотеки PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.mu = nn.Linear(hidden_dim, latent_dim)
        self.logvar = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    def loss_function(self, x_hat, x, mu, logvar):
        bce_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return bce_loss + kld_loss
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

Это простая реализация VAE. Обратите внимание, как у нас есть методы encode, decode, reparameterize, loss_function и forward (необходимые для реализации модуля PyTorch) в классе VAE.

Чтобы обучить модель, вам нужно перебирать DataLoader для каждой эпохи, примерно так:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

vae = VAE(input_dim=784, hidden_dim=256, latent_dim=20)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (x, _) in enumerate(train_dataloader):
        x_hat, mu, logvar = vae(x)
        loss = vae.loss_function(x_hat, x, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(train_dataloader)}] Loss: {loss.item():.4f}")

Кроме того, как вы, возможно, уже знаете, вы должны убедиться, что вы настраиваете оптимизатор, вычисляете потери, обновляете параметры модели на основе градиентов, вычисленных с помощью loss.backward(), отключаете вычисление градиента во время проверки (с помощью torch .no_grad()) и т. д. все самостоятельно.

Используя PyTorch Lightning, можно избежать всех этих шагов, и мы можем писать более чистый и лучший код.

Реализация VAE с помощью PyTorch Lightning

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


class VAE(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.mu = nn.Linear(hidden_dim, latent_dim)
        self.logvar = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar
    
    def loss_function(self, x_hat, x, mu, logvar):
        bce_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return bce_loss + kld_loss

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, mu, logvar = self(x)
        loss = self.loss_function(x_hat, x, mu, logvar)
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Класс VAE наследуется от pl.LightningModule вместо nn.Module. Посмотрите, у нас в классе определены два дополнительных метода: training_step и configure_optimizers. Эти методы являются встроенными методами pl.LightningModule, которые используются для определения шага обучения и настройки оптимизатора.

Чтобы обучить модель сейчас, мы можем использовать следующий код:

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import pytorch_lightning as pl


train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

vae = VAE(input_dim=784, hidden_dim=256, latent_dim=20)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(vae, train_dataloader)

Нам не нужно перебирать DataLoader для каждой эпохи, чтобы обучить модель. Настройки градиента и т. д. обрабатываются объектом тренера автоматически, нам не нужно об этом беспокоиться. Нам просто нужно определить объект модели, создать объект pl.Trainer с желаемым количеством эпох и вызвать метод fit, и ВУАЛЯ, начнется обучение нашей модели.

Также для сохранения модели мы можем использовать следующий код:

from pytorch_lightning.callbacks import ModelCheckpoint

# Define a callback to save the model after every epoch
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='vae-{epoch:02d}-{val_loss:.2f}',
    save_top_k=-1,   # save all checkpoints
    mode='min',
    save_last=True   # save the model after every epoch
)

# Create a trainer and fit the model
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader, val_dataloader)

В этом примере мы определяем обратный вызов ModelCheckpoint и устанавливаем параметр save_last=True для сохранения модели после каждой эпохи. Мы также установили save_top_k=-1, чтобы сохранить все файлы контрольных точек, а не только лучший.

Во время обучения PyTorch Lightning автоматически сохраняет модель после каждой эпохи в указанный каталог с форматом имени файла, который включает номер эпохи и потерю проверки. Окончательный файл модели также будет сохранен в том же формате, но с номером эпохи, установленным на «-1».

Используя обратный вызов ModelCheckpoint в PyTorch Lightning, вы можете легко сохранять модель после каждой эпохи во время обучения и обеспечивать доступ ко всем файлам промежуточных контрольных точек на случай, если вам потребуется загрузить предыдущее состояние модели.

Наряду со всем этим мы также получаем преимущество от индикатора выполнения, которого мы не могли получить только с PyTorch, если бы не использовали какие-то внешние библиотеки. Может показаться, что это не имеет большого значения, но для таких любителей PyTorch, как я, это имеет огромное значение.

PyTorch Lightning ТАК ХОРОШО, ПРАВДА?
Спасибо, что читаете :)

Я Data Scientist с сильной страстью к компьютерному зрению и глубокому обучению. Мне нравится работать над решением задач с помощью искусственного интеллекта. Если вам понравился этот блог, не забудьте подписаться на меня в СРЕДНЕМ. Также вы можете найти меня в LinkedIn, давайте создадим сообщество вместе ❤.