DataLoader портит преобразованные данные

Я тестирую набор данных MNIST в Pytorch, и после того, как я применил преобразование к данным X, кажется, что DataLoader выводит все значения из исходного порядка, что потенциально может испортить шаг обучения.

Мое преобразование состоит в том, чтобы разделить все значения на 255. Следует заметить, что само преобразование не меняет позиций, как показано на первых диаграммах рассеяния. Но после того, как данные переданы в DataLoader и я извлек их обратно, они вышли из строя. Если я не сделаю трансформацию, все в порядке (не показано). Распределение значений одинаково между before, after1 (деленное на 255 / до DataLoader) и after2 (деленное на 255 / после DataLoader ) (также не показан), кажется, что это влияет только на порядок.

import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

before = train.data[0]

train.data = train.data.float()/255
after1 = train.data[0]

train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')

Я знаю, что это может быть не лучший способ справиться с этими данными (transforms.Normalize должен решить эту проблему), но мне действительно хотелось бы понять, что происходит.

Спасибо!


person Denny Ceccon    schedule 13.09.2019    source источник


Ответы (2)


Итак ... Я разместил этот же вопрос на странице Pytorch GitHub, и они ответили на следующий:

Это не связано с загрузчиком данных. Вы возитесь с атрибутом конкретного объекта набора данных, однако фактический __getitem__ этого объекта делает гораздо больше: https://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92

В частности, эта строка (mode='L') предполагает ввод uint8. Поскольку вы заменили его на float, это неверно.

Тогда я думаю, что предпочтительным подходом было бы применить преобразование при подготовке набора данных в самом начале моего кода.

person Denny Ceccon    schedule 14.09.2019
comment
Этот набор данных MNIST(VisionDataset): намного хуже, чем вы могли ожидать. У меня есть аналогичный ответ, который может оказаться полезным. - person prosti; 18.09.2019

Изначально я не тестировал написанный вами код. Переписал оригинал:

import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

dl = DataLoader(train)

images = dl.dataset.data.float()/255
labels = dl.dataset.targets

train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))

before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]

# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')
person prosti    schedule 13.09.2019
comment
Это не сработало. Аргумент shuffle относится только ко всем экземплярам, ​​а не к структуре данных внутри экземпляра. Фактически, именно эта внутренняя структура делает возможным любое предсказание, перетасовывать его не имеет смысла. Кстати, оставление shuffle = True не влияет на структуру данных, если я не применяю это преобразование. (Также по умолчанию shuffle = False.) - person Denny Ceccon; 14.09.2019
comment
D, я не мог понять, почему исходный код не работал, если вы разберетесь, дайте мне знать. - person prosti; 14.09.2019