Я хотел бы управлять своим обучением с помощью tf.estimator.Estimator
, но возникли проблемы с его использованием вместе с API tf.data
.
У меня есть что-то вроде этого:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
Поскольку я не могу использовать make_one_shot_iterator
для своего варианта использования, моя проблема в том, что input_fn
содержит итератор, который должен быть инициализирован в model_fn
(здесь я использую tf.train.Scaffold
для инициализации локальных операций).
Кроме того, я понял, что мы не можем использовать только input_fn = iterator.get_next
, иначе другие операции не будут добавлены в тот же граф.
Каков рекомендуемый способ инициализации итератора?
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
в input_fn()? - person reese0106   schedule 14.02.2018input_fn()
непосредственно передreturn iterator.get_next()
. - person guillaumekln   schedule 14.02.2018