Как использовать итератор API набора данных tensorflow в качестве входных данных (рекуррентной) нейронной сети?

При использовании итератора Dataset API Iterator tensorflow моей целью является определение RNN, которая работает с тензорами get_next() итератора в качестве входных данных (см. (1) в коде).

Однако простое определение dynamic_rnn с get_next() в качестве входных данных приводит к ошибке: ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

Теперь я знаю, что одним из обходных путей является просто создать заполнитель для next_batch, а затем eval() тензор (потому что вы не можете передать сам тензор) и передать его с помощью feed_dict (см. X и (2) в коде). Однако, если я правильно понимаю, это не эффективное решение, поскольку мы сначала оцениваем, а затем повторно инициализируем тензор.

Есть ли способ:

  1. Определите dynamic_rnn непосредственно поверх вывода Iterator;

or:

  1. Как-то напрямую передать существующий тензор get_next() заполнителю, который является входом dynamic_rnn?

Полный рабочий пример; версия (1) — это то, с чем я хотел бы работать, но это не так, а (2) — обходной путь, который работает.

import tensorflow as tf

from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator

data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
                                   dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)

# (2):
X = tf.placeholder(tf.float32, shape=(None, 3, 1))

cell = BasicLSTMCell(num_units=8)

# (1):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)

# (2):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    sess.run(iterator_init)

    # (1):
    # o, s = sess.run([outputs, states])
    # o, s = sess.run([outputs, states])

    # (2):
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})

(Используя тензорный поток 1.4.0, Python 3.6.)

Большое спасибо :)


person myke    schedule 20.11.2017    source источник
comment
Это не связано с вашим вопросом, но ваш код помог мне понять, что я должен использовать eval() при передаче элементов из набора данных в session.run. Спасибо!   -  person Armin Meisterhirn    schedule 18.02.2018


Ответы (1)


Оказывается, загадочная ошибка, вероятно, является ошибкой в ​​​​tensorflow, см. https://github.com/tensorflow/tensorflow/issues/14729. В частности, ошибка действительно возникает из-за подачи неправильного типа данных (в моем примере выше массив data содержит int32 значений, но он должен содержать числа с плавающей запятой).

Вместо получения ошибки ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct
tensorflow должен возвращать:
TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [int32, float32] that don't all match. (см. 1).

Чтобы решить эту проблему, просто замените
data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
на
data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32).

и тогда следующий код должен работать правильно:

import tensorflow as tf
import numpy as np

from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator

data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
                                   dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)

# (2):
# X = tf.placeholder(tf.float32, shape=(None, 3, 1))

cell = BasicLSTMCell(num_units=8)

# (1):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)

# (2):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    sess.run(iterator_init)

    # (1):
    o, s = sess.run([outputs, states])
    o, s = sess.run([outputs, states])

    # (2):
    # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
person myke    schedule 21.11.2017
comment
Но почему форма выходов (1,3,8)? Есть 3 наблюдения. - person ARAT; 20.11.2018
comment
Посмотрите, что возвращает dynamic_rnn. 1 — это размер пакета (поскольку мы определяем его только для одного пакета, next_batch), 3 — это количество меток времени в одном пакете, 8 — это количество единиц RNN. - person myke; 21.11.2018
comment
Но в вашем примере вы группируете 2 наблюдения в своем экземпляре. - person ARAT; 21.11.2018
comment
А, ты прав, не заметил. В этом случае форма будет (2,3,8). Где ты взял (1,3,8)? - person myke; 21.11.2018
comment
Минимальный рабочий пример, который вы привели, выдает форму o как (1,3,8). - person ARAT; 21.11.2018
comment
Вы технически правильны, хотя он также выдает форму (2,3,8). Проще говоря, поскольку набор данных состоит из трех элементов, первая итерация выводит форму (2,3,8), а вторая итерация выводит форму (1,3,8). Код может обрабатывать пакеты любого размера. Как правило, у вас будет цикл, перебирающий набор данных. Здесь, для ясности, RNN просто запускается дважды, сначала для первых двух элементов набора данных, а затем для последнего (третьего) элемента. - person myke; 22.11.2018
comment
ой! стрелять! Я такой глупый! Да, последний o перезаписывает первый o. Стрелять! Я понял! спасибо! - person ARAT; 22.11.2018