Оптимизируйте загрузчик данных pytorch для чтения небольших фрагментов изображений в формате Full HD.

Я тренирую свою нейронную сеть, используя структуру PyTorch. Данные представляют собой изображения в формате Full HD (1920x1080). Но в каждой итерации мне просто нужно вырезать из этих изображений случайный патч 256x256. Моя сеть относительно мала (5 конверсионных слоев), и, следовательно, узкое место вызвано загрузкой данных. Я предоставил свой текущий код ниже. Есть ли способ оптимизировать загрузку данных и ускорить обучение?

Код:

from pathlib import Path

import numpy
import skimage.io
import torch.utils.data as data

import Imath
import OpenEXR


class Ours(data.Dataset):
    """
    Loads patches of resolution 256x256. Patches are selected such that they contain atleast 1 unknown pixel
    """

    def __init__(self, data_dirpath, split_name, patch_size):
        super(Ours, self).__init__()
        self.dataroot = Path(data_dirpath) / split_name
        self.video_names = []
        for video_path in sorted(self.dataroot.iterdir()):
            for i in range(4):
                for j in range(11):
                    view_num = i * 12 + j
                    self.video_names.append((video_path.stem, view_num))
        self.patch_size = patch_size
        return

    def __getitem__(self, index):
        video_name, view_num = self.video_names[index]

        patch_start_pt = (numpy.random.randint(1080), numpy.random.randint(1920))

        frame1_path = self.dataroot / video_name / f'render/rgb/{view_num + 1:04}.png'
        frame2_path = self.dataroot / video_name / f'render/rgb/{view_num + 2:04}.png'
        depth_path = self.dataroot / video_name / f'render/depth/{view_num + 1:04}.exr'
        mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
        frame1 = self.get_image(frame1_path, patch_start_pt)
        frame2 = self.get_image(frame2_path, patch_start_pt)
        mask = self.get_mask(mask_path, patch_start_pt)
        depth = self.get_depth(depth_path, patch_start_pt, mask)

        data_dict = {
            'frame1': frame1,
            'frame2': frame2,
            'mask': mask,
            'depth': depth,
        }
        return data_dict

    def __len__(self):
        return len(self.video_names)

    @staticmethod
    def get_mask(path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        mask = skimage.io.imread(path.as_posix())[h:h + self.patch_size, w:w + self.patch_size][None]
        return mask

    def get_image(self, path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        image = skimage.io.imread(path.as_posix())
        image = image[h:h + self.patch_size, w:w + self.patch_size, :3]
        image = image.astype(numpy.float32) / 255 * 2 - 1
        image_cf = numpy.moveaxis(image, [0, 1, 2], [1, 2, 0])
        return image_cf

    def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
        h, w = patch_start_point

        exrfile = OpenEXR.InputFile(path.as_posix())
        raw_bytes = exrfile.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
        depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
        height = exrfile.header()['displayWindow'].max.y + 1 - exrfile.header()['displayWindow'].min.y
        width = exrfile.header()['displayWindow'].max.x + 1 - exrfile.header()['displayWindow'].min.x
        depth = numpy.reshape(depth_vector, (height, width))

        depth = depth[h:h + self.patch_size, w:w + self.patch_size]
        depth = depth[None]
        depth = depth.astype(numpy.float32)
        depth = depth * mask
        return depth

Наконец, я создаю DataLoader следующим образом:

train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

Что я уже пробовал:

  1. Я искал, можно ли прочитать часть изображения. К сожалению, никаких зацепок я не получил. Похоже, библиотеки Python читают полное изображение.
  2. Я планирую читать больше исправлений с одного изображения, поэтому мне нужно будет читать меньше изображений. Но в среде PyTorch функция get_item() должна возвращать один образец, а не пакет. Итак, в каждом get_item() я могу прочитать только патч.
  3. Я планирую обойти это следующим образом: прочитать 4 исправления в get_item() и вернуть исправления формы (4,3,256,256) вместо (3,256,256). Позже, когда я прочитаю пакет с помощью загрузчика данных, я получу пакет формы (BS,4,3,256,256) вместо (BS,3,256,256). Затем я могу объединить данные по dim=1, чтобы преобразовать (BS,4,3,256,256) в (BS*4,3,256,256). Таким образом я могу уменьшить batch_size (BS) на 4 и, надеюсь, это ускорит загрузку данных в 4 раза.

Есть ли другие варианты? Я открыт для всех видов предложений. Спасибо!


person Nagabhushan S N    schedule 29.11.2020    source источник
comment
рассматривали ли вы нарезку патчей в качестве шага предварительной обработки, сохранение небольших изображений и чтение только соответствующего патча в get_image? если у вас должно быть случайное увеличение обрезки, вы можете сделать так, чтобы патчи перекрывались и имели размер больше 256 пикселей, это может иметь более или менее тот же эффект, что и ваше увеличение.   -  person Roni    schedule 03.05.2021
comment
альтернативным решением было бы сохранить bmps и прочитать только соответствующие части, как в этом ответе: stackoverflow.com/questions/19695249/   -  person Roni    schedule 03.05.2021
comment
Сохранение обрезанных кадров у меня не сработает, так как размер сильно увеличится. Альтернативное решение выглядит интересно. Я проверю это. Спасибо   -  person Nagabhushan S N    schedule 03.05.2021