Pytorch: объясните torch.argmax

Здравствуйте, у меня есть следующий код:

import torch
x = torch.zeros(1,8,4,576) # create a 4 dimensional tensor
x[0,4,2,333] = 1.0 # put on 1 on a random spot

# I want to find the index of the highest value (0,4,2,333)
print(x.argmax()) # this should return the index

Это возвращает

тензор (10701)

В чем смысл этого 10701?

Как мне получить актуальные индексы 0,4,2,333?


person relot    schedule 18.07.2020    source источник


Ответы (1)


Данные в 4-мерном массиве линейно сохраняются в памяти, и argmax() возвращает соответствующий индекс этого плоского представления.

В Numpy есть функция для распутывания индекса (преобразования из индекса плоского массива в соответствующие многомерные индексы).

import numpy as np
np.unravel_index(10701, (1,8,4,576))
person dannyadam    schedule 19.07.2020