DataLouder в PyTorch / SpykeTorch: проблема извлечения преобразованных данных

Коллеги, я работаю с нейронными сетями в PyTorch и SpykeTorch (на основе PyTorch), и мне нужно создать наборы данных изображений и поместить их в DataLouders для дальнейшей обработки. Полная процедура выглядит следующим образом:
1. генерация тензоров,
2. преобразование их с помощью torchvision.transforms.ToPILImage (),
3. сохранение созданных изображений в каталог,
4.создание ImageFolder на основе каталога с преобразованием изображений (с использованием фильтров),
5.создание DataLoader из ImageFolder.

image_set = torch.rand([10000, 28, 28], dtype=torch.float)   

path = './data/images/'  
os.makedirs(path)  
        
tTPI = torchvision.transforms.ToPILImage()   
    
for i in range(n):   
    single_image = tTPI(image_set[i])     
    image_file = path+f'pic_{i}.jpg'   
    saved_image = single_image.save(f'{path}pic_{i}.jpg')    

kernels = [ SpykeTorch.utils.DoGKernel(7,1,2),
            SpykeTorch.utils.DoGKernel(7,2,1)]
filter = SpykeTorch.utils.Filter(kernels, padding = 3, thresholds = 50)
s1 = S1Transform(filter)

RandomImageFolder = ImageFolder(root='./data/', transform = s1)  
RandomDataLoader = DataLoader(RandomImageFolder, batch_size=len(RandomImageFolder))   

Далее данные из DataLoader используются в работе (например, распознаются нейронной сетью).

for data, target in RandomDataLoader:
    prediction_X, prediction_y = predict(model, data, target)

Проблема в том, что при вытаскивании данных и меток из DataLoader возникает ошибка:

RuntimeError: Given groups = 1, weight of size [2, 1, 7, 7], expected input [1, 3, 28, 28] to have 1 channels, but got 3 channels instead

Судя по размерности [1, 2, 7, 7], ошибка возникает на этапе 4, где для преобразования используется набор фильтров.
Однако с использованием другого набора фильтров. в таких ситуациях не вызывает никаких ошибок.
Как решить проблему, не меняя фильтры?


person Дмитрий    schedule 16.11.2020    source источник


Ответы (1)


Проблема заключалась в том, что сгенерированные файлы * .jpg при загрузке в ImageFolder воспринимались как RGB и имели размер [1, 3, 28, 28] вместо [1, 1, 28, 28].

Я добавил в преобразование:

from PIL import ImageOps
 
gray_image = ImageOps.grayscale(image)
person Дмитрий    schedule 24.11.2020