65.9K
CodeProject 正在变化。 阅读更多。
Home

使用神经网络和遗传算法训练的自驾汽车模拟器

starIconstarIconstarIconstarIconstarIcon

5.00/5 (15投票s)

2016年12月12日

CPOL

10分钟阅读

viewsIcon

36618

这个程序模拟了一辆自动驾驶汽车学习在赛道上行驶的过程。

引言

如今,人工智能在许多方面影响着人类生活。一个例子是汽车工业;许多公司正试图让他们的汽车变得更智能。汽车可以自动驾驶、避开障碍物、寻找目的地……而无需人工控制。本文是一个汽车模拟程序,其中汽车本身将在没有任何外部控制的情况下移动。该方法使用神经网络和遗传算法来训练汽车,让汽车在每次未能完成赛道后进行学习。

Using the Code

1) FPS

每台电脑的速度不同,所以我们需要一种机制来标准化,使游戏在任何电脑上以相同的速度运行。我们有两种主要方法:updaterender。通常,游戏以 60 帧每秒运行,因此 update 方法每秒调用 60 次,而 render 方法则以电脑最快的速度调用。

package com.auto.car;
import java.awt.BasicStroke;
import java.awt.Canvas;
import java.awt.Dimension;
import java.awt.Graphics2D;
import java.awt.event.KeyEvent;
import java.awt.image.BufferStrategy;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferInt;
import javax.swing.JFrame;
import com.auto.algorithm.HelperFunction;
import com.auto.car.entity.EntityManager;
import com.auto.graphics.Screen;
import com.auto.input.Keyboard;
import com.auto.input.Mouse;

/*
 * @Description: This class contains main method - entry point to the program
 * 
 */
public class CarDemo extends Canvas implements Runnable {
    private static final long serialVersionUID = 1L;
    private static int width = 600;
    private static int height = width * 9 / 16;
    public static int scale = 2;
    private Thread thread;
    private JFrame frame;
    private boolean running = false;
    private BufferedImage image = new BufferedImage(width, height,
            BufferedImage.TYPE_INT_RGB);
    private int[] pixels = ((DataBufferInt) image.getRaster().getDataBuffer())
            .getData();
    private Screen screen;
    private Keyboard keyboard;
    private EntityManager entityManager;

    public CarDemo() {
        Dimension size = new Dimension(width * scale, height * scale);
        setPreferredSize(size);
        screen = new Screen(width, height);
        frame = new JFrame();
        Mouse mouse = new Mouse();
        addMouseListener(mouse);
        addMouseMotionListener(mouse);
        keyboard = new Keyboard();
        addKeyListener(keyboard);
        entityManager = new EntityManager();
    }

    public synchronized void start() {
        thread = new Thread(this, "Auto Car");
        running = true;
        thread.start();
    }

    public synchronized void stop() {
        running = false;
        try {
            System.out.println("Goodbye");
            thread.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    public static int fps = 60;
    @Override
    public void run() {
        long lastTime = System.nanoTime();
        long timer = System.currentTimeMillis();

        double deltaTime = 0;
        int frameCount = 0;
        int updateCount = 0;
        requestFocus();
        while (running) {
            double nanoSecondsPerCount = 1000000000.0 / fps;
            long now = System.nanoTime();
            deltaTime += (now - lastTime) / nanoSecondsPerCount;
            lastTime = now;
            while (deltaTime >= 1) {
                update();
                updateCount++;
                //System.out.println(updateCount + " - " + deltaTime);
                deltaTime--;
            }
            render();
            frameCount++;
            if (System.currentTimeMillis() - timer > 1000) {
                timer += 1000;
                frame.setTitle(updateCount + "ups, " + frameCount + " frames");
                updateCount = 0;
                frameCount = 0;
            }
        }
        stop();
    }

    private void update() {
        keyboard.update();
        if (keyboard.getKeys()[KeyEvent.VK_ESCAPE]) {
            stop();
        } else if (keyboard.getKeys()[KeyEvent.VK_R]) {
            restartSimulation();
        } else if (keyboard.getKeys()[KeyEvent.VK_DOWN]) {
            fps -= 1;
            fps = (int) HelperFunction
                    .getValueInRange(fps, 30, 300);
        } else if (keyboard.getKeys()[KeyEvent.VK_UP]) {
            fps += 1;
            fps = (int) HelperFunction
                    .getValueInRange(fps, 30, 300);
        } else if (keyboard.getKeys()[KeyEvent.VK_N]) {
            entityManager.nextMapIndex();
        } else {
            if (keyboard.getKeys()[KeyEvent.VK_SPACE]) {
                entityManager.forceToNextAgent();
            }
        }
        entityManager.update();
    }

    private void restartSimulation() {
        entityManager = new EntityManager();
    }

    private void render() {
        BufferStrategy bs = getBufferStrategy();
        if (bs == null) {
            createBufferStrategy(3);
            return;
        }
        Graphics2D g = (Graphics2D) bs.getDrawGraphics();
        screen.setGraphic(g);
        entityManager.renderByPixels(screen);
        for (int i = 0; i < pixels.length; i++) {
            pixels[i] = screen.getPixels()[i];
        }
        g.setStroke(new BasicStroke(2));
        g.drawImage(image, 0, 0, getWidth(), getHeight(), null);
        entityManager.renderByGraphics(screen);
        screen.dispose();
        bs.show();
    }

    public static int getWindowWidth() {
        return width * scale;
    }

    public static int getWindowHeight() {
        return height * scale;
    }

    public static void main(String[] args) {
        CarDemo car = new CarDemo();
        car.frame.setResizable(false);
        car.frame.setTitle("Auto Car");
        car.frame.add(car);
        car.frame.pack();
        car.frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        car.frame.setLocationRelativeTo(null);
        car.frame.setVisible(true);
        car.start();
    }
}

2) 图形

Screen 类处理程序的所有图形渲染。此程序有两种渲染方式:通过操作单个像素或使用 Java 库的内置函数。对于地图,使用操作单个像素的方法;对于其他对象,如汽车精灵、传感器线和传感器圆圈,则使用 Java 的内置函数以获得更好的图形。此类具有以下主要字段:widthheightpixelsxOffsetyOffsetgraphics。如果我们想操作单个像素,我们将更改 pixels 数组中的数据。当 Car 移动时,xOffsetyOffset 这两个值会发生变化。因为我们始终需要在屏幕上看到 Car,所以 Car 将固定在屏幕上的一个位置;我们只需调整屏幕的偏移量,让我们感觉 Car 在移动。此类中有五个主要函数:

renderTile

renderTile 方法在屏幕上渲染一个 Tile。它需要 Tile 的 x, y 坐标和 Tile 本身作为参数。

renderCar

renderCar 方法渲染 Car 对象。它需要 Car 的 x, y 坐标、Car 的朝向 angle 以及 Car 图像的 Sprite 作为参数。

renderLine

renderLine 方法需要 color 参数以及起始点和结束点的 x, y 坐标,以特定颜色渲染这些点之间的线。

renderCircle

renderCircle 方法需要圆心 x, y 坐标、半径 r 和 Circlecolor 参数。

dispose

dispose 方法用于在渲染后释放 graphics 对象资源。

package com.auto.graphics;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.geom.AffineTransform;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import com.auto.car.CarDemo;
import com.auto.car.level.tile.Tile;
/*
 * @Description: This class takes care of all graphic rendering on the screen.   
 */
public class Screen {
    private int width;
    private int height;
    private int[] pixels;
    private int xOffset, yOffset;
    private Graphics2D graphics;
    private int scale = CarDemo.scale;    
    public Screen() {

    }
    public Screen(int w, int h) {
        width = w;
        height = h;
        pixels = new int[w * h];
    }

