Как мне поступить вместо использования numpy.vectorize в CuPy?

Как мне применить определенную функцию к cupy.array вместо np.vectorize? Реализована ли аналогичная функция в cupy?

Я пишу программу моделирования в Python 3.6.9.

Я хотел бы выполнить моделирование на графическом процессоре (GTX1060, NVIDIA) с помощью CuPy (6.0.0 для CUDA10.1).

В исходном коде функция numpy.vectorize использовалась для применения определенной функции к np.array. Однако эта же функция еще не реализована в CuPy.

Исходный код (с использованием numpy) выглядит следующим образом:

#For define function
def rate(tmean,x,y,z):
    rate = 1/z/(1 + math.exp(-x*(tmean-y)))
    #DVR<0
    if rate < 0:
        rate = 0
    return rate

#tmean is temperature data(365,100,100) and loaded as np.array
#paras is parameter(3,100,100)
#vectorized
f = np.vectorize(rate)
#roop
for i in range(365):
    #calc developing rate(by function "rate") and accumulate
    dvi[i,:,:] = dvi[i-1,:,:] + f(tmean[i,:,:],paras[0],paras[1],paras[2])

Я знаю, что почти все функции numpy реализованы в CuPy. поэтому я изменил

f = np.vectorized(rate) 

to

f= cp.vectorized(rate)

но произошла ошибка AttributeError.


person mas    schedule 04.11.2019    source источник


Ответы (1)


GPU не может распараллелить произвольный код Python. Напишите все в операциях, совместимых с NumPy, например

def rate_(xp, tmean,x,y,z):
    rate = 1/z/(1 + xp.exp(-x*(tmean-y)))
    rate[rate < 0] = 0
    return rate

f = functools.partial(rate_, xp=cupy)

Для ускорения вы можете использовать cupy.ElementwiseKernel (https://docs-cupy.chainer.org/en/stable/tutorial/kernel.html), который создает единое ядро ​​для векторизованной операции.

f = cupy.ElementwiseKernel(
    'T tmean, T x, T y, T z',
    'T rate',
    '''
    rate = 1/z/(1 + exp(-x*(tmean-y)));
    // DVR<0
    if (rate < 0) {
        rate = 0;
    }
    '''
)

Чтобы создать ядро ​​из кода Python, попробуйте cupy.fuse.

@cupy.fuse()
def f(tmean,x,y,z):
    rate = 1/z/(1 + cupy.exp(-x*(tmean-y)))
    return cupy.where(rate < 0, 0, rate)  # __setitem__ is not fully supported
person tos    schedule 05.11.2019