Ошибка Tensorflow при использовании tf.gradients и tf.hessian: TypeError: аргумент Fetch. Нет недопустимого типа ‹type 'NoneType'›

Я только начал изучать tensorflow и столкнулся со следующей ошибкой при использовании функций tf.gradients и tf.hessain. Ниже приведен код и ошибка для tf.gradients.

import tensorflow as tf
a = tf.placeholder(tf.float32,shape = (2,2))
b = [[1.0,2.0],[3.0,4.0]]
c = a[0,0]*a[0,1]*a[1,0] + a[0,1]*a[1,0]*a[1,1]
e = tf.reshape(b,[4])
d = tf.gradients(c,e)
sess = tf.Session()
print(sess.run(d,feed_dict={a:b}))

Я получаю следующую ошибку для последней строки

>>> print(sess.run(d,feed_dict={a:b}))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 952, in _run
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 408, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch
    return _ListFetchMapper(fetch)
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 337, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/share/apps/tensorflow/20170218/python2.7/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 227, in for_fetch
    (fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <type 'NoneType'>

любые идеи относительно того, как я могу отладить это?


person noob_eggplant    schedule 28.04.2017    source источник


Ответы (1)


Это потому, что c рассчитывается на основе a, а не e. Вы можете изменить линию тензора градиента, как показано ниже.

d = tf.gradients(c,a)

Кстати, в исходном коде, если вы напечатаете d, вы обнаружите, что это [None]

person Chengji Yao    schedule 28.04.2017