Как использовать итератор make_initializable_iterator тензорного потока внутри input_fn?

Я хочу обучить свой режим с помощью tf.estimator.Estimator и загрузить свои данные с помощью Dataset API. Поскольку мои данные, например 'mnist', представляют собой массив (тензор), я пытаюсь загрузить его с помощью 'tf.data. Dataset.from_tensor_slices, но я не знаю, как инициализировать make_initializable_iterator внутри input_fn.

Если я могу использовать make_one_shot_iterator для успешной тренировки, но он медленно загружается перед тренировкой. И 《API более высокого уровня в TensorFlow》 является хороший пример make_initializable_iterator внутри input_fn, но он должен возвращать iterator_initializer_hook другой функции из input_fn. Я хочу знать, есть ли другой способ лучше или элегантнее?

    def input_fn():

    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset iterator
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    iterator = dataset.make_one_shot_iterator()
    next_example = iterator.get_next()
    # Set runhook to initialize iterator

    return next_example

person liushui    schedule 05.02.2018    source источник


Ответы (2)


В TensorFlow версии 1.5 и новее tf.estimator.Estimator автоматически создаст и инициализирует инициализируемый итератор, когда вы вернете tf.data.Dataset из своего input_fn. Это позволяет вам написать следующий код, не беспокоясь об инициализации или перехватах:

def input_fn():
    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset.
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    return dataset
person mrry    schedule 05.02.2018

Внутри вашего кода добавьте это:

      self.hooks.append(utils_hooks.DatasetHook(iter))

В run_loop.py перед вызовом в свой fn добавьте это

 for hook in dataset_hooks:
        sess.run(hook.iterator().initializer)

Тогда все должно быть в порядке.

person Y00    schedule 22.07.2019