    public int[] getPixels() {
        return pixels;
    }

    /*
     * @Description: This function renders single Tile on screen
     * @param xPosition x coordinate of Tile on screen.
     * @param yPosition y coordinate of Tile on screen.
     * @param tile the tile to be rendered.
     */
    public void renderTile(int xPosition, int yPosition, Tile tile) {
        // substract the tile position by the offset of screen when screen move
        xPosition -= xOffset;
        yPosition -= yOffset;
        for (int yTile = 0; yTile < tile.size; yTile++) {
            int yScreen = yTile + yPosition;
            for (int xTile = 0; xTile < tile.size; xTile++) {
                int xScreen = xTile + xPosition;
                // if the xScreen and yScreen are out of boundary then no
                // rendering. -tile.size is to render the boundary of the map
                // without rendering black strike.
                if (xScreen < -tile.size || xScreen >= width || yScreen < 0
                        || yScreen >= height) {
                    break;
                }
                if (xScreen < 0) {
                    xScreen = 0;
                }
                pixels[xScreen + yScreen * width] = tile.sprite.getPixels()[xTile
                        + yTile * tile.size];
            }
        }
    }

    /*
     * @Description: This function renders single Sprite on screen
     * @param xPosition x coordinate of Sprite on screen.
     * @param yPosition y coordinate of Sprite on screen.
     * @param angle The angle of the car.
     * @param sprite the sprite to be rendered.
     */
    public void renderCar(int xPosition, int yPosition, double angle,
            Sprite sprite) {
        // substract the Tile position by the offset of screen when screen moves
        xPosition -= xOffset;
        yPosition -= yOffset;
        BufferedImage img = new BufferedImage(Sprite.SIZE, Sprite.SIZE,
                BufferedImage.TYPE_INT_ARGB);
        for (int ySprite = 0; ySprite < Sprite.SIZE; ySprite++) {
            for (int xSprite = 0; xSprite < Sprite.SIZE; xSprite++) {
                int color = sprite.getPixels()[(xSprite + ySprite * Sprite.SIZE)];
                if (color != 0xffffffff) {
                    img.setRGB(xSprite, ySprite, color);
                }
            }
        }
        AffineTransform reset = new AffineTransform();
        reset.rotate(0, 0, 0);
        graphics.rotate(angle, (xPosition + Sprite.SIZE / 2) * scale,
                (yPosition + Sprite.SIZE / 2) * scale);
        graphics.drawImage(img, xPosition * scale, yPosition * scale,
                Sprite.SIZE * scale, Sprite.SIZE * scale, null);
        /*
         * graphics.drawRect(xPosition * scale, yPosition * scale, Sprite.SIZE
         * scale, Sprite.SIZE * scale);
         */
        graphics.setTransform(reset);
    }

    /*
     * @Description: draw a line between 2 points
     * @param x xCoordinate of first point
     * @param y yCoordinate of first point
     * @param x2 xCoordinate of second point
     * @param y2 yCoordinate of second point
     * @param color the color of line
     */
    public void renderLine(double x, double y, double x2, double y2, Color color) {
        graphics.setColor(color);
        graphics.drawLine(((int) (x - xOffset + Sprite.SIZE / 2)) * scale,
                ((int) (y - yOffset + Sprite.SIZE / 2)) * scale, ((int) (x2
                        - xOffset + Sprite.SIZE / 2))
                        * scale, ((int) (y2 - yOffset + Sprite.SIZE / 2))
                        * scale);
    }

    /*
     * @Description: render a circle from a certain point
     * @param x xCoordinate of center of circle
     * @param y yCoordinate of center of cicle
     * @param r radius of circle
     * @param color the color of circle
     */
    public void renderCircle(double x, double y, double r, Color color) {
        graphics.setColor(color);
        graphics.drawOval((int) (x - xOffset - r + Sprite.SIZE / 2) * scale,
                (int) (y - r - yOffset + Sprite.SIZE / 2) * scale,
                (int) (2 * r * scale), (int) (2 * r * scale));
    }

    public void setGraphic(Graphics2D g) {
        this.graphics = g;
    }

    public void renderStatistics(ArrayList<String> info) {
        graphics.setColor(Color.black);
        graphics.setFont(graphics.getFont().deriveFont(20f));
        graphics.drawString(info.get(0), 700, 20);
        graphics.drawString(info.get(1), 700, 40);
        graphics.drawString(info.get(2), 700, 60);
        graphics.drawString(info.get(3), 700, 80);
    }

    public void dispose() {
        graphics.dispose();
    }
    public int getHeight() {
        return height;
    }
    public void setHeight(int height) {
        this.height = height;
    }
    public int getWidth() {
        return width;
    }
    public void setWidth(int width) {
        this.width = width;
    }
    public void setOffset(int xOffset, int yOffset) {
        this.xOffset = xOffset;
        this.yOffset = yOffset;
    }
    public int getXOffset() {
        return xOffset;
    }
    public int getYOffset() {
        return yOffset;
    }
}

地图使用 16x16 的瓦片来平铺屏幕。该程序中有三种瓦片:GrassTileBrickTileVoidTile,它们都继承自 Tile 类。GrassTileVoidTile 是非碰撞瓦片;同时,BrickTile 是碰撞瓦片。BrickTile 用于构建赛道的墙壁,因此 Car 可以检测到碰撞。VoidTile 将用于覆盖屏幕上没有需要渲染的地方。

package com.auto.car.level.tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;
/*
 * @Description: Tile class is the parent class of all other Tiles.
 */
public class Tile {
    public int x, y;
    public Sprite sprite;
    public int size;
    public static Tile grass = new GrassTile(Sprite.grass,16);
    public static Tile brick = new BrickTile(Sprite.brick,16);
    public static Tile voidTile = new VoidTile(Sprite.voidSprite,16);
    public static final int grassColor = 0xff00ff00;
    public static final int brickColor = 0xffFFD800;

