При использовании итератора Dataset API Iterator tensorflow моей целью является определение RNN, которая работает с тензорами get_next()
итератора в качестве входных данных (см. (1)
в коде).
Однако простое определение dynamic_rnn
с get_next()
в качестве входных данных приводит к ошибке: ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.
Теперь я знаю, что одним из обходных путей является просто создать заполнитель для next_batch
, а затем eval()
тензор (потому что вы не можете передать сам тензор) и передать его с помощью feed_dict
(см. X
и (2)
в коде). Однако, если я правильно понимаю, это не эффективное решение, поскольку мы сначала оцениваем, а затем повторно инициализируем тензор.
Есть ли способ:
- Определите
dynamic_rnn
непосредственно поверх вывода Iterator;
or:
- Как-то напрямую передать существующий тензор
get_next()
заполнителю, который является входомdynamic_rnn
?
Полный рабочий пример; версия (1)
— это то, с чем я хотел бы работать, но это не так, а (2)
— обходной путь, который работает.
import tensorflow as tf
from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator
data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)
# (2):
X = tf.placeholder(tf.float32, shape=(None, 3, 1))
cell = BasicLSTMCell(num_units=8)
# (1):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)
# (2):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sess.run(iterator_init)
# (1):
# o, s = sess.run([outputs, states])
# o, s = sess.run([outputs, states])
# (2):
o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
(Используя тензорный поток 1.4.0, Python 3.6.)
Большое спасибо :)