Нейронная сеть, возвращающая NaN в качестве вывода

Я пытаюсь написать свою первую нейронную сеть, чтобы играть в игру, соединяющую четыре. Я использую Java и deeplearning4j. Я пытался реализовать генетический алгоритм, но когда я некоторое время обучаю сеть, выходы сети перескакивают на NaN, и я не могу сказать, где я так сильно напортачил, чтобы это произошло. Я опубликую все 3 класса ниже, где Game — игровая логика и правила, VGFrame — пользовательский интерфейс, а Main — все остальное.

У меня есть пул из 35 нейронных сетей, и на каждой итерации я позволяю 5 лучшим жить и размножаться, а вновь созданные немного рандомизирую. Чтобы оценить сети, я позволяю им сражаться друг с другом и даю очки победителю и очки за проигрыш позже. Поскольку я наказываю, бросая камень в столбец, который уже заполнен, я ожидал, что нейронные сети, по крайней мере, смогут играть в игру по правилам через некоторое время, но они не могут этого сделать. Я погуглил проблему NaN, и, похоже, это проблема экспонирующего градиента, но, насколько я понимаю, это не должно происходить в генетическом алгоритме? Любые идеи, где я мог бы искать ошибку или что вообще не так с моей реализацией?

Главный

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class Main {
    final int numRows = 7;
    final int numColums = 6;
    final int randSeed = 123;
    MultiLayerNetwork[] models;

    static Random random = new Random();
    private static final Logger log = LoggerFactory.getLogger(Main.class);
    final float learningRate = .8f;
    int batchSize = 64; // Test batch size
    int nEpochs = 1; // Number of training epochs
    // --
    public static Main current;
    Game mainGame = new Game();

    public static void main(String[] args) {
        current = new Main();
        current.frame = new VGFrame();
        current.loadWeights();
    }

    private VGFrame frame;
    private final double mutationChance = .05;

    public Main() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU).seed(randSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
                        .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
                .build();
        models = new MultiLayerNetwork[35];
        for (int i = 0; i < models.length; i++) {
            models[i] = new MultiLayerNetwork(conf);
            models[i].init();
        }

    }

    public void addChip(int i, boolean b) {
        if (mainGame.gameState == 0)
            mainGame.addChip(i, b);
        if (mainGame.gameState == 0) {
            float[] f = Main.rowsToInput(mainGame.rows);
            INDArray input = Nd4j.create(f);
            INDArray output = models[0].output(input);
            for (int i1 = 0; i1 < 7; i1++) {
                System.out.println(i1 + ": " + output.getDouble(i1));
            }
            System.out.println("----------------");
            mainGame.addChip(Main.getHighestOutput(output), false);
        }
        getFrame().paint(getFrame().getGraphics());
    }

    public void newGame() {
        mainGame = new Game();
        getFrame().paint(getFrame().getGraphics());
    }

    public void startTraining(int iterations) {

        // --------------------------
        for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
            System.out.println("Iteration " + gameNumber + " of " + iterations);
            float[] evaluation = new float[models.length];
            for (int i = 0; i < models.length; i++) {
                for (int j = 0; j < models.length; j++) {
                    if (i != j) {
                        Game g = new Game();
                        g.playFullGame(models[i], models[j]);
                        if (g.gameState == 1) {
                            evaluation[i] += 45;
                            evaluation[j] += g.turnNumber;
                        }
                        if (g.gameState == 2) {
                            evaluation[j] += 45;
                            evaluation[i] += g.turnNumber;
                        }
                    }
                }
            }

            float[] evaluationSorted = evaluation.clone();
            Arrays.sort(evaluationSorted);
            // keep the best 4
            int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
            for (int i = 0; i < evaluation.length; i++) {
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
                    n1 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
                    n2 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
                    n3 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
                    n4 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
                    n5 = i;
            }
            models[0] = models[n1];
            models[1] = models[n2];
            models[2] = models[n3];
            models[3] = models[n4];
            models[4] = models[n5];

            for (int i = 3; i < evaluationSorted.length; i++) {
                // random parent/keep w8ts
                double r = Math.random();
                if (r > .3) {
                    models[i] = models[random.nextInt(3)].clone();

                } else if (r > .1) {
                    models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
                }
                // Mutate
                INDArray params = models[i].params();
                models[i].setParams(mutate(params));
            }
        }
    }

    private INDArray mutate(INDArray params) {
        double[] d = params.toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < mutationChance)
                d[i] += (Math.random() - .5) * learningRate;

        }
        return Nd4j.create(d);
    }

    private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        double[] d = m1.params().toDoubleVector();
        double[] d2 = m2.params().toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < .5)
                d[i] += d2[i];
        }
        return Nd4j.create(d);
    }

    static int getHighestOutput(INDArray output) {
        int x = 0;
        for (int i = 0; i < 7; i++) {
            if (output.getDouble(i) > output.getDouble(x))
                x = i;
        }
        return x;
    }

    static float[] rowsToInput(byte[][] rows) {
        float[] f = new float[7 * 6];
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 7; j++) {
                // f[j + i * 7] = rows[j][i] / 2f;
                f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
            }
        }
        return f;
    }

    public void saveWeights() {
        log.info("Saving model");
        for (int i = 0; i < models.length; i++) {
            File resourcesDirectory = new File("src/resources/model" + i);
            try {
                models[i].save(resourcesDirectory, true);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void loadWeights() {
        if (new File("src/resources/model0").exists()) {
            for (int i = 0; i < models.length; i++) {
                File resourcesDirectory = new File("src/resources/model" + i);
                try {

                    models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        System.out.println("col: " + models[0].params().shapeInfoToString());
    }

    public VGFrame getFrame() {
        return frame;
    }

}

VGFrame

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;

public class VGFrame extends JFrame {
    JTextField iterations;
    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public VGFrame() {
        super("Vier Gewinnt");
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setSize(1300, 800);
        this.setVisible(true);
        JPanel panelGame = new JPanel();
        panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
        this.add(panelGame);

        var handler = new Handler();
        var menuHandler = new MenuHandler();

        JButton b1 = new JButton("1");
        JButton b2 = new JButton("2");
        JButton b3 = new JButton("3");
        JButton b4 = new JButton("4");
        JButton b5 = new JButton("5");
        JButton b6 = new JButton("6");
        JButton b7 = new JButton("7");
        b1.addActionListener(handler);
        b2.addActionListener(handler);
        b3.addActionListener(handler);
        b4.addActionListener(handler);
        b5.addActionListener(handler);
        b6.addActionListener(handler);
        b7.addActionListener(handler);
        panelGame.add(b1);
        panelGame.add(b2);
        panelGame.add(b3);
        panelGame.add(b4);
        panelGame.add(b5);
        panelGame.add(b6);
        panelGame.add(b7);

        JButton buttonTrain = new JButton("Train");
        JButton buttonNewGame = new JButton("New Game");
        JButton buttonSave = new JButton("Save Weights");
        JButton buttonLoad = new JButton("Load Weights");

        iterations = new JTextField("1000");

        buttonTrain.addActionListener(menuHandler);
        buttonNewGame.addActionListener(menuHandler);
        buttonSave.addActionListener(menuHandler);
        buttonLoad.addActionListener(menuHandler);
        iterations.addActionListener(menuHandler);

        panelGame.add(iterations);
        panelGame.add(buttonTrain);
        panelGame.add(buttonNewGame);
        panelGame.add(buttonSave);
        panelGame.add(buttonLoad);

        this.validate();
    }

    @Override
    public void paint(Graphics g) {
        super.paint(g);
        if (Main.current.mainGame.rows == null)
            return;
        var rows = Main.current.mainGame.rows;
        for (int i = 0; i < rows.length; i++) {
            for (int j = 0; j < rows[0].length; j++) {
                if (rows[i][j] == 0)
                    break;

                g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
                g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
            }
        }
    }

    public void update() {
    }
}

class Handler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        if (Main.current.mainGame.playersTurn)
            Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
    }
}

class MenuHandler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        switch (event.getActionCommand()) {
        case "New Game":
            Main.current.newGame();
            break;
        case "Train":
            Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
            break;
        case "Save Weights":
            Main.current.saveWeights();
            break;
        case "Load Weights":
            Main.current.loadWeights();
            break;
        }

    }
}

Игра

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Game {

    int turnNumber = 0;
    byte[][] rows = new byte[7][6];
    boolean playersTurn = true;

    int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw

    public boolean isRunning() {
        return this.gameState == 0;
    }

    public void addChip(int x, boolean player1) {
        turnNumber++;
        byte b = nextRow(x);
        if (b == 6) {
            gameState = player1 ? 2 : 1;
            return;
        }
        rows[x][b] = (byte) (player1 ? 1 : 2);
        gameState = checkWinner(x, b);
    }

    private byte nextRow(int x) {
        for (byte i = 0; i < rows[x].length; i++) {
            if (rows[x][i] == 0)
                return i;
        }
        return 6;
    }

    // 0 continue, 1 Player won, 2 ai won, 3 Draw
    private int checkWinner(int x, int y) {
        int color = rows[x][y];
        // Vertikal
        if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
            return rows[x][y];

        // Horizontal
        if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
            return rows[x][y];

        // Diagonal1
        if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
            return rows[x][y];
        // Diagonal2
        if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
            return rows[x][y];
        
        for (byte[] bs : rows) {
            for (byte s : bs) {
                if (s == 0)
                    return 0;
            }
        }
        return 3; // Draw
    }

    private int getCount(int x, int y, int dirX, int dirY) {
        int color = rows[x][y];
        int count = 0;
        while (true) {
            x += dirX;
            y += dirY;
            if (x < 0 | x > 6 | y < 0 | y > 5)
                break;
            if (color != rows[x][y])
                break;
            count++;
        }
        return count;
    }

    public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        boolean player1 = true;
        while (this.gameState == 0) {
            float[] f = Main.rowsToInput(this.rows);
            INDArray input = Nd4j.create(f);
            this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
            player1 = !player1;
        }
    }
}

