Я тренирую свою нейронную сеть, используя структуру PyTorch. Данные представляют собой изображения в формате Full HD (1920x1080). Но в каждой итерации мне просто нужно вырезать из этих изображений случайный патч 256x256. Моя сеть относительно мала (5 конверсионных слоев), и, следовательно, узкое место вызвано загрузкой данных. Я предоставил свой текущий код ниже. Есть ли способ оптимизировать загрузку данных и ускорить обучение?
Код:
from pathlib import Path
import numpy
import skimage.io
import torch.utils.data as data
import Imath
import OpenEXR
class Ours(data.Dataset):
"""
Loads patches of resolution 256x256. Patches are selected such that they contain atleast 1 unknown pixel
"""
def __init__(self, data_dirpath, split_name, patch_size):
super(Ours, self).__init__()
self.dataroot = Path(data_dirpath) / split_name
self.video_names = []
for video_path in sorted(self.dataroot.iterdir()):
for i in range(4):
for j in range(11):
view_num = i * 12 + j
self.video_names.append((video_path.stem, view_num))
self.patch_size = patch_size
return
def __getitem__(self, index):
video_name, view_num = self.video_names[index]
patch_start_pt = (numpy.random.randint(1080), numpy.random.randint(1920))
frame1_path = self.dataroot / video_name / f'render/rgb/{view_num + 1:04}.png'
frame2_path = self.dataroot / video_name / f'render/rgb/{view_num + 2:04}.png'
depth_path = self.dataroot / video_name / f'render/depth/{view_num + 1:04}.exr'
mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
frame1 = self.get_image(frame1_path, patch_start_pt)
frame2 = self.get_image(frame2_path, patch_start_pt)
mask = self.get_mask(mask_path, patch_start_pt)
depth = self.get_depth(depth_path, patch_start_pt, mask)
data_dict = {
'frame1': frame1,
'frame2': frame2,
'mask': mask,
'depth': depth,
}
return data_dict
def __len__(self):
return len(self.video_names)
@staticmethod
def get_mask(path: Path, patch_start_point: tuple):
h, w = patch_start_point
mask = skimage.io.imread(path.as_posix())[h:h + self.patch_size, w:w + self.patch_size][None]
return mask
def get_image(self, path: Path, patch_start_point: tuple):
h, w = patch_start_point
image = skimage.io.imread(path.as_posix())
image = image[h:h + self.patch_size, w:w + self.patch_size, :3]
image = image.astype(numpy.float32) / 255 * 2 - 1
image_cf = numpy.moveaxis(image, [0, 1, 2], [1, 2, 0])
return image_cf
def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
h, w = patch_start_point
exrfile = OpenEXR.InputFile(path.as_posix())
raw_bytes = exrfile.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
height = exrfile.header()['displayWindow'].max.y + 1 - exrfile.header()['displayWindow'].min.y
width = exrfile.header()['displayWindow'].max.x + 1 - exrfile.header()['displayWindow'].min.x
depth = numpy.reshape(depth_vector, (height, width))
depth = depth[h:h + self.patch_size, w:w + self.patch_size]
depth = depth[None]
depth = depth.astype(numpy.float32)
depth = depth * mask
return depth
Наконец, я создаю DataLoader следующим образом:
train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
Что я уже пробовал:
- Я искал, можно ли прочитать часть изображения. К сожалению, никаких зацепок я не получил. Похоже, библиотеки Python читают полное изображение.
- Я планирую читать больше исправлений с одного изображения, поэтому мне нужно будет читать меньше изображений. Но в среде PyTorch функция
get_item()
должна возвращать один образец, а не пакет. Итак, в каждомget_item()
я могу прочитать только патч. - Я планирую обойти это следующим образом: прочитать 4 исправления в
get_item()
и вернуть исправления формы(4,3,256,256)
вместо(3,256,256)
. Позже, когда я прочитаю пакет с помощью загрузчика данных, я получу пакет формы(BS,4,3,256,256)
вместо(BS,3,256,256)
. Затем я могу объединить данные поdim=1
, чтобы преобразовать(BS,4,3,256,256)
в(BS*4,3,256,256)
. Таким образом я могу уменьшитьbatch_size
(BS
) на 4 и, надеюсь, это ускорит загрузку данных в 4 раза.
Есть ли другие варианты? Я открыт для всех видов предложений. Спасибо!