Простое многократное обучение нейронной сети на необученных данных дает большие ошибки

Я сделал небольшую нейронную сеть умножения, используя библиотеку encog с сигмовидной функцией активации в базовой сети. Моя проблема в том, что я получил большие ошибки на необученных данных. Как я могу улучшить необученные данные для получения хороших результатов. Меньше ошибок.

Сначала я попробовал: от train.getError() > 0,00001 до train.getError() > 0,0000001 уменьшение ошибки до меньшего даст более четкие результаты. Но это не помогло.

Не помогло и увеличение скрытого слоя: network.addLayer(new BasicLayer(new ActivationSigmoid(),false,128));

я пытался увеличить количество нейронов на слой, но тоже не помогло

Как я могу получить более четкие результаты?

Что такое предвзятость? Когда его использовать?

Я видел: http://www.heatonresearch.com/wiki/Activation_Function Но я только с помощью сигмоида. Когда использовать другие или мне нужно изменить функцию активации?

Вот мой код:

    package org.encog.examples.neural.xor;

    import org.encog.Encog;
    import org.encog.engine.network.activation.ActivationSigmoid;
    import org.encog.ml.data.MLData;
    import org.encog.ml.data.MLDataPair;
    import org.encog.ml.data.MLDataSet;
    import org.encog.ml.data.basic.BasicMLDataSet;
    import org.encog.neural.networks.BasicNetwork;
    import org.encog.neural.networks.layers.BasicLayer;
    import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;

    import java.awt.*;
    import java.text.DecimalFormat;
    import java.text.NumberFormat;


    public class MulHelloWorld {

        /**
         * The input necessary for MUL.
         */
        public static double MUL_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 },
                { 0.2, 0.4 }, { 0.3, 0.2 } , {0.12 , 0.11} , {0.7,0.2} , {0.32,0.42} , {0.9,0.3} , {0.5,0.2} , { 0.4 , 0.6 } , {0.9,0.1} };

        /**
         * The ideal data necessary for MUL.
         */
        public static double MUL_IDEAL[][] = { { 0.0 }, { 0.0 }, { 0.08 }, { 0.06 } , {0.0132} , {0.14} , {0.1344} , {0.27} , {0.1} , {0.24} , {0.09} };


        private static BasicNetwork network;
        private static NumberFormat formatter = new DecimalFormat("###.#####");


        public static final void retrain() {
            network = new BasicNetwork();
            network.addLayer(new BasicLayer(null,true,2));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),false,128));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),false,128));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),false,128));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),false,1));
            network.getStructure().finalizeStructure();
            network.reset();

            // create training data
            MLDataSet trainingSet = new BasicMLDataSet(MUL_INPUT, MUL_IDEAL);

            // train the neural network
            final ResilientPropagation train = new ResilientPropagation(network, trainingSet );

            int epoch = 1;

            do {
                train.iteration();
                System.out.println("Epoch #" + epoch + " Error:" + formatter.format(train.getError()));
                epoch++;
            } while(train.getError() > 0.00001);
            train.finishTraining();

            // test the neural network
            System.out.println("Neural Network Results:");

            for(MLDataPair pair: trainingSet ) {
                final MLData output = network.compute(pair.getInput());
                System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
                        + ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
            }
        }

        /**
         * The main method.
         * @param args No arguments are used.
         */
        public static void main(final String args[]) {
            // create a neural network, without using a factory

            retrain();

            final double computedValue = compute(network, 0.01, 0.01);
            final double diff = computedValue - 0.0001;
            do {
                if (diff < 0.001 && diff > -0.001) {
                    String f = formatter.format(computedValue);
                    System.out.println("0.0001:"+f);
                    System.out.println("0.0002:"+formatter.format(compute(network, 0.02, 0.01)));//0.0002
                    System.out.println("0.001:"+formatter.format(compute(network, 0.05, 0.02)));//0.001
                    Toolkit.getDefaultToolkit().beep();
                    try { Thread.sleep(7000); } catch (Exception epx) {}
                    retrain();
                } else {
                    String f = formatter.format(computedValue);
                    System.out.println("0.0001:"+f);
                    System.out.println("0.0002:"+formatter.format(compute(network, 0.02, 0.01)));//0.0002
                    System.out.println("0.001:"+formatter.format(compute(network, 0.05, 0.02)));//0.001
                    System.exit(0);
                }
            } while (diff < 0.001 && diff > -0.001);

            Encog.getInstance().shutdown();
        }


        public static final double compute(BasicNetwork network, double x, double y) {
            final double value[] = new double[1];
            network.compute( new double[] { x , y } , value );
            return value[0];
        }
    }

