dl4j lstm не удалось

Я пытаюсь скопировать упражнение примерно на полпути вниз по этой ссылке: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

В этом упражнении функция синуса используется для создания 1000 точек данных в диапазоне от -1 до 1, а рекуррентная сеть используется для аппроксимации функции.

Ниже приведен код, который я использовал. Я вернусь, чтобы подробнее изучить, почему это не работает, поскольку сейчас это не имеет для меня особого смысла, когда я легко смог использовать сеть прямой связи для аппроксимации этой функции.

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

Можете ли вы объяснить код, который мне понадобится для 1 из 10 скрытых и 1 исходящих сетей lstm для аппроксимации синусоидальной функции?

Я не использую нормализацию, так как функция уже -1:1, и я использую ввод Y в качестве функции и следующий ввод Y в качестве метки для обучения сети.

Вы заметили, что я создаю класс, который позволяет упростить построение сетей, и я пытался внести множество изменений в проблему, но мне надоело гадать.

Вот несколько примеров моих результатов. Синий — данные, красный — результат

введите здесь описание изображения

введите здесь описание изображения


person cagney    schedule 22.06.2020    source источник


Ответы (2)


Это один из тех случаев, когда вы переходите от вопроса, почему это не работает, к тому, как, черт возьми, мои первоначальные результаты были такими же хорошими, как и были.

Моя ошибка заключалась в том, что я плохо понимал документацию, а также не понимал BPTT.

В сетях с прямой связью каждая итерация сохраняется в виде строки, а каждый ввод — в виде столбца. Пример: [dataset.size, network inputs.size]

Однако с рекуррентным вводом все наоборот: каждая строка является вводом, а каждый столбец - итерацией по времени, необходимой для активации состояния цепочки событий lstm. Как минимум, мой ввод должен быть [0, networkinputs.size, dataset.size], но также может быть [dataset.size, networkinputs.size, statelength.size]

В моем предыдущем примере я обучал сеть с данными в этом формате [dataset.size, networkinputs.size, 1]. Итак, исходя из моего понимания низкого разрешения, сеть lstm вообще не должна была работать, но каким-то образом производила хоть что-то.

Возможно, также возникла проблема с преобразованием набора данных в список, поскольку я также изменил способ подачи данных в сеть, но я думаю, что основная проблема была связана со структурой данных.

Ниже приведены мои новые результаты Не идеально, но это 5 эпох обучения, которые так хороши, учитывая

person cagney    schedule 23.06.2020

Трудно сказать, что происходит, не видя полного кода. Для начала я не вижу указанного RnnOutputLayer. Вы можете посмотреть это, в котором показано, как построить RNN в DL4J. Если ваша настройка RNN верна, это может быть проблемой настройки. Подробнее о настройке можно узнать здесь. Адам, вероятно, лучший выбор для обновления, чем RMSProp. И tanh, вероятно, является хорошим выбором для активации вашего выходного слоя, поскольку его диапазон составляет (-1,1). Другие вещи, которые нужно проверить / настроить — скорость обучения, количество эпох, настройка ваших данных (например, вы пытаетесь предсказать далеко?).

person Susan Eraly    schedule 22.06.2020
comment
большое спасибо за ответ. Мой код запутан, потому что я создаю автоматический итеративный генератор сети. Я согласен, что ваше решение даст лучшие результаты, но я считаю, что это тоже должно работать. Я пытаюсь понять, почему это не работает, поэтому я больше занимаюсь стратегией, а не алхимией. - person cagney; 23.06.2020