    public Tile(Sprite sprite, int size) {
        this.sprite = sprite;
        this.size = size;
    }
    public void render(int x, int y, Screen screen) {

    }
    public boolean solid() {
        return false;
    }
}
package com.auto.car.level.tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;

public class BrickTile extends Tile {
    public BrickTile(Sprite sprite, int size) {
        super(sprite, size);
    }
    public void render(int x, int y, Screen screen) {
        screen.renderTile(x, y, this);
    }
    /*
     * BrickTile is solid so we can detect collision
     */
    public boolean solid() {
        return true;
    }
}
package com.auto.car.level.tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;

public class GrassTile extends Tile {
    public GrassTile(Sprite sprite, int size) {
        super(sprite, size);
    }
    public void render(int x, int y, Screen screen) {
        screen.renderTile(x, y, this);
    }
}
package com.auto.car.level.tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;
/*
 * @Description: VoidTile is render at where there's nothing on screen 
 */
public class VoidTile extends Tile {
    public VoidTile(Sprite sprite, int size) {
        super(sprite, size);
    }
    public void render(int x, int y, Screen screen) {
        screen.renderTile(x, y, this);
    }
}

         (Sprite Sheet)

一个 Sprite 对象保存着 Sprite Sheet 中一个 16x16 图像的像素。Sprite 用于设计程序中的对象,例如 CarBrickGrassTile 类可以加载 Sprite 图像,也可以简单地用颜色填充自身。在此程序中,VoidTile 仅用颜色填充,而 BrickGrassCar Sprite 将从 SpriteSheet 加载图像。Sprite 的背景(白色)将不会被渲染。

package com.auto.graphics;
/*
 * @Description: This class represents the individual Sprite which is got from
 * the SpriteSheet. The SIZE of Sprite is fixed which is 16 in this program
 */
public class Sprite {
    // SIZE in Decimal
    public static final int SIZE = 16;
    // SIZE in Binary
    public static final int SIZE_2BASED = 4;// 2^4
    /*
     * The coordinate of sprite in the SpriteSheet. Every unit of x and y is
     * equal to 16 pixels in SpriteSheet
     */
    private int x, y;
    private int[] pixels;
    private SpriteSheet sheet;
    /*
     * Preloaded Sprites
     */
    public static Sprite grass = new Sprite(0, 0, SpriteSheet.tiles16x16);
    public static Sprite brick = new Sprite(2, 0, SpriteSheet.tiles16x16);
    public static Sprite voidSprite = new Sprite(0xE6FFA3);
    public static Sprite carSprite = new Sprite(4, 0, SpriteSheet.tiles16x16);

    /*
     * Sprite Constructor
     */
    public Sprite(int x, int y, SpriteSheet sheet) {
        pixels = new int[SIZE * SIZE];
        this.x = x * SIZE;
        this.y = y * SIZE;
        this.sheet = sheet;
        load();
    }
    public Sprite(int colour) {
        pixels = new int[SIZE * SIZE];
        setColour(colour);
    }
    private void setColour(int colour) {
        for (int i = 0; i < SIZE * SIZE; i++) {
            pixels[i] = colour;
        }
    }
    /*
     * This method get data from the SpriteSheet and load it into the Sprite
     */
    private void load() {
        for (int i = 0; i < SIZE; i++) {
            for (int j = 0; j < SIZE; j++) {
                pixels[j + i * SIZE] = sheet.pixels[(j + this.x) + (i + this.y)
                        * sheet.getSize()];
            }
        }
    }
    public int[] getPixels() {
        return pixels;
    }
}

         (Level)

Level 是瓦片的小地图。一个 Level 会加载一个图像文件,然后根据每个像素的颜色,它将确定应该渲染哪个 Tile。图像中的一个像素在屏幕上渲染时,将被程序中的 16 个 Tile 像素替换。在此程序中,颜色 0x00FF00(绿色)将代表 GrassTile,颜色 0xFFD800(深黄色)将代表 BrickTile;任何其他未定义的颜色都将被 VoidTile 替换。SpawnLevel 类继承自 Level 类,并覆盖了 loadLevel 方法。此方法将从路径加载图像并将数据存储在 pixels 数组中。之后,图像上的颜色将替换为 Tile 以在屏幕上渲染。下图是 Level 的示例。

package com.auto.car.level;
import com.auto.car.level.tile.Tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;
/*
 * @Description: Level class represents the map of our program
 */
public class Level {
    protected int width, height;
    protected int[] pixels;
    public Level(String path) {
        loadLevel(path);
    }
    protected void loadLevel(String path) {
    }
    public void update() {
    }
    /*
     * @Description: render map on screen
     * @param xScroll: the xOffset of screen
     * @param yScroll: the yOffset of screen
     */
    public void render(int xScroll, int yScroll, Screen screen) {
        screen.setOffset(xScroll, yScroll);
        // because every single pixel in Level is equal to SIZE of a Sprite
        // so we have to convert the coordinate of screen into coordinate in
        // pixels of Level
        int xMostLeft = xScroll >> Sprite.SIZE_2BASED;
        int xMostRight = (xScroll + screen.getWidth() + Sprite.SIZE) >> Sprite.SIZE_2BASED;
        int yMostTop = yScroll >> Sprite.SIZE_2BASED;
        int yMostBottom = (yScroll + screen.getHeight() + Sprite.SIZE) >> Sprite.SIZE_2BASED;
        for (int y = yMostTop; y < yMostBottom; y++) {
            for (int x = xMostLeft; x < xMostRight; x++) {
                if (x < 0 || y < 0 || x >= width || y >= height) {
                    // We have to convert the Level coordinate back to screen
                    // coordinate before rendering it on screen
                    Tile.voidTile.render(x << Sprite.SIZE_2BASED,
                            y << Sprite.SIZE_2BASED, screen);
                    continue;
                }
                getTile(x, y).render(x << Sprite.SIZE_2BASED,
                        y << Sprite.SIZE_2BASED, screen);
            }
        }
    }

