Как выполнить итерацию набора данных несколько раз с помощью API набора данных TensorFlow?

Как вывести значение в наборе данных несколько раз? (набор данных создается Dataset API TensorFlow)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

Сообщение об ошибке:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

Как заставить это работать?


person void    schedule 02.11.2017    source источник


Ответы (4)


Прежде всего я советую вам прочитать Руководство по набору данных. Здесь описаны все подробности DataSet API.

Ваш вопрос касается повторения данных несколько раз. Вот два решения для этого:

  1. Перебор сразу всех эпох, нет информации о конце отдельных эпох.
import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. Второй вариант информирует вас о завершении каждой эпохи, так что вы можете ex. проверить потерю валидации:
import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)
person melgor89    schedule 10.11.2017

Если ваша версия tenorflow 1.3+, я рекомендую высокоуровневый API tf.train.MonitoredTrainingSession. sess, созданный этим API, может автоматически обнаруживать tf.errors.OutOfRangeError с помощью sess.should_stop(). В большинстве ситуаций обучения вам необходимо перемешивать данные и получать пакет на каждом шаге, я добавил это в следующий код.

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)
person Tom    schedule 21.12.2017

Попробуй это

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

Когда итератор набора данных достигает конца данных, он вызывает tf.errors.OutOfRangeError, вы можете поймать его с помощью except и запустить набор данных с начала.

person Grigor Carran    schedule 27.03.2018
comment
Вы должны объяснить свой код или также включить комментарии - person Michael; 27.03.2018

Подобно ответу Тома, для tensorflow 2+ вы можете использовать следующие вызовы API высокого уровня (код, предложенный в его ответе, устарел в tensorflow 2+):

epoch = 10
batch_size = 32
dataset = tf.data.Dataset.range(100) 

dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.repeat(epoch)

num_batch = 0
for batch in dataset:
        num_batch += 1
        print("Num Batch: ", num_batch)

Полезный призыв для отслеживания прогресса - это общее количество пакетов, которые будут повторяться (используется после вызовов batch и repeat):

num_batches = tf.data.experimental.cardinality(dataset)

Обратите внимание, что в настоящее время (тензорный поток 2.1) метод cardinality все еще экспериментальный.

person miwe    schedule 05.02.2020