Вот моя последняя попытка кажется немного более эффективной, но пока не очень хорошей:

    package org.encog.examples.neural.xor;

    import org.encog.Encog;
    import org.encog.engine.network.activation.ActivationSigmoid;
    import org.encog.ml.data.MLData;
    import org.encog.ml.data.MLDataPair;
    import org.encog.ml.data.MLDataSet;
    import org.encog.ml.data.basic.BasicMLDataSet;
    import org.encog.neural.networks.BasicNetwork;
    import org.encog.neural.networks.layers.BasicLayer;
    import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;

    import java.awt.*;
    import java.text.DecimalFormat;
    import java.text.NumberFormat;
    import java.util.ArrayList;


    public class MulHelloWorld {

        /**
         * The input necessary for MUL.
         */
        public static double MUL_INPUT[][] = {
                { 0.0, 0.0 }, { 1.0, 0.0 }, { 0.2, 0.4 }, { 0.3, 0.2 } ,
                {0.12 , 0.11} , {0.7,0.2} , {0.32,0.42} , {0.9,0.3} ,
                {0.5,0.2} , { 0.4 , 0.6 } , {0.9,0.1} , {0.1,0.1} ,
                {0.34,0.42} , {0.3,0.3}
        };

        /**
         * The ideal data necessary for MUL.
         */
        public static double MUL_IDEAL[][] = {
                { 0.0 }, { 0.0 }, { 0.08 }, { 0.06 } ,
                {0.0132} , {0.14} , {0.1344} , {0.27} ,
                {0.1} , {0.24} , {0.09} , {0.01} ,
                {0.1428} , {0.09} };


        private static BasicNetwork network;
        private static NumberFormat formatter = new DecimalFormat("###.##########");
        private static final double acceptableDiff = 0.01;


        public static final void retrain() {
            network = new BasicNetwork();
            network.addLayer(new BasicLayer(null,true,2));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),true,32));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),true,32));
            network.addLayer(new BasicLayer(new ActivationSigmoid(),true,1));
            network.getStructure().finalizeStructure();
            network.reset();

            ArrayList<Double> inputs = new ArrayList<Double>();
            ArrayList<Double> inputs2 = new ArrayList<Double>();
            ArrayList<Double> outputs = new ArrayList<Double>();
            double j = 0;
            int size = 64;
            for (int i = 0; i < size; i++) {
                final double random1 = Math.random();
                final double random2 = Math.random();
                inputs.add( random1 );
                inputs2.add( random2 );
                outputs.add( random1*random2 );
            }
            final Double x1[] = new Double[size];
            final Double x2[] = new Double[size];
            final Double x3[] = new Double[size];

            final Double[] inputz1 = inputs.toArray(x1);
            final Double[] inputz2 = inputs2.toArray(x2);
            final Double[] outz = outputs.toArray(x3);

            final double inputsAll[][] = new double[inputz1.length][2];
            final double outputsAll[][] = new double[inputz1.length][1];

            final int inputz1Size = inputz1.length;
            for (int x = 0; x < inputz1Size ; x++) {
                inputsAll[x][0] = inputz1[x];
                inputsAll[x][1] = inputz2[x];

                outputsAll[x][0] = outz[x];
            }

            // create training data
            MLDataSet trainingSet = new BasicMLDataSet(inputsAll, outputsAll );

            // train the neural network
            final ResilientPropagation train = new ResilientPropagation(network, trainingSet );

            int epoch = 1;
            do {
                train.iteration();
                System.out.println("Epoch #" + epoch + " Error:" + formatter.format(train.getError()));
                epoch++;
            } while(train.getError() > acceptableDiff);
            train.finishTraining();

            // test the neural network
            System.out.println("Neural Network Results:");

            for(MLDataPair pair: trainingSet ) {
                final MLData output = network.compute(pair.getInput());
                System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
                        + ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
            }
        }

        /**
         * The main method.
         * @param args No arguments are used.
         */
        public static void main(final String args[]) {
            // create a neural network, without using a factory

            retrain();


            double random3 = Math.random();
            double random4 = Math.random();
            double v2 = random3 * random4;
            double computedValue = compute(network, random3, random4);
            System.out.println(formatter.format(v2) + ":" + formatter.format(computedValue));

            final double diff = computedValue - v2;
            do {
                if (diff <  acceptableDiff || diff > -acceptableDiff ) {
                    String f = formatter.format(computedValue);
                    {
                        double random = Math.random();
                        double random1 = Math.random();
                        double v = random * random1;
                        System.out.println(formatter.format(v) + ":" + formatter.format(compute(network, random, random1)));
                    }

                    {
                        double random = Math.random();
                        double random1 = Math.random();
                        double v = random * random1;
                        System.out.println(formatter.format(v) + ":" + formatter.format(compute(network, random, random1)));
                    }

                    {
                        double random = Math.random();
                        double random1 = Math.random();
                        double v = random * random1;
                        System.out.println(formatter.format(v) + ":" + formatter.format(compute(network, random, random1)));
                    }

                    Toolkit.getDefaultToolkit().beep();
                    try { Thread.sleep(1000); } catch (Exception epx) {}
                    retrain();
                } else {
                    String f = formatter.format(computedValue);
                    System.out.println("0.0001:"+f);
                    System.out.println("0.0002:"+formatter.format(compute(network, 0.02, 0.01)));//0.0002
                    System.out.println("0.001:"+formatter.format(compute(network, 0.05, 0.02)));//0.001
                    System.exit(0);
                }
            } while (diff < acceptableDiff || diff > -acceptableDiff);

            Encog.getInstance().shutdown();
        }


        public static final double compute(BasicNetwork network, double x, double y) {
            final double value[] = new double[1];
            network.compute( new double[] { x , y } , value );
            return value[0];
        }
    }

person Kadir BASOL    schedule 07.04.2014    source источник
comment
Это может помочь исправить ваш тренировочный набор, вторая запись. 1*0 равно 0, а не 1, как указано в MUL_IDEAL[0][1]   -  person Miichi    schedule 07.04.2014


Ответы (1)


Вы можете обнаружить, что на самом деле вы слишком близко подходите к тренировочному набору, и поэтому ваша сеть плохо обобщается. Лучшей стратегией было бы иметь третий набор для проверки. Вы можете использовать эти данные, чтобы установить ошибку обучения для получения эффективного результата, а затем протестировать сеть на необученных данных.

Я не знаком с этим конкретным пакетом, но вы также можете ознакомиться с другими методами обучения. Я обнаружил, что масштабированные сопряженные градиенты часто немного лучше, чем базовое обратное распространение.

person Philip Graham    schedule 07.04.2014
comment
я обновил код сверху. Можем ли мы сделать лучше, чем последнее обновление? - person Kadir BASOL; 08.04.2014