    /*
     * @Description: each pixels in Level object represents a Tile.
     * @param x: xCoordinate
     * @param y: yCoordinate
     */
    public Tile getTile(int x, int y) {
        int index = x + y * width;
        if (index >= 0 && index < pixels.length) {
            switch (pixels[index]) {
            case Tile.grassColor:
                return Tile.grass;
            case Tile.brickColor:
                return Tile.brick;
            }
        }
        return Tile.voidTile;
    }
}

Entity 类是 CarMob 类的父类(图 9)(Mob 类代表 Mobile 类,它创建可在屏幕上移动的对象)。它有 2 个字段用于 x, y 坐标以及 x 和 y 的 getter 和 setter。

Mob 类继承自 Entity 类,它有两个函数:moveisCollidedMob 对象只有在该位置没有碰撞时才能移动到新位置。为了检测碰撞,我们将检测该位置的 Tile。如果 Tile 是实心的,则发生碰撞;如果不是,则没有碰撞。Mob 对象还有一个 Sprite 用于在屏幕上渲染。

Car 类继承自 Mob 类。它包含汽车行驶方向的 angle 信息。其 deltaDistance 字段是汽车刚刚移动的距离。它的 5 个传感器,EASTNORTH EASTNORTHWESTNORTH WEST 用于在 Car 绕着赛道行驶时检测汽车与墙壁之间的距离。intersections 是 5 个传感器与赛道之间的交点。这些交点将在此程序中用小黄圈表示。检测到距离后,它们将被标准化并发送到神经网络以做出决策。Car 实体的主要方法是:

buildFeelers

此方法将计算每个触角头部和尾部的坐标。所有触角的尾部都将与汽车当前位置具有相同的坐标。其他触角的头部将根据汽车当前行驶的 angle 进行计算。

detectFeelerIntersection

为了检测触角与墙壁之间的交点,首先我们构建一条穿过该触角头部和尾部的线。在计算完线的相关信息后,我们将从该触角的尾部到头部进行迭代,以确定碰撞点。在这里,我们需要 Level 对象的 getTile 函数的帮助。我们将线上的点的坐标传递给此函数以获取 Tile 对象。如果 Tile 对象不是 null 并且是实心的,那么我们就在这一点找到了触角与墙壁之间的交点。

update

汽车将不断构建触角,检测与墙壁的交点,然后将数据发送到神经网络。神经网络将调用其 update 方法,然后给出左右力的信息,帮助汽车向左或向右转弯。汽车利用这些数据计算将要转弯的 angle 和将要移动的 deltaDistance,然后将此信息发送到其 move 方法。

render

在调用 buildFeelersdetectFeelerIntersection 两个方法之后,此方法将使用这两个方法的输出绘制 Car 本身、其传感器的线和交点圆圈。

package com.auto.car.entity;
import com.auto.car.level.Level;
import com.auto.graphics.Screen;

public abstract class Entity {
    protected int x, y;
    protected Level level;
    public void update() {

    }
    public void render(Screen screen) {

    }
    public int getY() {
        return y;
    }

    public void setY(int y) {
        this.y = y;
    }
    public int getX() {
        return x;
    }
    public void setX(int x) {
        this.x = x;
    }
}
package com.auto.car.entity.mob;

import com.auto.car.entity.Entity;
import com.auto.graphics.Sprite;

public abstract class Mob extends Entity {
    protected Sprite sprite;
    protected boolean collided = false;
    public void move(int xPosition, int yPosition) {
        if (xPosition != 0 && yPosition != 0) {
            move(xPosition, 0);
            move(0, yPosition);
            return;
        }
        if (!isCollided(xPosition, yPosition)) {
            x += xPosition;
            y += yPosition;
        }
    }
    public void update() {
    }
    private boolean isCollided(int xPosition, int yPosition) {
        for (int corner = 0; corner < 4; corner++) {
            int xt = ((x + xPosition) + (corner % 2) * 7 + 5) >> 4;
            int yt = ((y + yPosition) + (corner / 2) * 12 + 3) >> 4;
            if (level.getTile(xt, yt).solid()) {
                collided = true;
            }
        }
        return collided;
    }

    public void render() {

    }
}
package com.auto.car.entity.mob;

import java.awt.Color;
import java.awt.geom.Line2D;
import java.awt.geom.Point2D;
import java.awt.geom.Rectangle2D;
import com.auto.algorithm.NeuralNetwork;
import com.auto.car.entity.EntityManager;
import com.auto.car.level.Level;
import com.auto.car.level.tile.Tile;
import com.auto.graphics.Screen;
import com.auto.graphics.Sprite;

/*
 * Description: Car has 5 sensors to measure the distance from its center
 * to the wall. It use NeuralNetwork, pass in these 5 sensors's values to 
 * get out the decision of making turn left or right or going straight.
 */
public class Car extends Mob {
    // The identifiers for the feelers of the agnet
    public static int FEELER_COUNT = 5;

    public static enum SensorFeelers {
        FEELER_EAST, FEELER_NORTH_EAST, FEELER_NORTH, FEELER_NORTH_WEST, FEELER_WEST,
    };

    public static float FEELER_LENGTH = 32.0f;
    private double angle;
    private double deltaDistance;
    private Sensor sensor;
    private NeuralNetwork neural;
    private Point2D[] intersections;
    private double[] normalizedIntersectionDepths;
    public int eastIdx = SensorFeelers.FEELER_EAST.ordinal();
    public int northEastIdx = SensorFeelers.FEELER_NORTH_EAST.ordinal();
    public int northIdx = SensorFeelers.FEELER_NORTH.ordinal();
    public int northWestIdx = SensorFeelers.FEELER_NORTH_WEST.ordinal();
    public int westIdx = SensorFeelers.FEELER_WEST.ordinal();
    public static final float MAX_ROTATION_PER_SECOND = 30.0f / 180;

    public class Sensor {
        public Point2D[] feelerTails;
        public Point2D[] feelerHeads;
        public Sensor() {
            feelerTails = new Point2D[FEELER_COUNT];
            feelerHeads = new Point2D[FEELER_COUNT];
        }
    }

    public Car(int x, int y, double angle, Level lv) {
        this.x = x;
        this.y = y;
        this.angle = angle;
        sensor = new Sensor();
        level = lv;
        buildFeelers();
        detectFeelerIntersection();
    }

    public void setLevel(Level lv) {
        level = lv;
    }

    public void update() {
        if (!this.collided) {
            buildFeelers();
            detectFeelerIntersection();
            neural.setInput(normalizedIntersectionDepths);
            neural.update();
            double leftForce = neural
                    .getOutput(EntityManager.NeuralNetOuputs.NN_OUTPUT_LEFT_FORCE
                            .ordinal());
            double rightForce = neural
                    .getOutput(EntityManager.NeuralNetOuputs.NN_OUTPUT_RIGHT_FORCE
                            .ordinal());
            System.out.println("left force: " + leftForce + "-right force: "
                    + rightForce);
            // Convert the outputs to a proportion of how much to turn.
            double leftTheta = MAX_ROTATION_PER_SECOND * leftForce;
            double rightTheta = MAX_ROTATION_PER_SECOND * rightForce;
            angle += (leftTheta - rightTheta) * 2;
            double movingX = Math.sin(angle) * 2;
            double movingY = -Math.cos(angle) * 2;
            deltaDistance = Math.sqrt(movingX * movingX + movingY * movingY);
            move((int) movingX, (int) movingY);
        }
    }

    /*
     * private void turn(int x, int y, int direction) { if (Mouse.getButton() ==
     * 1) { double dx = Mouse.getX() - CarDemo.getWindowWidth() / 2 - 24; double
     * dy = Mouse.getY() - CarDemo.getWindowHeight() / 2 - 24; angle =
     * Math.atan2(dy, dx); }
     * 
     * }
     */

