Как выбрать заданный набор индексов из NDArray в ND4j аналогично массиву numpy[arrayIndex]?

Я разрабатываю научное приложение, в значительной степени полагающееся на манипулирование массивами в Java с использованием ND4j (в настоящее время версия 1.0.0-beta5). В моем конвейере мне необходимо динамически выбирать несмежное подмножество матрицы [2,195102] (точнее, несколько десятков/сотен столбцов). Любая идея, как добиться этого в этой структуре?

Короче говоря, я пытаюсь выполнить эту операцию python/numpy:

import numpy as np
arrayData = np.array([[1, 5, 0, 6, 2, 0, 9, 0, 5, 2],
       [3, 6, 1, 0, 4, 3, 1, 4, 8, 1]])
arrayIndex = np.array((1,5,6))
res  = arrayData[:, arrayIndex]
# res value is
# array([[5, 0, 9],
#        [6, 3, 1]])

До сих пор мне удалось выбрать нужный столбец с помощью NDArray.getColumns (вместе с NDArray.data().asInt() из indexArray для предоставления значений индекс). Проблема в том, что в документации прямо говорится о получении информации во время вычислений: «Обратите внимание, что ЭТО НЕ ДОЛЖНО ИСПОЛЬЗОВАТЬ ДЛЯ СКОРОСТИ» (см. документацию NDArray.ToIntMatrix(), чтобы увидеть полное сообщение — другой метод, та же операция).

Я просмотрел различные прототипы для NDArray.get(), и ни один из них не отвечает всем требованиям. Я предполагаю, что NDArray.getWhere() может работать - если, как я предполагаю, только возвращает элементы, которые удовлетворяют условию, но до сих пор не смогли его использовать. Документация относительно легкая, когда дело доходит до объяснения необходимых аргументов/использования.

Спасибо всем за ваше время и помощь :)

РЕДАКТИРОВАТЬ (11.04.2019): некоторая точность в отношении того, что я пробовал. Я поиграл с NDArray.get() и использовал индексы:

INDArray arrayData = Nd4j.create(new int[]
                    {1, 5, 0, 6, 2, 0, 9, 0, 5, 2,
                     3, 6, 1, 0, 4, 3, 1, 4, 8, 1},   new long[]{2, 10}, DataType.INT);
INDArray arrayIndex = Nd4j.create(new int[]{1, 5, 6}, new long[]{1,  3}, DataType.INT);

INDArray colSelection = null;

//index free version
colSelection = arrayData.getColumns(arrayIndex.toIntVector());
/*
* colSelection value is
* [[5, 0, 9],
*  [6, 3, 1]]
* but the toIntVector() call pulls the data from the back-end storage
* and re-inject them. That is presumed to be slow.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 4001 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 5339 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 7016 ms for 100000 iterations
*/

//index version
colSelection = arrayData.get(NDArrayIndex.all(), NDArrayIndex.indices(arrayIndex.toLongVector()));
/*
* Same result, but same problem regarding toLongVector() this time around.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 3200 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 4269 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 5252 ms for 100000 iterations
*/

//weird but functional version (that I just discovered)
colSelection = arrayData.transpose().get(arrayIndex); // the transpose operation is necessary to not hit an IllegalArgumentException: Illegal slice 5
// note that transposing the arrayIndex leads to an IllegalArgumentException: Illegal slice 6 (as it is trying to select the element at the line idx 1, column 5, depth 6, which does not exist)
/*
* colSelection value is
* [5, 6, 0, 3, 9, 1]
* The array is flattened... calling a reshape(arrayData.shape()[0],arrayIndex.shape()[1]) yields
* [[5, 6, 0],
*  [3, 9, 1]]
* which is wrong.
*/
colSelection = colSelection.reshape(arrayIndex.shape()[1],arrayData.shape()[0]).transpose();
/* yields the right result
* [[5, 0, 9],
*  [6, 3, 1]]
* While this seems to be the correct way to handle the memory the performance are low:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 8225 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 8980 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 9453 ms for 100000 iterations
Plus, this is very roundabout method for such a "simple" operation
* if the repacking of the data is commented out, the timing become:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 6987 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 7976 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 8336 ms for 100000 iterations
*/

Эти скорости кажутся мне нормальными, не зная, на какой машине я работаю, но эквивалентный код Python дает:

  • Выбрано 2 столбца (arrayIndex = {1, 5}), ==> 171 мс для 100000 итераций
  • Выбрано 3 столбца (arrayIndex = {1, 5, 6}), ==> 173 мс для 100000 итераций
  • Выбрано 4 столбца (arrayIndex = {1, 5, 6 ,2}), ==> 173 мс для 100000 итераций

Эти реализации Java в лучшем случае в 20 раз медленнее, чем реализация python-numpy.


person TimWetone    schedule 03.11.2019    source источник


Ответы (1)


org.nd4j.linalg.api.ndarray.INDArray arr = org.nd4j.linalg.factory.Nd4j.create(new double[][]{
                {1, 5, 0, 6, 2, 0, 9, 0, 5, 2},
                {3, 6, 1, 0, 4, 3, 1, 4, 8, 1}
        });

        org.nd4j.linalg.indexing.INDArrayIndex indices[] = {
                org.nd4j.linalg.indexing.NDArrayIndex.all(),
                new org.nd4j.linalg.indexing.SpecifiedIndex(1,5,6)
        };

        org.nd4j.linalg.api.ndarray.INDArray selected = arr.get(indices);
        System.out.println(selected);
    }

Это должно сработать для вас. Это печатает: SLF4J: не удалось загрузить класс «org.slf4j.impl.StaticLoggerBinder». SLF4J: по умолчанию используется реализация регистратора без операций (NOP) SLF4J: см. http://www.slf4j.org/codes.html#StaticLoggerBinder для получения дополнительной информации.

[[    5.0000,         0,    9.0000], 
 [    6.0000,    3.0000,    1.0000]]

Процесс завершен с кодом выхода 0

person Adam Gibson    schedule 04.11.2019
comment
Привет Адам, спасибо за ваш ответ. Я добавил некоторую точность в качестве редактирования моего вопроса. Действительно, у меня есть функциональные реализации. Но ни один из них не является действительно эффективным. - person TimWetone; 04.11.2019
comment
Попробуйте мой отредактированный ответ? Это воспроизводит то, что вы делаете в numpy. - person Adam Gibson; 06.11.2019
comment
Привет Адам, я нашел время, чтобы попробовать это в конце недели. Это действительно работает, с производительностью, аналогичной решению, которое я пробовал, используя NDArrayIndex.indices(arrayIndex.toLongVector()). Узким местом является получение данных из arrayIndex (и, следовательно, перенос памяти из неуправляемой в управляемую память). Тем не менее я приму это как правильный ответ, поскольку он отвечает на заголовок вопроса. Спасибо :) - person TimWetone; 11.11.2019
comment
Правда, вы не можете избежать этого, хотя. Для этого требуется копия. Нет никакого способа сделать это с представлением. Numpy имеет аналогичные крайние случаи. - person Adam Gibson; 12.11.2019