person David    schedule 05.02.2021    source источник


Ответы (2)


При беглом взгляде и на основе анализа ваших вариантов множителя кажется, что NaN создается арифметическим недостатком, вызванным слишком маленькими градиентами (слишком близкими к абсолютному 0 ).

Это самая подозрительная часть кода:

 f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);

Если rows[j][i] == 1, то сохраняется 0f. Я не знаю, как этим управляет нейронная сеть (или даже java), но с математической точки зрения число с плавающей запятой конечного размера не может содержать ноль.

Даже если ваш код изменит 0f с некоторой дополнительной солью, результирующие значения этих значений массива будут иметь некоторый риск стать слишком близкими к нулю. Из-за ограниченной точности при представлении действительных чисел значения, очень близкие к нулю, не могут быть представлены, поэтому NaN.

Эти значения имеют очень понятное название: субнормальные числа .

Любое ненулевое число, величина которого меньше наименьшего нормального числа, является субнормальным.

введите здесь описание изображения

IEEE_754

Как и в IEEE 754-1985, стандарт рекомендует 0 для сигнализации NaN, 1 для тихой NaN, так что сигнализацию NaN можно отключить, изменив только этот бит на 1, в то время как обратное может дают кодировку бесконечности.

Здесь важен приведенный выше текст: согласно стандарту, вы фактически указываете NaN с любым сохраненным значением 0f.


