API Tensorflow Estimator: запоминайте состояние LSTM из предыдущего пакета для следующего пакета с динамическим batch_size

Я знаю, что подобный вопрос уже задавался несколько раз здесь, в stackoverflow и в Интернете, но я просто не могу найти решение следующей проблемы: я пытаюсь построить модель LSTM с отслеживанием состояния в тензорном потоке и ее API оценки. Я попробовал решение Tensorflow, лучший способ сохранить состояние в RNN?, который работает, пока я использую статический batch_size. Наличие динамического batch_size вызывает следующую проблему:

ValueError: initial_value должен иметь указанную форму: Tensor("DropoutWrapperZeroState/MultiRNNCellZeroState/DropoutWrapperZeroState/LSTMCellZeroState/zeros:0", shape=(?, 200), dtype=float32)

Параметр tf.Variable(...., validate_shape=False) просто перемещает проблему дальше по графику:

Traceback (most recent call last):
  File "model.py", line 576, in <module>
    tf.app.run(main=run_experiment)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "model.py", line 137, in run_experiment
    hparams=params  # HParams
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py", line 210, in run
    return _execute_schedule(experiment, schedule)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py", line 47, in _execute_schedule
    return task()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 495, in train_and_evaluate
    self.train(delay_secs=0)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 275, in train
    hooks=self._train_monitors + extra_hooks)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 660, in _call_train
    hooks=hooks)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 241, in train
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 560, in _train_model
    model_fn_lib.ModeKeys.TRAIN)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 545, in _call_model_fn
    features=features, labels=labels, **kwargs)
  File "model.py", line 218, in model_fn
    output, state = get_model(features, params)
  File "model.py", line 567, in get_model
    model = lstm(inputs, params)
  File "model.py", line 377, in lstm
    output, new_states = tf.nn.dynamic_rnn(multicell, inputs=inputs, initial_state = states)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 574, in dynamic_rnn
    dtype=dtype)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 737, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2770, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2599, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2549, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 722, in _time_step
    (output, new_state) = call_cell()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 708, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 752, in __call__
    output, new_state = self._cell(inputs, state, scope)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/base.py", line 441, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 916, in call
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 752, in __call__
    output, new_state = self._cell(inputs, state, scope)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/base.py", line 441, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 542, in call
    lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1002, in _linear
    raise ValueError("linear is expecting 2D arguments: %s" % shapes)
ValueError: linear is expecting 2D arguments: [TensorShape([Dimension(None), Dimension(62)]), TensorShape(None)]

В соответствии с выпуском github 2838 в любом случае НЕ рекомендуется использовать необучаемые переменные (? ??), поэтому я продолжал искать другие решения.

Теперь я использую заполнители и что-то в этом роде (также предлагается в ветке github) в моем model_fn:

def rnn_placeholders(state):
    """Convert RNN state tensors to placeholders with the zero state as default."""
    if isinstance(state, tf.contrib.rnn.LSTMStateTuple):
        c, h = state
        c = tf.placeholder_with_default(c, c.shape, c.op.name)
        h = tf.placeholder_with_default(h, h.shape, h.op.name)
        return tf.contrib.rnn.LSTMStateTuple(c, h)
    elif isinstance(state, tf.Tensor):
        h = state
        h = tf.placeholder_with_default(h, h.shape, h.op.name)
        return h
    else:
        structure = [rnn_placeholders(x) for x in state]
        return tuple(structure)


state = rnn_placeholders(cell.zero_state(batch_size, tf.float32))

for tensor in flatten(state):
    tf.add_to_collection('rnn_state_input', tensor)

x, new_state = tf.nn.dynamic_rnn(...)

for tensor in flatten(new_state):
    tf.add_to_collection('rnn_state_output', tensor) 

Но, к сожалению, я не знаю, как использовать заполнитель new_state для передачи его значений заполнителю state на каждой итерации, при использовании tf.Estimator API и т. д.. Поскольку я новичок в Tensorflow, я думаю, что мне здесь не хватает концептуальных знаний. Можно ли использовать пользовательский SessionRunHook?:

class UpdateHook(tf.train.SessionRunHook):

        def before_run(self, run_context):
            run_args = super(UpdateHook, self).before_run(run_context)
            run_args = tf.train.SessionRunArgs(new_state)

            #print(run_args)
            return run_args

        def after_run(self, run_context, run_values):
            #run_values gives the actual value of new_state.
            # How to update now the state placeholder??

Есть ли кто-нибудь, у кого есть идея, как решить эту проблему? Советы и рекомендации приветствуются!!! Большое спасибо!

PS: Если что-то неясно, дайте мне знать ;)

РЕДАКТИРОВАТЬ: К сожалению, я использую новый API tf.data и не могу использовать StateSavingRNNEstimator, как предложил Юджин.


person Pete    schedule 06.10.2017    source источник
comment
Не могли бы вы найти реальное решение? Я создал заполнитель для batch_size и получаю сообщение об ошибке.   -  person ARAT    schedule 27.06.2018


Ответы (2)


этот ответ может быть поздним. У меня была аналогичная проблема несколько месяцев назад. Я решил это, используя настроенный SessionRunHook. Это может быть не идеально с точки зрения производительности, но вы можете попробовать.

class LSTMStateHook(tf.train.SessionRunHook):

 def __init__(self, params):
    self.init_states  = None
    self.current_state = np.zeros((params.rnn_layers, 2, params.batch_size, params.state_size))

 def before_run(self, run_context):
    run_args = tf.train.SessionRunArgs([tf.get_default_graph().get_tensor_by_name('LSTM/output_states:0')],{self.init_states:self.current_state,},)
    return run_args

 def after_run(self, run_context, run_values):
    self.current_state = run_values[0][0] //depends on your session run arguments!!!!!!!


 def begin(self):
    self.init_states = tf.get_default_graph().get_tensor_by_name('LSTM/init_states:0')

В вашем коде, где вы определяете свой график lstm, вам нужно что-то вроде этого:

if self.stateful is True:
        init_states = multicell.zero_state(self.batch_size, tf.float32)
        init_states = tf.identity(init_states, "init_states")

        l = tf.unstack(init_states, axis=0)
        rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(self.rnn_layers)])

    else:
        rnn_tuple_state = multicell.zero_state(self.batch_size, tf.float32)

# Unroll RNN
output, output_states = tf.nn.dynamic_rnn(multicell, inputs=inputs, initial_state = rnn_tuple_state)

if self.stateful is True:
  output_states = tf.identity(output_states, "output_states")
  return output
person th-w    schedule 04.01.2018