Более быстрая реализация производной ReLu в python?

Я реализовал производную ReLu как:

def relu_derivative(x):
     return (x>0)*np.ones(x.shape)

Я также пробовал:

def relu_derivative(x):
   x[x>=0]=1
   x[x<0]=0
   return x

Размер X=(3072,10000). Но вычисление занимает много времени. Есть ли другое оптимизированное решение?


person Talha Yousuf    schedule 03.03.2019    source источник


Ответы (1)


Подход №1: Использование numexpr

При работе с большими данными мы можем использовать numexprмодуль, который поддерживает многоядерную обработку, если предполагаемые операции можно выразить как арифметические. Здесь один из способов -

(X>=0)+0

Таким образом, чтобы решить наше дело, было бы -

import numexpr as ne

ne.evaluate('(X>=0)+0')

Подход № 2: Использование NumPy views

Еще один прием — использовать views, рассматривая маску сравнений как массив int, например:

(X>=0).view('i1')

По производительности это должно быть идентично созданию X>=0.

Время

Сравнение всех опубликованных решений на случайном массиве -

In [14]: np.random.seed(0)
    ...: X = np.random.randn(3072,10000)

In [15]: # OP's soln-1
    ...: def relu_derivative_v1(x):
    ...:      return (x>0)*np.ones(x.shape)
    ...: 
    ...: # OP's soln-2     
    ...: def relu_derivative_v2(x):
    ...:    x[x>=0]=1
    ...:    x[x<0]=0
    ...:    return x

In [16]: %timeit ne.evaluate('(X>=0)+0')
10 loops, best of 3: 27.8 ms per loop

In [17]: %timeit (X>=0).view('i1')
100 loops, best of 3: 19.3 ms per loop

In [18]: %timeit relu_derivative_v1(X)
1 loop, best of 3: 269 ms per loop

In [19]: %timeit relu_derivative_v2(X)
1 loop, best of 3: 89.5 ms per loop

Основанный на numexpr был с 8 потоками. Таким образом, с большим количеством потоков, доступных для вычислений, это должно улучшиться. Related post о том, как управлять многоядерными функциями.

Подход № 3: Подход № 1 + № 2 -

Смешайте оба из них для наиболее оптимального для больших массивов -

In [27]: np.random.seed(0)
    ...: X = np.random.randn(3072,10000)

In [28]: %timeit ne.evaluate('X>=0').view('i1')
100 loops, best of 3: 14.7 ms per loop
person Divakar    schedule 03.03.2019
comment
Отличный ответ (+1 и, конечно же, удалил свой собственный, несмотря на то, что уже проголосовал) - person desertnaut; 03.03.2019