pythorch-lightning train_dataloader исчерпывает данные

Я начал использовать pytorch-lightning и столкнулся с проблемой моих пользовательских загрузчиков данных:

Я использую собственный набор данных и общий torch.utils.data.DataLoader. Обычно набор данных выбирает путь и загружает данные, соответствующие заданному индексу, который загружает загрузчик данных.

def train_dataloader(self):
    train_set = TextKeypointsDataset(parameters...)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
    return train_loader 

Когда я использую модули pytorch-lightning train_dataloader и training_step, все работает нормально. Когда я добавляю val_dataloader и validation_step, я сталкиваюсь с этой ошибкой:

Epoch 1:  45%|████▌     | 10/22 [00:02<00:03,  3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)

В этом случае мой набор данных действительно мал (для проверки функциональности) из 84 образцов, размер моего пакета равен 8. Набор данных для обучения и проверки имеет одинаковую длину (снова только для целей тестирования).

Таким образом, всего 84 * 2 = 168 и 168/8 (размер партии) = 21, что примерно соответствует общему количеству шагов (22), показанных выше. Это означает, что после запуска набора обучающих данных 10 раз (10 * 8 = 80) загрузчик ожидает новую полную выборку из 8, но, поскольку есть только 84 образца, я получаю сообщение об ошибке (по крайней мере, это мое текущее понимание).

Я столкнулся с подобной проблемой в своей собственной реализации (без использования pytorch-lighntning) и использовал этот шаблон для ее решения. В основном я сбрасываю итератор, когда заканчиваются данные:

try:
    data = next(data_iterator)
    source_tensor = data[0]
    target_tensor = data[1]

except StopIteration:  # reinitialize data loader if num_iteration > amount of data
    data_iterator = iter(data_loader)

Прямо сейчас кажется, что я сталкиваюсь с чем-то похожим? Я не знаю, как сбросить / повторно инициализировать загрузчик данных в pytorch-lightning, когда в моем training_dataloader заканчиваются данные. Думаю, должен быть еще один изощренный способ, с которым я не знаком. Спасибо


person Asdf11    schedule 25.05.2020    source источник
comment
Реализация собственного Dataset довольно стандартна, но определение пользовательского _2 _, вероятно, является ошибкой, поскольку он выполняет всевозможные сложные вещи на бэкэнде (многопоточность и т. д.). В самых крайних случаях вы должны иметь возможность определить свой собственный Sampler и, возможно, collate_fn (при необходимости), оба из которых будут предоставлены вашему DataLoader при строительстве.   -  person jodag    schedule 25.05.2020
comment
Я отредактировал свой вопрос, чтобы было понятнее. Я использую собственный набор данных, но не пользовательский загрузчик данных   -  person Asdf11    schedule 26.05.2020


Ответы (1)


Решение было:

Я использовал source_tensor = source_tensor.view(-1, self.batch_size, self.input_size), что позже привело к некоторым ошибкам, теперь я использую source_tensor = source_tensor.permute(1, 0, 2), что устранило проблему.

person Asdf11    schedule 26.05.2020