Как определить метод __len__ для PyTorch Dataloader, когда у меня есть отдельные наборы данных длины?

В настоящее время я загружаю свои данные с помощью одного класса набора данных. В наборе данных я раздельно разделяю данные поезда, тестирования и проверки. Например:

class Data():
    def __init__(self):
        self.load()

    def load(self):
        with open(file=file_name, mode='r') as f:
            self.data = f.readlines()

        self.train = self.data[:checkpoint]
        self.valid = self.data[checkpoint:halfway]
        self.test = self.data[halfway:]

Многие детали опущены для удобства чтения. По сути, я читаю один большой набор данных и делаю разбиения вручную.

Мой вопрос возникает из-за того, как переопределить метод __len__, когда длина моего поезда, действительных и тестовых данных различается?

Причина, по которой я хочу это сделать, заключается в том, что я хочу сохранить разделенные данные в одном классе, и я также хочу создать отдельные загрузчики данных для каждого и что-то вроде:

def __len__(self):
    return len(self.train)

не подходит для self.test и self.valid.

Возможно, я в корне неправильно понимаю Dataloader, но как мне подойти к этой проблеме? Заранее спасибо.


person Seankala    schedule 08.12.2019    source источник


Ответы (1)


Я думаю, что наиболее подходящий метод для получения длины каждого разделения - это просто использовать:

# Number of training points
len(self.train)

# Number of testing points
len(self.test)

# Number of validation points
len(self.valid)

В качестве альтернативы, если вы хотите указать длину разбиений для конкретного экземпляра вашего объекта:

data = Data()
print(len(data.train))
print(len(data.test))
print(len(data.valid))

__len__ позволяет реализовать способ подсчета элементов объекта. Поэтому я бы реализовал это следующим образом и использовал вышеупомянутые вызовы для получения счетчиков на разбиение:

def __len__(self):
    return len(self.data)
person Giorgos Myrianthous    schedule 08.12.2019
comment
Разве это не вызовет проблем при создании моих объектов Dataloader для каждой настройки? Если я определю __len__ так, как вы предложили, тогда я мог бы просто сделать return len(self.data), а не добавлять три, не так ли? Возможно, мне нужно изучить это более подробно, но я никогда не видел явного вызова метода __len__ при объявлении объектов Dataloader. - person Seankala; 09.12.2019
comment
@Seankala Я не видел, чтобы вы тоже инициализировали self.data. Я обновил свой ответ. - person Giorgos Myrianthous; 09.12.2019