    public Rectangle2D getCarBound() {
        return new Rectangle2D.Double(x - Sprite.SIZE / 2, y - Sprite.SIZE / 2,
                Sprite.SIZE, Sprite.SIZE);
    }

    /*
     * this function determines the position of sensors when the car turns
     */
    private void buildFeelers() {

        for (int i = 0; i < FEELER_COUNT; i++) {
            sensor.feelerHeads[i] = new Point2D.Float();
            sensor.feelerTails[i] = new Point2D.Float();
            // All feelers's tails has the same coordinate which is the center
            // of the car.
            sensor.feelerTails[i].setLocation(x, y);
        }
        // East feeler's head
        sensor.feelerHeads[eastIdx].setLocation(
                x + Math.sin(Math.PI - (angle + Math.PI / 2)) * FEELER_LENGTH,
                y + Math.cos(Math.PI - (angle + Math.PI / 2)) * FEELER_LENGTH);
        // North East feeler's head
        sensor.feelerHeads[northEastIdx].setLocation(
                x + Math.sin(Math.PI - (angle + Math.PI / 4)) * FEELER_LENGTH,
                y + Math.cos(Math.PI - (angle + Math.PI / 4)) * FEELER_LENGTH);
        // North feeler's head
        sensor.feelerHeads[northIdx].setLocation(x + Math.sin(Math.PI - angle)
                * FEELER_LENGTH, y + Math.cos(Math.PI - angle) * FEELER_LENGTH);
        // North West feeler's head
        sensor.feelerHeads[northWestIdx].setLocation(
                x + Math.sin(Math.PI - (angle - Math.PI / 4)) * FEELER_LENGTH,
                y + Math.cos(Math.PI - (angle - Math.PI / 4)) * FEELER_LENGTH);
        // West feeler's head
        sensor.feelerHeads[westIdx].setLocation(
                x + Math.sin(Math.PI - (angle - Math.PI / 2)) * FEELER_LENGTH,
                y + Math.cos(Math.PI - (angle - Math.PI / 2)) * FEELER_LENGTH);
    }

    /*
     * This function measures the distance from center of the car to the wall.
     */
    public void detectFeelerIntersection() {
        intersections = new Point2D[FEELER_COUNT];
        normalizedIntersectionDepths = new double[FEELER_COUNT];
        for (int k = 0; k < FEELER_COUNT; k++) {
            double xStart = sensor.feelerHeads[k].getX();
            double xEnd = sensor.feelerTails[k].getX();
            double yStart = sensor.feelerHeads[k].getY();
            double yEnd = sensor.feelerTails[k].getY();
            Line2D line = new Line2D.Double();
            line.setLine(sensor.feelerHeads[k], sensor.feelerTails[k]);
            double step = 0.001;
            double slope = (yStart - yEnd) / (xStart - xEnd);
            if (!java.lang.Double.isInfinite(slope)) {
                for (double i = xStart; i < xEnd; i += step) {
                    double j = slope * (i - xEnd) + yEnd;
                    Tile tile = level.getTile((int) (i + Sprite.SIZE / 2)
                            / Sprite.SIZE, (int) (j + Sprite.SIZE / 2)
                            / Sprite.SIZE);
                    if (tile != null) {
                        if (tile.solid()) {
                            intersections[k] = new Point2D.Float();
                            intersections[k].setLocation(i, j);
                        }
                    }
                }
                for (double i = xStart; i > xEnd; i -= step) {
                    double j = slope * (i - xEnd) + yEnd;
                    Tile tile = level.getTile((int) (i + Sprite.SIZE / 2)
                            / Sprite.SIZE, (int) (j + Sprite.SIZE / 2)
                            / Sprite.SIZE);
                    if (tile != null) {
                        if (tile.solid()) {
                            intersections[k] = new Point2D.Float();
                            intersections[k].setLocation(i, j);
                        }
                    }
                }
            } else {
                for (double j = yStart; j < yEnd; j += step) {
                    Tile tile = level.getTile((int) (xStart + Sprite.SIZE / 2)
                            / Sprite.SIZE, (int) (j + Sprite.SIZE / 2)
                            / Sprite.SIZE);
                    if (tile != null) {
                        if (tile.solid()) {
                            intersections[k] = new Point2D.Float();
                            intersections[k].setLocation(xStart, j);
                        }
                    }
                }
                for (double j = yStart; j > yEnd; j -= step) {
                    Tile tile = level.getTile((int) (xStart + Sprite.SIZE / 2)
                            / Sprite.SIZE, (int) (j + Sprite.SIZE / 2)
                            / Sprite.SIZE);
                    if (tile != null) {
                        if (tile.solid()) {
                            intersections[k] = new Point2D.Float();
                            intersections[k].setLocation(xStart, j);
                        }
                    }
                }
            }
            if (intersections[k] != null) {
                normalizedIntersectionDepths[k] = 1 - (Math.sqrt(Math.pow(x
                        - intersections[k].getX(), 2)
                        + Math.pow(y - intersections[k].getY(), 2)) / FEELER_LENGTH);
            } else {
                normalizedIntersectionDepths[k] = 0;
            }
        }
    }

    public void attach(NeuralNetwork neuralNet) {
        this.neural = neuralNet;
    }

    public void setPosition(Point2D defaultPosition) {
        x = (int) defaultPosition.getX();
        y = (int) defaultPosition.getY();
    }

    public void clearFailure() {
        collided = false;
    }

    public boolean hasFailed() {
        return collided;
    }

    public double getDistanceDelta() {
        return deltaDistance;
    }

    public void render(Screen screen) {
        // Render the car
        screen.renderCar(x, y, angle, Sprite.carSprite);
        // Render 5 sensors around the car
        screen.renderLine(sensor.feelerHeads[eastIdx].getX(),
                sensor.feelerHeads[eastIdx].getY(),
                sensor.feelerTails[eastIdx].getX(),
                sensor.feelerTails[eastIdx].getY(), Color.YELLOW);
        screen.renderLine(sensor.feelerHeads[northEastIdx].getX(),
                sensor.feelerHeads[northEastIdx].getY(),
                sensor.feelerTails[northEastIdx].getX(),
                sensor.feelerTails[northEastIdx].getY(), Color.YELLOW);

        screen.renderLine(sensor.feelerHeads[northIdx].getX(),
                sensor.feelerHeads[northIdx].getY(),
                sensor.feelerTails[northIdx].getX(),
                sensor.feelerTails[northIdx].getY(), Color.black);

        screen.renderLine(sensor.feelerHeads[northWestIdx].getX(),
                sensor.feelerHeads[northWestIdx].getY(),
                sensor.feelerTails[northWestIdx].getX(),
                sensor.feelerTails[northWestIdx].getY(), Color.YELLOW);

        screen.renderLine(sensor.feelerHeads[westIdx].getX(),
                sensor.feelerHeads[westIdx].getY(),
                sensor.feelerTails[westIdx].getX(),
                sensor.feelerTails[westIdx].getY(), Color.YELLOW);
        screen.renderCircle(x, y, FEELER_LENGTH, Color.YELLOW);
        // draw collisions by a small circle
        for (int k = 0; k < FEELER_COUNT; k++) {
            if (intersections[k] != null) {
                screen.renderCircle(intersections[k].getX(),
                        intersections[k].getY(), 3, Color.YELLOW);
            }
        }
    }

