Будучи энтузиастом исследований и любителем PyTorch (извините, Tensorflow), мне всегда нравилось использовать библиотеку PyTorch для моих проектов и рабочих процессов глубокого обучения. Как я заметил, большинство исследователей также предпочитают писать свой код на PyTorch. Но как бы мы ни любили PyTorch, мы могли бы согласиться с тем фактом, что есть несколько задач, которые можно было бы упростить, просто чтобы сэкономить время, включая, помимо прочего, создание цикла обучения, создание этапы проверки и тестирования и т. д. Я использовал Torch Snippets для одного и того же в течение очень долгого времени, просто чтобы сэкономить время, пока не узнал о PyTorch Lightning.
В тот день, когда я узнал о PyTorch Lightning, моя жизнь стала намного проще (по крайней мере, с точки зрения написания кода глубокого обучения), и мой рабочий процесс стал намного эффективнее.
В этом блоге я пишу о том, как вы можете использовать PyTorch Lightning для рабочего процесса вашего проекта, с примером реализации Variational Auto Encoder (VAE) с использованием библиотеки PyTorch Lightning.
Я также покажу вам, как это упрощает нашу жизнь по сравнению с написанием того же кода в PyTorch.
Внедрение VAE на PyTorch
Я не буду вдаваться в математику VAE, так как это выходит за рамки этой статьи, но если вам интересно узнать об этом, вы можете обратиться к ЭТОМУ блогу. На этом фоне давайте посмотрим, как реализовать VAE с помощью библиотеки PyTorch.
import torch import torch.nn as nn import torch.nn.functional as F class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(VAE, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.mu = nn.Linear(hidden_dim, latent_dim) self.logvar = nn.Linear(hidden_dim, latent_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Sigmoid() ) def encode(self, x): h = self.encoder(x) mu = self.mu(h) logvar = self.logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std return z def decode(self, z): x_hat = self.decoder(z) return x_hat def loss_function(self, x_hat, x, mu, logvar): bce_loss = F.binary_cross_entropy(x_hat, x, reduction='sum') kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return bce_loss + kld_loss def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) x_hat = self.decode(z) return x_hat, mu, logvar
Это простая реализация VAE. Обратите внимание, как у нас есть методы encode, decode, reparameterize, loss_function и forward (необходимые для реализации модуля PyTorch) в классе VAE.
Чтобы обучить модель, вам нужно перебирать DataLoader для каждой эпохи, примерно так:
import torch import torch.nn as nn import torch.nn.functional as F from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision.transforms import ToTensor vae = VAE(input_dim=784, hidden_dim=256, latent_dim=20) optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True) train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) num_epochs = 10 for epoch in range(num_epochs): for batch_idx, (x, _) in enumerate(train_dataloader): x_hat, mu, logvar = vae(x) loss = vae.loss_function(x_hat, x, mu, logvar) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 100 == 0: print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(train_dataloader)}] Loss: {loss.item():.4f}")
Кроме того, как вы, возможно, уже знаете, вы должны убедиться, что вы настраиваете оптимизатор, вычисляете потери, обновляете параметры модели на основе градиентов, вычисленных с помощью loss.backward(), отключаете вычисление градиента во время проверки (с помощью torch .no_grad()) и т. д. все самостоятельно.
Используя PyTorch Lightning, можно избежать всех этих шагов, и мы можем писать более чистый и лучший код.
Реализация VAE с помощью PyTorch Lightning
import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl class VAE(pl.LightningModule): def __init__(self, input_dim, hidden_dim, latent_dim): super(VAE, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.mu = nn.Linear(hidden_dim, latent_dim) self.logvar = nn.Linear(hidden_dim, latent_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Sigmoid() ) def encode(self, x): h = self.encoder(x) mu = self.mu(h) logvar = self.logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std return z def decode(self, z): x_hat = self.decoder(z) return x_hat def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) x_hat = self.decode(z) return x_hat, mu, logvar def loss_function(self, x_hat, x, mu, logvar): bce_loss = F.binary_cross_entropy(x_hat, x, reduction='sum') kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return bce_loss + kld_loss def training_step(self, batch, batch_idx): x, _ = batch x_hat, mu, logvar = self(x) loss = self.loss_function(x_hat, x, mu, logvar) self.log('train_loss', loss) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer
Класс VAE наследуется от pl.LightningModule вместо nn.Module. Посмотрите, у нас в классе определены два дополнительных метода: training_step и configure_optimizers. Эти методы являются встроенными методами pl.LightningModule, которые используются для определения шага обучения и настройки оптимизатора.
Чтобы обучить модель сейчас, мы можем использовать следующий код:
from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import pytorch_lightning as pl train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True) train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) vae = VAE(input_dim=784, hidden_dim=256, latent_dim=20) trainer = pl.Trainer(max_epochs=10) trainer.fit(vae, train_dataloader)
Нам не нужно перебирать DataLoader для каждой эпохи, чтобы обучить модель. Настройки градиента и т. д. обрабатываются объектом тренера автоматически, нам не нужно об этом беспокоиться. Нам просто нужно определить объект модели, создать объект pl.Trainer с желаемым количеством эпох и вызвать метод fit, и ВУАЛЯ, начнется обучение нашей модели.
Также для сохранения модели мы можем использовать следующий код:
from pytorch_lightning.callbacks import ModelCheckpoint # Define a callback to save the model after every epoch checkpoint_callback = ModelCheckpoint( monitor='val_loss', dirpath='checkpoints/', filename='vae-{epoch:02d}-{val_loss:.2f}', save_top_k=-1, # save all checkpoints mode='min', save_last=True # save the model after every epoch ) # Create a trainer and fit the model trainer = pl.Trainer(callbacks=[checkpoint_callback]) trainer.fit(model, train_dataloader, val_dataloader)
В этом примере мы определяем обратный вызов ModelCheckpoint
и устанавливаем параметр save_last=True
для сохранения модели после каждой эпохи. Мы также установили save_top_k=-1
, чтобы сохранить все файлы контрольных точек, а не только лучший.
Во время обучения PyTorch Lightning автоматически сохраняет модель после каждой эпохи в указанный каталог с форматом имени файла, который включает номер эпохи и потерю проверки. Окончательный файл модели также будет сохранен в том же формате, но с номером эпохи, установленным на «-1».
Используя обратный вызов ModelCheckpoint
в PyTorch Lightning, вы можете легко сохранять модель после каждой эпохи во время обучения и обеспечивать доступ ко всем файлам промежуточных контрольных точек на случай, если вам потребуется загрузить предыдущее состояние модели.
Наряду со всем этим мы также получаем преимущество от индикатора выполнения, которого мы не могли получить только с PyTorch, если бы не использовали какие-то внешние библиотеки. Может показаться, что это не имеет большого значения, но для таких любителей PyTorch, как я, это имеет огромное значение.
PyTorch Lightning ТАК ХОРОШО, ПРАВДА?
Спасибо, что читаете :)
Я Data Scientist с сильной страстью к компьютерному зрению и глубокому обучению. Мне нравится работать над решением задач с помощью искусственного интеллекта. Если вам понравился этот блог, не забудьте подписаться на меня в СРЕДНЕМ. Также вы можете найти меня в LinkedIn, давайте создадим сообщество вместе ❤.