Даже если название вводит в заблуждение, Float.MIN_VALUE – это положительное значение, больше 0:

введите здесь описание изображения

Фактически реальное минимальное значение float составляет: -Float.MAX_VALUE.

Является ли математика с плавающей запятой ненормальной?


Нормализация градиентов

Если вы проверите, что проблема связана только со значениями 0f, вы можете просто изменить их для других значений, которые представляют что-то подобное; Float.MIN_VALUE, Float.MIN_NORMAL и так далее. Что-то вроде этого, а также в других возможных частях кода, где может произойти этот сценарий. Возьмите их просто как примеры и поиграйте со следующими диапазонами:

rows[j][i] == 1 ? Float.MIN_VALUE : 1f;

rows[j][i] == 1 ?  Float.MIN_NORMAL : Float.MAX_VALUE/2;

rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;

Тем не менее, это также может привести к NaN в зависимости от того, как эти значения изменены. Если это так, вы должны нормализовать значения. Вы можете попробовать применить GradientNormalizer для этого. При инициализации вашей сети для каждого уровня (или для тех, у кого есть проблемы) должно быть определено нечто подобное:

new NeuralNetConfiguration
  .Builder()
  .weightInit(WeightInit.XAVIER)
  (...)
  .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this   
        .build())
  
  (...)

Существуют разные нормализаторы, поэтому выберите, какой из них лучше всего подходит для вашей схемы, и какие слои должны его включать. Варианты:

