Я пытаюсь использовать Numba для ускорения кода. Код прост, в основном цикл с простыми вычислениями в массиве numpy.
import numpy as np
import time
from numba import jit, double
def MinimizeSquareDiffBudget(x, budget):
if (budget > np.sum(x)):
return x
n = np.size(x,0)
j = 1
i = 0
y = np.zeros((n, 1))
while (budget > 0):
while (x[i] == x[j]) and (j < n-1):
j += 1
i = j - 1
if (np.std(x)<1e-10):
to_give = budget/n
y += to_give
x= x- to_give
break
to_give = min(budget, (x[0] - x[j])*j)
y[0:j] += to_give/j
x[0:j]=x[0:j]-to_give/j
budget = budget - to_give
j = 1
return y
Теперь я попытался оптимизировать его, используя @jit и определив:
fastMinimizeSquareDiffBudget = jit(double[:,:](double[:,:], double[:,:]))(MinimizeSquareDiffBudget)
Однако время примерно такое же, а я ожидал, что Numba будет намного быстрее.
Тестирование кода:
budget = 335.0
x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)
x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)
занимает 0,28 секунды для прямой реализации и 0,45 секунды для оптимизированного кода с помощью Numba. Тот же код, написанный на C, выполняется менее чем за 0,001 секунды.
Любые идеи?