    public void setPosition(int x, int y) {
        setX(x);
        setY(y);
    }

3) 遗传算法

遗传算法用于训练神经网络;它帮助神经网络做出更好的决策。

Genome 类

一个基因组(图 10)包含三部分信息:IDfitnessweightsfitness 信息是汽车在没有碰撞的情况下能够移动的距离,weights 信息是 Sigmoid 值的随机列表,范围从 -11

GeneticAlgorithm 类

GeneticAlgorithm 类有一个名为 population 的基因组列表。每个 population 都包含具有高适应度的优秀基因组。这些基因组将被混合和变异以创建新的基因组 population。此类包含以下主要方法:

generateNewPopulation

此函数将生成一个新的基因组 population。基因组的 weights 将被随机分配一个 Sigmoid 值;IDfitness 值将被分配为 0

breedPopulation

此函数将使用旧的 population 生成新的 population。首先,从旧的 population 中选择 4 个具有高适应度的基因组。之后,我们将对这 4 个基因组进行变异和交叉,并添加到新的 population 中。新 population 的剩余位置将由新的随机 Genome 填充。

crossOver

crossOver 函数是一个将两个基因组混合并创建另外两个新基因组的函数。为了混合它们,我们首先随机选择交叉点,然后将两个基因组从该点分成四个部分,然后将四个部分混合在一起。

mutate

mutate 函数是一个随机选择基因组中的一个基因,并为其分配一个新值的函数,新值是随机 Sigmoid 值乘以常数 MAX_PERMUTATION = 0.3 再加上该基因的权重(如果 Sigmoid 值小于 MUTATION_RATE = 0.15)。

package com.auto.algorithm;
import java.util.ArrayList;
/*
 * @Description: Genome object simply keeps 2 important info: fitness and weights. 
 * The fitness is the distance that how long the car could go. The weights are
 * a list of Sigmoid values which is from -1 to 1.
 */
public class Genome {
    public int ID;
    public double fitness;
    public ArrayList<Double> weights;
}

package com.auto.algorithm;
import java.util.ArrayList;
import java.util.Random;

/*
 * @Description: GeneticAlgorithm is used to train neuron network. It has a list of 
 * genomes called population. Each population will have some of the best genomes
 * with highest fitness. Fitness is the sum of distance that how far a car could go. 
 * The best genomes will be used
 * to create other genomes by mixing them up (crossing over between 2 genomes) and mutate
 * some of their genes. It is a little bit different with basic genetic algorithm, 
 * the mutate will not turn on and off a gene but they randomly change the weight of genes.
 */
public class GeneticAlgorithm {
    public static final float MAX_PERMUTATION = 0.3f;
    public static final float MUTATION_RATE = 0.15f;
    private int currentGenome;
    private int totalPopulation;
    private int genomeID;
    private int generation;
    private ArrayList<Genome> population;

    public GeneticAlgorithm() {
        currentGenome = -1;
        totalPopulation = 0;
        genomeID = 0;
        generation = 1;
        population = new ArrayList<Genome>();
    }

    /*
     * Generate genomes population with ID, fitness and random Sigmoid weights
     */
    public void generateNewPopulation(int totalPop, int totalWeights) {
        generation = 1;
        clearPopulation();
        currentGenome = -1;
        totalPopulation = totalPop;
        for (int i = 0; i < totalPopulation; i++) {
            Genome genome = new Genome();
            genome.ID = genomeID;
            genome.fitness = 0.0;
            genome.weights = new ArrayList<Double>();
            for (int j = 0; j < totalWeights; j++) {
                genome.weights.add(HelperFunction.RandomSigmoid());
            }
            genomeID++;
            population.add(genome);
        }
    }

    public void setGenomeFitness(double fitness, int index) {
        if (index >= population.size() || index < 0)
            return;
        population.get(index).fitness = fitness;
    }

    public Genome getNextGenome() {
        currentGenome++;
        if (currentGenome >= population.size())
            return null;
        return population.get(currentGenome);
    }

    public void clearPopulation() {
        population.clear();
    }

    /*
     * This function will generate new population of genomes based on best 4
     * genomes (genomes which have highest fitness). The best genomes will be
     * mixed up and mutated to create new genomes.
     */
    public void breedPopulation() {
        ArrayList<Genome> bestGenomes = new ArrayList<Genome>();
        // Find the 4 best genomes which have highest fitness.
        bestGenomes = getBestGenomes(4);
        ArrayList<Genome> children = new ArrayList<Genome>();
        // Carry on the best genome.
        Genome bestGenome = new Genome();
        bestGenome.fitness = 0.0;
        bestGenome.ID = bestGenomes.get(0).ID;
        bestGenome.weights = bestGenomes.get(0).weights;
        // mutate few gene of genome to create new genome
        mutate(bestGenome);
        children.add(bestGenome);
        // Child genomes.
        ArrayList<Genome> crossedOverGenomes;
        // Breed with genome 0.
        crossedOverGenomes = crossOver(bestGenomes.get(0), bestGenomes.get(1));
        mutate(crossedOverGenomes.get(0));
        mutate(crossedOverGenomes.get(1));
        children.add(crossedOverGenomes.get(0));
        children.add(crossedOverGenomes.get(1));
        crossedOverGenomes = crossOver(bestGenomes.get(0), bestGenomes.get(2));
        mutate(crossedOverGenomes.get(0));
        mutate(crossedOverGenomes.get(1));
        children.add(crossedOverGenomes.get(0));
        children.add(crossedOverGenomes.get(1));
        crossedOverGenomes = crossOver(bestGenomes.get(0), bestGenomes.get(3));
        mutate(crossedOverGenomes.get(0));
        mutate(crossedOverGenomes.get(1));
        children.add(crossedOverGenomes.get(0));
        children.add(crossedOverGenomes.get(1));

        // Breed with genome 1.
        crossedOverGenomes = crossOver(bestGenomes.get(1), bestGenomes.get(2));
        mutate(crossedOverGenomes.get(0));
        mutate(crossedOverGenomes.get(1));
        children.add(crossedOverGenomes.get(0));
        children.add(crossedOverGenomes.get(1));
        crossedOverGenomes = crossOver(bestGenomes.get(1), bestGenomes.get(3));
        mutate(crossedOverGenomes.get(0));
        mutate(crossedOverGenomes.get(1));
        children.add(crossedOverGenomes.get(0));
        children.add(crossedOverGenomes.get(1));
        // For the remainding n population, add some random genomes.
        int remainingChildren = (totalPopulation - children.size());
        for (int i = 0; i < remainingChildren; i++) {
            children.add(this.createNewGenome(bestGenomes.get(0).weights.size()));
        }
        clearPopulation();
        population = children;
        currentGenome = -1;
        generation++;
    }

