Ускорение следующего кода с помощью Numba

Я пытаюсь использовать Numba для ускорения кода. Код прост, в основном цикл с простыми вычислениями в массиве numpy.

import numpy as np
import time
from numba import jit, double

def MinimizeSquareDiffBudget(x, budget):
    if (budget > np.sum(x)):
        return x
    n = np.size(x,0)
    j = 1
    i = 0
    y = np.zeros((n, 1))
    while (budget > 0):
        while (x[i] == x[j]) and (j < n-1):
            j += 1
        i = j - 1
        if (np.std(x)<1e-10):
            to_give = budget/n
            y += to_give
            x= x- to_give
            break
        to_give = min(budget, (x[0] - x[j])*j)
        y[0:j] += to_give/j
        x[0:j]=x[0:j]-to_give/j
        budget = budget - to_give
        j = 1
    return y

Теперь я попытался оптимизировать его, используя @jit и определив:

fastMinimizeSquareDiffBudget = jit(double[:,:](double[:,:], double[:,:]))(MinimizeSquareDiffBudget)

Однако время примерно такое же, а я ожидал, что Numba будет намного быстрее.

Тестирование кода:

budget = 335.0

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]

t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

занимает 0,28 секунды для прямой реализации и 0,45 секунды для оптимизированного кода с помощью Numba. Тот же код, написанный на C, выполняется менее чем за 0,001 секунды.

Любые идеи?


person Gil    schedule 13.03.2017    source источник


Ответы (1)


Когда вы измеряете только одно выполнение jit-функции, вы видите как время выполнения, так и время, которое требуется Numba для jit-кода. Если вы запустите код во второй раз, вы увидите фактическое ускорение, так как Numba использует кеш в памяти скомпилированной функции, поэтому вы платите за время компиляции только один раз для каждого типа аргумента.

На моей машине с использованием python 3.6 и numba 0.31.0 чистая функция python занимает 0,32 секунды. В первый раз, когда я вызываю fastMinimizeSquareDiffBudget, это занимает 0,57 секунды, а во второй раз — 0,31 секунды.

Теперь причина, по которой вы не видите огромного ускорения, заключается в том, что у вас есть функция, которую Numba не может скомпилировать в режиме nopython, поэтому она возвращается к гораздо более медленному режиму object mode. Если вы передадите nopython=True методу jit, вы сможете увидеть, где он не может скомпилироваться. Две проблемы, которые я заметил, заключались в том, что вы должны использовать x.shape[0] вместо np.size(x,0), и вы не можете использовать min так, как вы это делаете.

person JoshAdel    schedule 13.03.2017