Нормализация градиента

  • ПеренормироватьL2PerLayer

    Изменение масштаба градиентов путем деления на норму L2 всех градиентов для слоя.

  • ПеренормироватьL2PerParamType

    Изменение масштаба градиентов путем деления на норму градиентов L2 отдельно для каждого типа параметра в слое. Это отличается от RenormalizeL2PerLayer тем, что здесь каждый тип параметра (вес, смещение и т. д.) нормализуется отдельно. Например, в сети MLP/FeedForward (где G — вектор градиента) вывод будет следующим:

    GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)

  • ClipElementWiseAbsoluteValue

    Обрезка градиентов для каждого элемента. Для каждого градиента g установите g ‹- sign(g) max(maxAllowedValue,|g|). т. е. если градиент параметра имеет абсолютное значение больше порогового значения, усеките его. Например, если порог = 5, то значения в диапазоне -5‹g‹5 не изменяются; значения ‹-5 устанавливаются равными -5; значения ›5 устанавливаются равными 5.

  • КлипL2PerLayer

    Условная перенормировка. В чем-то похожая на RenormalizeL2PerLayer, эта стратегия масштабирует градиенты тогда и только тогда, когда норма градиентов L2 (для всего слоя) превышает указанный порог. В частности, если G — вектор градиента для слоя, то:

    GOut = G, если l2Norm(G) ‹ пороговое значение (т. е. без изменений) GOut = пороговое значение * G / l2Norm(G)

  • ClipL2PerParamType

    Условная перенормировка. Очень похоже на ClipL2PerLayer, однако вместо отсечения по слою выполняйте отсечение по каждому типу параметра отдельно. Например, в рекуррентной нейронной сети входные градиенты веса, рекуррентные градиенты веса и градиент смещения отсекаются отдельно.


Здесь вы можете найти полный пример применения этих GradientNormalizers.

person aran    schedule 05.02.2021
comment
Спасибо за подробное объяснение. Я считал, что градиент имеет значение только при обратном распространении, разве это не правильно? Я изменил свой ввод на Float.Min_Value и добавил нормализацию градиента ко всем слоям, но проблема все еще возникает - person David; 05.02.2021
comment
@Дэвид привет! Я работал над нейронными сетями, но это было так много лет назад (уф!!). Это также может произойти, если некоторые поля массива хранят абсолютные нулевые значения. Поправьте меня, если я ошибаюсь, но ваши градиенты не становятся слишком большими, верно? Они колеблются между 0 и некоторыми другими небольшими значениями. Вот что я вам скажу, я не позволю этому быть, я заинтригован, чтобы это сработало. Поэтому, пожалуйста, не стесняйтесь спрашивать некоторую информацию. Есть ли здесь какие-то разделения? Каков ваш ожидаемый результат для значений слоя? - person aran; 06.02.2021
comment
Кроме того, есть ли у вас доступ к трассировке стека, чтобы знать, где все усложняется? Мы собираемся сделать эту работу, несмотря ни на что! Даю вам слово, слово баска. Я убью половину людей, чтобы достичь этого необходимо. - person aran; 06.02.2021
comment
Значения 0f и 1f представляют противоположные значения? Я имею в виду, могут ли они быть изменены, например, на -FLOAT_MAX.VALUE/2 и FLOAT_MAX.VALUE/2, например? Или еще проще, -10f и 10f?. Тем не менее, я думаю, что первое, что нужно проверить, это то, что поблизости нет 0 с плавающей запятой. - person aran; 06.02.2021
comment
Обходной путь также может заключаться в замене поплавков на двойные, поскольку им не хватает такой точности и обычно они не переполняются для таких значений. Вы потеряете точность, но в качестве теста это может помочь определить основную проблему. Плавающая точка не работает, приятель - person aran; 06.02.2021
comment
@David в качестве обновления, взгляните на спецификацию IEEE: Как и в случае с IEEE 754-1985, стандарт рекомендует 0 для сигнализации NaN и 1 для тихих NaN — это действительно может быть проблематично - person aran; 08.02.2021

Кажется, я наконец понял это. Я пытался визуализировать сеть с помощью deeplearning4j-ui, но получил несколько ошибок несовместимости версий. После смены версий я получил новую ошибку, в которой говорилось, что ввод сети ожидает двумерный массив, и я обнаружил в Интернете, что это ожидается во всех версиях.

Так что я изменил

float[] f = new float[7 * 6];
Nd4j.create(f);

to

float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);

И значения NaN окончательно исчезли. @aran Итак, я предполагаю, что неправильные входные данные были определенно правильным направлением. Спасибо большое за вашу помощь :)

person David    schedule 08.02.2021
comment
получил некоторую информацию, связанную .. обновит в течение этих дней - person aran; 24.02.2021