Применение поэлементных условных функций к Theano TensorVariable

Проще всего мне было бы просто опубликовать пустой код, который я пытаюсь выполнить непосредственно в Theano, если это возможно:

tensor = shared(np.random.randn(7, 16, 16)).eval()

tensor2 = tensor[0,:,:].eval()
tensor2[tensor2 < 1] = 0.0
tensor2[tensor2 > 0] = 1.0

new_tensor = [tensor2]
for i in range(1, tensor.shape[0]):
    new_tensor.append(np.multiply(tensor2, tensor[i,:,:].eval()))

output = np.array(new_tensor).reshape(7,16,16)

Если это не сразу очевидно, я пытаюсь использовать значения из одной матрицы тензора, состоящего из 7 разных матриц, и применить это к другим матрицам в тензоре.

На самом деле проблема, которую я решаю, заключается в выполнении условных операторов в целевой функции для полностью сверточной сети в Керасе. По сути, потери для некоторых значений карты объектов будут рассчитываться (и впоследствии взвешиваться) иначе, чем для других, в зависимости от некоторых значений в одной из карт объектов.


person Corey J. Nolet    schedule 17.01.2017    source источник


Ответы (1)


Вы можете легко реализовать условные операторы с оператором switch.

Вот эквивалентный код:

import theano
from theano import tensor as T
import numpy as np


def _check_new(var):
    shape =  var.shape[0]
    t_1, t_2 = T.split(var, [1, shape-1], 2, axis=0)
    ones = T.ones_like(t_1)
    cond = T.gt(t_1, ones)
    mask = T.repeat(cond, t_2.shape[0], axis=0)
    out  = T.switch(mask, t_2, T.zeros_like(t_2))
    output = T.join(0, cond, out)
    return output

def _check_old(var):
    tensor = var.eval()

    tensor2 = tensor[0,:,:]
    tensor2[tensor2 < 1] = 0.0
    tensor2[tensor2 > 0] = 1.0
    new_tensor = [tensor2]

    for i in range(1, tensor.shape[0]):
        new_tensor.append(np.multiply(tensor2, tensor[i,:,:]))

    output = theano.shared(np.array(new_tensor).reshape(7,16,16))
    return output


tensor = theano.shared(np.random.randn(7, 16, 16))
out1 =  _check_new(tensor).eval() 
out2 =  _check_old(tensor).eval()
print out1
print '----------------'
print ((out1-out2) ** 2).mean()

Примечание: поскольку вы маскируете первый фильтр, мне нужно было использовать операции split и join.

person indraforyou    schedule 17.01.2017