    private Genome createNewGenome(int totalWeights) {
        Genome genome = new Genome();
        genome.ID = genomeID;
        genome.fitness = 0.0f;
        genome.weights = new ArrayList<Double>();
        for (int j = 0; j < totalWeights; j++) {
            genome.weights.add(HelperFunction.RandomSigmoid());
        }
        genomeID++;
        return genome;
    }

    /*
     * This function will mix up two genomes to create 2 other new genomes
     */
    private ArrayList<Genome> crossOver(Genome g1, Genome g2) {
        Random random = new Random(System.nanoTime());
        // Select a random cross over point.
        int totalWeights = g1.weights.size();
        int crossover = Math.abs(random.nextInt()) % totalWeights;
        ArrayList<Genome> genomes = new ArrayList<Genome>();
        Genome genome1 = new Genome();
        genome1.ID = genomeID;
        genome1.weights = new ArrayList<Double>();
        genomeID++;
        Genome genome2 = new Genome();
        genome2.ID = genomeID;
        genome2.weights = new ArrayList<Double>();
        genomeID++;
        // Go from start to crossover point, copying the weights from g1 to children.
        for (int i = 0; i < crossover; i++) {
            genome1.weights.add(g1.weights.get(i));
            genome2.weights.add(g2.weights.get(i));
        }
        // Go from start to crossover point, copying the weights from g2 to children.
        for (int i = crossover; i < totalWeights; i++) {
            genome1.weights.add(g2.weights.get(i));
            genome2.weights.add(g1.weights.get(i));
        }
        genomes.add(genome1);
        genomes.add(genome2);
        return genomes;
    }

    /*
     * Generate a random chance of mutating the weight in the genome.
     */
    private void mutate(Genome genome) {
        for (int i = 0; i < genome.weights.size(); ++i) {
            double randomSigmoid = HelperFunction.RandomSigmoid();
            if (randomSigmoid < MUTATION_RATE) {
                genome.weights.set(i, genome.weights.get(i)
                        + (randomSigmoid * MAX_PERMUTATION));
            }
        }
    }

    /*
     * Get the best genomes to breed new population
     */
    private ArrayList<Genome> getBestGenomes(int totalGenomes) {
        int genomeCount = 0;
        int runCount = 0;
        ArrayList<Genome> bestGenomes = new ArrayList<Genome>();
        while (genomeCount < totalGenomes) {
            if (runCount > 10) {
                break;
            }
            runCount++;
            // Find the best cases for cross breeding based on fitness score.
            double bestFitness = 0;
            int bestIndex = -1;
            for (int i = 0; i < this.totalPopulation; i++) {
                if (population.get(i).fitness > bestFitness) {
                    boolean isUsed = false;
                    for (int j = 0; j < bestGenomes.size(); j++) {
                        if (bestGenomes.get(j).ID == population.get(i).ID) {
                            isUsed = true;
                        }
                    }
                    if (isUsed == false) {
                        bestIndex = i;
                        bestFitness = population.get(bestIndex).fitness;
                    }
                }
            }
            if (bestIndex != -1) {
                genomeCount++;
                bestGenomes.add(population.get(bestIndex));
            }

        }
        return bestGenomes;
    }

    public int getCurrentGeneration() {
        return generation;
    }
    public int getCurrentGenomeIndex() {
        return currentGenome;
    }

}

4) 神经网络

Neuron 类

神经元是神经网络中的基本元素。它有两个字段:输入到神经元的 numberOfInputs 以及这些输入的值,称为 weights。在此程序中,神经元将接收来自基因组的数据并存储到其 weights 中。这些 weights 将被神经元层用于评估输出。

Neuronslayer 类

NeuronsLayer 类(图 13)包含一个神经元列表。它将使用这些神经元的 weights 来评估并给出输出。神经元层有两种类型。一种是隐藏层,一种是输出层。这些层将由神经网络管理。此类的主要方法是 evaluate 方法。在此方法中,我们将 inputs 的值乘以神经元的 weights 的值进行求和,并加上最后一个权重乘以常数 BIAS = -1 的值。BIAS 值的目的是确保输出不为 0。之后,求和结果将通过 Sigmoid 函数进行归一化并存储在 outputs 中。

NeuralNetwork 类

NeuralNetwork 类(图 14)包含 1 个输出层和 1 个或多个隐藏层。这些是此类主要方法:

setInput

此方法接收来自汽车传感器的输入并存储它们。

getOutput

此方法接收索引并给出该索引处的数据。

update

NeuralNetwork 不断从汽车传感器接收数据,将数据传递到隐藏层进行处理,然后将输出传输到输出层进行第二次处理。之后,输出层将决策反馈给汽车进行转向。

fromGenome

此函数将从 GeneticAlgorithm 的基因组中获取 weights,以存储在神经元层中。

package com.auto.algorithm;

import java.util.ArrayList;

/*
 * @Description: Neuron is the basic element in the neuron network. Each neuron has
 * a certain number of inputs. In this program Neuron from HiddenLayer has 5 inputs which
 * are from the car's 5 sensors and Neuron in OutputLayer has 8 inputs which are from
 * 8 HiddenLayers
 */
public class Neuron {
    protected int numberOfInputs;
    protected ArrayList<Double> weights;

    public void init(ArrayList<Double> weightsIn, int numOfInputs) {
        this.numberOfInputs = numOfInputs;
        weights = weightsIn;
    }
}
package com.auto.algorithm;

import java.util.ArrayList;

/*
 * @Description: NeuronsLayer contains Neurons. It evaluates these nerons to give
 * out decision.
 */
public class NeuronsLayer {
    public static final float BIAS = -1.0f;
    private int totalNeurons;
    // int totalInputs;
    private ArrayList<Neuron> neurons;

    /*
     * Evaluate the inputs from sensors or HiddenLayer and give out the output
     */
    public void evaluate(ArrayList<Double> inputs, ArrayList<Double> outputs) {
        int inputIndex = 0;
        for (int i = 0; i < totalNeurons; i++) {
            float activation = 0.0f;
            int numOfInputs = neurons.get(i).numberOfInputs;
            Neuron neuron = neurons.get(i);
            // sum the weights up to numberOfInputs-1 and add the bias
            for (int j = 0; j < numOfInputs - 1; j++) {
                if (inputIndex < inputs.size()) {
                    activation += inputs.get(inputIndex)
                            * neuron.weights.get(j);
                    inputIndex++;
                }
            }
            // Add the bias.
            activation += neuron.weights.get(numOfInputs) * BIAS;
            outputs.add(HelperFunction.Sigmoid(activation, 1.0f));
            inputIndex = 0;
        }
    }

