Я начинаю использовать новый API набора данных, и одна вещь, которую я хочу сделать, не описана в документе (https://www.tensorflow.org/programmers_guide/datasets#training_workflows)
Мои данные помещаются в память, поэтому я хочу загрузить их в тензорный поток, чтобы сделать обучение эффективным, и для этого я сейчас вижу 2 способа сделать это:
один загружает данные в график напрямую следующим образом:
dataset = tf.contrib.data.Dataset.from_tensor_slices((X, Y))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer)
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
другой - загрузить данные в заполнитель, чтобы данные не сохранялись на графике:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer, feed_dict={features_placeholder: X, labels_placeholder: Y})
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
Во-вторых, я считаю, что лучше всего экономить память, но я не хочу передавать данные в каждую эпоху. Это действительно потеря производительности ни за что.
Есть ли способ инициализировать итератор только один раз с заполнителем?
что-то вроде этого:
sess.run(iterator.initializer, feed_dict={features_placeholder: X, labels_placeholder: Y})
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer)
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
Таким образом, мы можем сохранить производительность первого решения и сэкономить память, как и второе решение.
Примечание:
одно из решений состоит в том, чтобы определить количество эпох с помощью метода
dataset.repeat()
, но с ним мы как бы теряем контроль над тем, на каком этапе обучения мы находимся.Я хочу проверять после каждой эпохи (один проход по всем данным) эволюцию потери.