Как использовать инициализируемые итераторы tf.data в input_fn tf.estimator?

Я хотел бы управлять своим обучением с помощью tf.estimator.Estimator, но возникли проблемы с его использованием вместе с API tf.data.

У меня есть что-то вроде этого:

def model_fn(features, labels, params, mode):
  # Defines model's ops.
  # Initializes with tf.train.Scaffold.
  # Returns an tf.estimator.EstimatorSpec.

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.

  iterator = dataset.make_initializable_iterator()

  return iterator.get_next()

estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)

Поскольку я не могу использовать make_one_shot_iterator для своего варианта использования, моя проблема в том, что input_fn содержит итератор, который должен быть инициализирован в model_fn (здесь я использую tf.train.Scaffold для инициализации локальных операций).

Кроме того, я понял, что мы не можем использовать только input_fn = iterator.get_next, иначе другие операции не будут добавлены в тот же граф.

Каков рекомендуемый способ инициализации итератора?


person guillaumekln    schedule 10.07.2017    source источник
comment
@guillaumeklin -- вы добавили tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) в input_fn()?   -  person reese0106    schedule 14.02.2018
comment
Да, вы можете добавить эту строку в input_fn() непосредственно перед return iterator.get_next().   -  person guillaumekln    schedule 14.02.2018


Ответы (1)


Начиная с TensorFlow 1.5, можно заставить input_fn возвращать tf.data.Dataset, например:

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.
  return dataset

См. c294fcfd.


Для предыдущих версий вы можете добавить инициализатор итератора в коллекции tf.GraphKeys.TABLE_INITIALIZERS и полагаться на инициализатор по умолчанию.

tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
person guillaumekln    schedule 10.07.2017
comment
Спасибо! +1. Просто чтобы уточнить ответ: нужно добавить строку tf.add_to_collection... перед возвратом input_fn(), и тогда все работает нормально, и не нужно ничего делать с Scaffold и local_init_ops. - person Pekka; 12.12.2017
comment
Простите, а можно ли указать имена для каждого поля набора данных, используя первый способ? Например, в моем наборе данных есть 2 поля: возраст и пол, и я хочу, чтобы словарь выглядел так: {возраст: тензор1, пол: тензор2}. - person soloice; 09.10.2018
comment
@Pekka @guillaumekln вы добавили строку tf.add_to_collection(...) в def input_fn() или где-то еще в model_fn()? Если бы это было добавлено в model_fn(), то строка по-прежнему была бы tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) или нужно было бы изменить iterator.initializer на что-то другое? - person reese0106; 23.10.2018
comment
Вы должны добавить его в input_fn() сразу после создания итератора. - person guillaumekln; 23.10.2018