Я хочу обучить свой режим с помощью tf.estimator.Estimator и загрузить свои данные с помощью Dataset API. Поскольку мои данные, например 'mnist', представляют собой массив (тензор), я пытаюсь загрузить его с помощью 'tf.data. Dataset.from_tensor_slices, но я не знаю, как инициализировать make_initializable_iterator внутри input_fn.
Если я могу использовать make_one_shot_iterator для успешной тренировки, но он медленно загружается перед тренировкой. И 《API более высокого уровня в TensorFlow》 является хороший пример make_initializable_iterator внутри input_fn, но он должен возвращать iterator_initializer_hook другой функции из input_fn. Я хочу знать, есть ли другой способ лучше или элегантнее?
def input_fn():
mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)
# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
iterator = dataset.make_one_shot_iterator()
next_example = iterator.get_next()
# Set runhook to initialize iterator
return next_example