Как использовать подаваемый итератор из API набора данных Tensorflow вместе с MonitoredTrainingSession?

Руководство программиста Tensorflow рекомендует использовать итератор с возможностью передачи для переключения между набором данных для обучения и проверки без повторной инициализации итератора. . В основном требуется кормить ручку, чтобы выбрать между ними.

Как использовать его вместе с tf.train.MonitoredTrainingSession?

Следующий метод завершается с ошибкой «RuntimeError: Graph завершен и не может быть изменен». ошибка.

with tf.train.MonitoredTrainingSession() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

Как добиться одновременного удобства MonitoredTrainingSession и повторения наборов данных для обучения и проверки?


person Michael Jaison G    schedule 08.09.2017    source источник


Ответы (3)


Я получил ответ из проблемы Tensorflow GitHub - https://github.com/tensorflow/tensorflow/issues/12859

Решение состоит в том, чтобы вызвать iterator.string_handle() перед созданием MonitoredSession.

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator

dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()

with tf.train.MonitoredTrainingSession() as sess:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        print('train', sess.run(next_batch, feed_dict={handle: handle_train}))

        if step % 3 == 0:
            print('val', sess.run(next_batch, feed_dict={handle: handle_val}))

Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
person Michael Jaison G    schedule 09.09.2017

@Michael Jaison G ответ правильный. Однако это не работает, когда вы также хотите использовать определенные session_run_hooks, которые должны оценивать части графа, например, например. LoggingTensorHook или SummarySaverHook. Пример ниже вызовет ошибку:

import tensorflow as tf

dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()

pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()

summary_hook = tf.train.SummarySaverHook(save_steps=5,
                                         output_dir="summaries", summary_op=tf.summary.merge_all())

with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        feat = sess.run(feature, feed_dict={handle: handle_train})
        pred_ = sess.run(pred, feed_dict={handle: handle_train})
        print('train: ', feat)
        print('pred: ', pred_)

        if step % 3 == 0:
            print('val', sess.run(feature, feed_dict={handle: handle_val}))

Это завершится ошибкой:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
     [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
     [[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

Причина в том, что хук попытается оценить граф уже при первом вызове session.run([iter_train_handle, iter_val_handle]), который явно еще не содержит дескриптора в feed_dict.

Обходное решение состоит в том, чтобы перезаписать хуки, вызывающие проблему, и изменить код в before_run и after_run, чтобы оценивать только вызовы session.run, содержащие дескриптор в feed_dict (вы можете получить доступ к feed_dict текущего вызова session.run через run_context аргумент before_run и after_run)

Или вы можете использовать последний мастер Tensorflow (после 1.4), который добавляет функцию run_step_fn в MonitoredSession, которая позволяет вам указать следующий step_fn, который позволит избежать ошибки (за счет оценки оператора if TrainingIteration несколько раз... )

def step_fn(step_context):
  if handle_train is None:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
  return step_context.run_with_hooks(fetches=..., feed_dict=...)
person Max F.    schedule 04.12.2017

Существует демонстрация использования заполнителя в mot_session с помощью SessionRunHook. Эта демонстрация посвящена переключению наборов данных путем передачи diff handle_string.

Кстати, я пробовал все решения, но работает только это.

переключение_набора_данных

person spark    schedule 08.05.2018
comment
В этой ссылке используется make_one_shot_iterator, тогда как другие используют другие типы итераторов, и это единственный вариант, с которым я тоже мог работать. Если вы застряли, эта ссылка может быть чрезвычайно полезной. Спасибо, что поделились! - person Asher Mancinelli; 28.06.2018