индексация в тензорном потоке медленнее, чем сбор

Я пытаюсь индексировать тензор, чтобы получить срез или отдельный элемент из 1d тензоров. Я обнаружил, что существует значительная разница в производительности при использовании numpy способа индексации [:] и slice vs tf.gather (почти 30-40%).

Также я замечаю, что tf.gather имеет значительные накладные расходы при использовании на скалярах (зацикливание на нестекированном тензоре), в отличие от tensor . Это известная проблема ?

пример кода (неэффективный):

for node_idxs in graph.nodes():
    node_indice_list = tf.unstack(node_idxs)
    result = []
    for nodeid in node_indices_list:
        x = tf.gather(..., nodeid)
        y = tf.gather(..., nodeid)
        result.append(tf.mul(x,y))
return tf.stack(result)

в отличие от примера кода (эффективного):

for node_idxs in graph.nodes():
    x = tf.gather(..., node_idxs)
    y = tf.gather(..., node_idxs)
return tf.mul(x, y)

Я понимаю, что первая неэффективная реализация выполняет больше работы по распаковке, укладке в стек, а затем зацикливанию и увеличению количества операций сбора, но я не ожидал 100-кратного замедления, когда порядок узлов, с которыми я работаю, составляет несколько сотен узлов (распаковка и накладные расходы на сбор на одном медленном скаляре, в первом случае у меня гораздо больше операций сбора, каждая из которых работает с одним элементом, а не с тензором смещений). Есть ли более быстрый способ индексации, я попробовал numpy и slice, которые оказались медленнее, чем сбор.


person user179156    schedule 05.09.2017    source источник


Ответы (1)


Во-первых, код на самом деле не сравнивает сбор и индексирование Numpy — он сравнивает векторизованное индексирование (tf.gather) с циклическим индексированием (Python for loop). Неудивительно, что зацикливание происходит медленно.

Обратите внимание, что Numpy-подобное индексирование tensor[idxs] в любом случае ограничено в Tensorflow:

Допустимыми индексами являются только целые числа, срезы (:), многоточие (...), tf.newaxis (None) и скалярные тензоры tf.int32/tf.int64.

Так что используйте tf.gather для общих приложений.

person Maciej S.    schedule 19.12.2020