    public ArrayList<Double> getWeights() {
        // Calculate the size of the output vector by calculating the amount of
        // weights in each neurons.
        ArrayList<Double> out = new ArrayList<Double>();
        for (int i = 0; i < this.totalNeurons; i++) {
            Neuron n = neurons.get(i);
            for (int j = 0; j < n.weights.size(); j++) {
                out.add(n.weights.get(j));
            }
        }
        return out;
    }

    public void loadLayer(ArrayList<Neuron> neurons) {
        totalNeurons = neurons.size();
        this.neurons = neurons;
    }
}
package com.auto.algorithm;

import java.util.ArrayList;

/*
 * @Description: NeuralNetwork is used to make decision for the car; decide that it
 * should turn right or turn left or go straight. It may contain many hidden NeuronsLayers 
 * and 1 output NeuronsLayer. 
 * These layers will constantly update to get new values from genomes
 * each time the car crashs to the wall. The network will use new values from Genetic 
 * to make a decision for the next try of car. 
 */
public class NeuralNetwork {
    private int inputAmount;
    private int outputAmount;
    private ArrayList<Double> outputs;
    private ArrayList<Double> inputs;
    // HiddenLayer produces the input for OutputLayer
    private ArrayList<NeuronsLayer> hiddenLayers;
    // OutputLayer will receive input from HiddenLayer
    private NeuronsLayer outputLayer;

    public NeuralNetwork() {
        outputs = new ArrayList<Double>();
        inputs = new ArrayList<Double>();
    }

    /*
     * receive input from sensors of car which is normalized distance from
     * center of car to the wall.
     */
    public void setInput(double[] normalizedDepths) {
        inputs.clear();
        for (int i = 0; i < normalizedDepths.length; i++) {
            inputs.add(normalizedDepths[i]);
        }
    }

    @SuppressWarnings("unchecked")
    public void update() {
        outputs.clear();
        for (int i = 0; i < hiddenLayers.size(); i++) {
            if (i > 0) {
                inputs = outputs;
            }
            // each hidden layer calculates the outputs based on inputs
            // from sensors of the car
            hiddenLayers.get(i).evaluate(inputs, outputs);
            System.out.println("Output of hidden layers: "
                    + outputs.toString());
        }
        // the outputs of HiddenLayers will be used as input for
        // OutputLayer
        inputs = (ArrayList<Double>) outputs.clone();
        // The output layer will give out the final outputs
        outputLayer.evaluate(inputs, outputs);
    }

    public double getOutput(int index) {
        if (index >= outputAmount)
            return 0.0f;
        return outputs.get(index);
    }

    /*
     * Initiate NeuronsNetwork
     */
    /*
     * public void createNet(int numOfHiddenLayers, int numOfInputs, int
     * neuronsPerHidden, int numOfOutputs) { inputAmount = numOfInputs;
     * outputAmount = numOfOutputs; hiddenLayers = new
     * ArrayList<NeuronsLayer>(); for (int i = 0; i < numOfHiddenLayers; i++) {
     * NeuronsLayer layer = new NeuronsLayer();
     * layer.populateLayer(neuronsPerHidden, numOfInputs);
     * hiddenLayers.add(layer); } outputLayer = new NeuronsLayer();
     * outputLayer.populateLayer(numOfOutputs, neuronsPerHidden); }
     */
    public void releaseNet() {
        // inputLayer = null;
        outputLayer = null;
        hiddenLayers = null;
    }

    /*
     * Neural network receives weights from genome to make new HiddenLayers and
     * OutputLayer.
     */
    public void fromGenome(Genome genome, int numOfInputs,
            int neuronsPerHidden, int numOfOutputs) {
        if (genome == null)
            return;
        releaseNet();
        hiddenLayers = new ArrayList<NeuronsLayer>();
        outputAmount = numOfOutputs;
        inputAmount = numOfInputs;
        NeuronsLayer hidden = new NeuronsLayer();
        ArrayList<Neuron> neurons = new ArrayList<Neuron>();
        for (int i = 0; i < neuronsPerHidden; i++) {
            ArrayList<Double> weights = new ArrayList<Double>();
            for (int j = 0; j < numOfInputs + 1; j++) {
                weights.add(genome.weights.get(i * neuronsPerHidden + j));
            }
            Neuron n = new Neuron();
            n.init(weights, numOfInputs);
            neurons.add(n);
        }
        hidden.loadLayer(neurons);
        hiddenLayers.add(hidden);

        // Clear weights and reassign the weights to the output.
        ArrayList<Neuron> neuronsOut = new ArrayList<Neuron>();
        for (int i = 0; i < numOfOutputs; i++) {
            ArrayList<Double> weights = new ArrayList<Double>();
            for (int j = 0; j < neuronsPerHidden + 1; j++) {
                weights.add(genome.weights.get(i * neuronsPerHidden + j));
            }
            Neuron n = new Neuron();
            n.init(weights, neuronsPerHidden);
            neuronsOut.add(n);
        }
        outputLayer = new NeuronsLayer();
        outputLayer.loadLayer(neuronsOut);
    }
}
package com.auto.algorithm;

import java.util.Random;

/*
 * Description: Global helper functions
 */
public class HelperFunction {
    /*
     * normalize value to make it from 1 to -1
     */
    public static double Sigmoid(float a, float p) {
        float ap = (-a) / p;
        return (1 / (1 + Math.exp(ap)));
    }

    /*
     * random number from -1 to 1;
     */
    public static double RandomSigmoid() {
        Random ran = new Random(System.nanoTime());
        double r = ran.nextDouble() - ran.nextDouble();
        return r;
    }
    /*
     * compare value of a to b and c. If is smaller then b or greater than c, 
     * a will become b or c
     */
    public static double getValueInRange(double a, double b, double c) {
        if (a < b) {
            return b;
        } else if (a > c) {
            return c;
        }
        return a;
    }
}

关注点

完成这个项目时我感觉非常好。这是我在马歇尔大学攻读硕士学位的毕业设计项目。这个项目是一个实际应用。它帮助初学者了解人工智能、神经网络和遗传算法。

参考文献

很抱歉,我忘记在之前的版本中提及参考文献。我的想法和代码是向这位作者 https://github.com/matthewrdev/Neural-Network 学习的,他是用 C++ 编写的。图形方面,我是从这个 Youtube 频道学习的 https://www.youtube.com/user/TheChernoProject。:)

历史

  • 2016 年 12 月 15 日:初始版本
© . All rights reserved.