65.9K
CodeProject 正在变化。 阅读更多。
Home

重塑神经网络 - 第二部分

starIconstarIconstarIconstarIconstarIcon

5.00/5 (19投票s)

2017年12月11日

CPOL

4分钟阅读

viewsIcon

37582

downloadIcon

3211

在第 2 部分中,我们将第 1 部分创建的神经网络在 Unity 中创建的环境中进行测试,以便我们能够看到它的性能如何。

完整系列

  • 第 1 部分:我们从头开始创建整个 NeuralNetwork 类。
  • 第 2 部分:我们在 Unity 中创建一个环境,以便在该环境中测试神经网络。
  • 第 3 部分:我们通过向代码添加一种新型的变异,对已经创建的神经网络进行重大改进。

引言

几天前,我发布了 这篇文章,解释了如何从零开始用 C# 实现一个神经网络。然而,在上一篇文章中,神经网络被训练来处理 XOR 函数。正如我承诺的那样,我们将要在 Unity 中训练简单的汽车来驾驶!这是 我们的目标:

视频拍完后,我感觉有点像 90 年代的诡异片段,但它确实达到了目的……

背景

要跟上本文的进度,你需要具备 C# 和 Unity 的基本编程知识。此外,你还需要阅读我 上一篇文章,我在其中首次实现了 NeuralNetwork 类。

编程前资源

如果你是 C# 新手,可以随时 搜索 MSDN 文档 来查找你不熟悉的内容,但如果你要查找 Unity 特定的内容,则可能需要 搜索 Unity 的脚本参考Unity 的手册

Using the Code

首先,你需要了解项目中将要使用的所有类

  • Car:控制 car 对象移动的主要脚本(由 NeuralNetwork 或用户控制)。
  • Wall:附加到每堵墙上的简单脚本。如果汽车撞上带有此脚本的对象,它会向 car 发送“Die”消息。
  • Checkpoint:一个简单的脚本,在 car 撞到它时会增加其适应度(分数)。
  • EvolutionManager:该脚本只是等待所有 cars 死亡,然后从最好的 car 生成新一代。
  • CameraFollow:这是改变 camera 位置以跟随最佳 car 的函数。

这是它们将如何工作的

  • 将有一个赛道,沿途有一系列检查点。
  • 一旦 car 撞到一个检查点,它的适应度就会增加。
  • 如果 car 撞到墙壁,它将被销毁。
  • 如果所有 cars 都被销毁,则从上一代中最好的 car 创建新一代。

现在,我们将逐一介绍每个脚本并对其进行更详细的解释。

NeuralNetwork

为此专门写了一整篇文章……

Car

首先,我们需要定义一些变量

    [SerializeField] bool UseUserInput = false;     // Defines whether the car 
                                                    // uses a NeuralNetwork or user input
    [SerializeField] LayerMask SensorMask;          // Defines the layer of the walls ("Wall")
    [SerializeField] float FitnessUnchangedDie = 5; // The number of seconds to wait 
                                                    // before checking if the fitness 
                                                    // didn't increase
	
	public static NeuralNetwork NextNetwork = new NeuralNetwork
           (new uint[] { 6, 4, 3, 2 }, null);       // public NeuralNetwork that refers to 
                                                    // the next neural network to be set to 
                                                    // the next instantiated car

    public string TheGuid { get; private set; }     // The Unique ID of the current car

    public int Fitness { get; private set; }        // The fitness/score of the current car. 
                                                    // Represents the number of checkpoints 
                                                    // that his car hit.

    public NeuralNetwork TheNetwork { get; private set; } // The NeuralNetwork of 
                                                          // the current car

    Rigidbody TheRigidbody;                         // The Rigidbody of the current car
    LineRenderer TheLineRenderer;                   // The LineRenderer of the current car

当创建一个新的 car 时,我们应该这样做

    private void Awake()
    {
        TheGuid = Guid.NewGuid().ToString(); // Assigns a new Unique ID for the current car
        
        TheNetwork = NextNetwork;            // Sets the current network to the Next Network
        NextNetwork = new NeuralNetwork(NextNetwork.Topology, null); // Make sure the 
           // Next Network is reassigned to avoid having another car use the same network

        TheRigidbody = GetComponent<Rigidbody>(); // Assign Rigidbody
        TheLineRenderer = GetComponent<LineRenderer>(); // Assign LineRenderer

        StartCoroutine(IsNotImproving());   // Start checking if the score stayed 
                                            // the same for a lot of time

        TheLineRenderer.positionCount = 17; // Make sure the line is long enough
    }

这是 IsNotImproving 函数

    // Checks every few seconds if the car didn't make any improvement
    IEnumerator IsNotImproving ()
    {
        while(true)
        {
            int OldFitness = Fitness;                             // Save the initial fitness
            yield return new WaitForSeconds(FitnessUnchangedDie); // Wait for some time
            if (OldFitness == Fitness)              // Check if the fitness didn't change yet
                WallHit();                                        // Kill this car
        }
    }

这是 Move 函数,它(请注意……)“移动”汽车

    // The main function that moves the car.
    public void Move (float v, float h)
    {
        TheRigidbody.velocity = transform.right * v * 4;
        TheRigidbody.angularVelocity = transform.up * h * 3;
    }

接下来是 CastRay 函数,它进行投射并可视化射线。稍后将使用它

    // Casts a ray and makes it visible through the line renderer
    double CastRay (Vector3 RayDirection, Vector3 LineDirection, int LinePositionIndex)
    {
        float Length = 4; // Maximum length of each ray

        RaycastHit Hit;
        if (Physics.Raycast(transform.position, RayDirection, 
                            out Hit, Length, SensorMask)) // Cast a ray
        {
            float Dist = Vector3.Distance
            (Hit.point, transform.position); // Get the distance of the hit in the line
            TheLineRenderer.SetPosition(LinePositionIndex, 
            Dist * LineDirection);           // Set the position of the line

            return Dist;                     // Return the distance
        }
        else
        {
            TheLineRenderer.SetPosition(LinePositionIndex, 
            LineDirection * Length);         // Set the distance of the hit in the line 
                                             // to the maximum distance

            return Length;                   // Return the maximum distance
        }
    }

接下来是 GetNeuralInputAxisFunction,它为我们完成了大量工作

    // Casts all the rays, puts them through the NeuralNetwork and outputs the Move Axis
    void GetNeuralInputAxis (out float Vertical, out float Horizontal)
    {
        double[] NeuralInput = new double[NextNetwork.Topology[0]];

        // Cast forward, back, right and left
        NeuralInput[0] = CastRay(transform.forward, Vector3.forward, 1) / 4;
        NeuralInput[1] = CastRay(-transform.forward, -Vector3.forward, 3) / 4;
        NeuralInput[2] = CastRay(transform.right, Vector3.right, 5) / 4;
        NeuralInput[3] = CastRay(-transform.right, -Vector3.right, 7) / 4;

        // Cast forward-right and forward-left
        float SqrtHalf = Mathf.Sqrt(0.5f);
        NeuralInput[4] = CastRay(transform.right * SqrtHalf + 
                         transform.forward * SqrtHalf, Vector3.right * SqrtHalf + 
                         Vector3.forward * SqrtHalf, 9) / 4;
        NeuralInput[5] = CastRay(transform.right * SqrtHalf + -transform.forward * SqrtHalf, 
                         Vector3.right * SqrtHalf + -Vector3.forward * SqrtHalf, 13) / 4;

        // Feed through the network
        double[] NeuralOutput = TheNetwork.FeedForward(NeuralInput);
        
        // Get Vertical Value
        if (NeuralOutput[0] <= 0.25f)
            Vertical = -1;
        else if (NeuralOutput[0] >= 0.75f)
            Vertical = 1;
        else
            Vertical = 0;

        // Get Horizontal Value
        if (NeuralOutput[1] <= 0.25f)
            Horizontal = -1;
        else if (NeuralOutput[1] >= 0.75f)
            Horizontal = 1;
        else
            Horizontal = 0;

        // If the output is just standing still, then move the car forward
        if (Vertical == 0 && Horizontal == 0)
            Vertical = 1;
    }

然后,这是我们每秒执行 50 次的操作

    private void FixedUpdate()
    {
        if (UseUserInput) // If we're gonna use user input
            Move(Input.GetAxisRaw("Vertical"), 
            Input.GetAxisRaw("Horizontal")); // Moves the car according to the input
        else // if we're gonna use a neural network
        {
            float Vertical;
            float Horizontal;

            GetNeuralInputAxis(out Vertical, out Horizontal);

            Move(Vertical, Horizontal); // Moves the car
        }
    }

我们还需要一些将从其他脚本(CheckpointWall)调用的函数

    // This function is called through all the checkpoints when the car hits any.
    public void CheckpointHit ()
    {
        Fitness++; // Increase Fitness/Score
    }

    // Called by walls when hit by the car
    public void WallHit()
    {
        EvolutionManager.Singleton.CarDead(this, Fitness); // Tell the Evolution Manager 
                                                           // that the car is dead
        gameObject.SetActive(false);                       // Make sure the car is inactive
    }

Wall

Wall 脚本只是通知任何撞到它的汽车

using UnityEngine;

public class Wall : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer 
                                                          // set on each car

    private void OnCollisionEnter(Collision collision)    // Once anything hits the wall
    {
        if (collision.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // Make sure 
                                                                               // it's a car
        {
            collision.transform.GetComponent<Car>().WallHit(); // If it is a car, 
                                                          // tell it that it just hit a wall
        }
    }
}

检查点

Checkpoint 的功能与 Wall 几乎相同,但有一个转折。Checkpoint 使用 Trigger 而不是 Collider,并且 Checkpoint 确保它们只为每辆 car 增加一次适应度。这就是为什么每辆 Car 都有一个唯一的 ID。每个 Checkpoint 只保存先前增加过的 Car 的所有 Guid

using System.Collections.Generic;
using UnityEngine;

public class Checkpoint : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer set 
                                                          // on each car

    List<string> AllGuids = new List<string>();           // The list of Guids of all the 
                                                          // cars increased

    private void OnTriggerEnter(Collider other)           // Once anything goes through the wall
    {
        if(other.gameObject.layer == LayerMask.NameToLayer(LayerHitName))  // If this object 
                                                                           // is a car
        {
            Car CarComponent = other.transform.parent.GetComponent<Car>(); // Get the component 
                                                                           // of the car
            string CarGuid = CarComponent.TheGuid;        // Get the Unique ID of the car

            if (!AllGuids.Contains(CarGuid))              // If we didn't increase 
                                                          // the car before
            {
                AllGuids.Add(CarGuid);                    // Make sure we don't 
                                                          // increase it again
                CarComponent.CheckpointHit();             // Increase the car's fitness
            }
        }
    }
}

EvolutionManager

没有变量就无法编写脚本

    public static EvolutionManager Singleton = null; // The current EvolutionManager Instance

    [SerializeField] int CarCount = 100;             // The number of cars per generation
    [SerializeField] GameObject CarPrefab;           // The Prefab of the car to be created 
                                                     // for each instance
    [SerializeField] Text GenerationNumberText;      // Some text to write the generation number

    int GenerationCount = 0;                         // The current generation number

    List<Car> Cars = new List<Car>();                // This list of cars currently alive

    NeuralNetwork BestNeuralNetwork = null;          // The best NeuralNetwork 
                                                     // currently available
    int BestFitness = -1;                            // The Fitness of the 
                                                     // best NeuralNetwork ever created

在程序开始时

    // On Start
    private void Start()
    {
        if (Singleton == null) // If no other instances were created
            Singleton = this;  // Make the only instance this one
        else
            gameObject.SetActive(false); // There is another instance already in place. 
                                         // Make this one inactive.

        BestNeuralNetwork = new NeuralNetwork(Car.NextNetwork); // Set the BestNeuralNetwork 
                                                                // to a random new network

        StartGeneration(); 
    }

这是新一代的创建方式

    // Starts a whole new generation
    void StartGeneration ()
    {
        GenerationCount++; // Increment the generation count
        GenerationNumberText.text = "Generation: " + GenerationCount; // Update generation text

        for (int i = 0; i < CarCount; i++)
        {
            if (i == 0)
                Car.NextNetwork = BestNeuralNetwork; // Make sure one car uses the best network
            else
            {
                Car.NextNetwork = new NeuralNetwork(BestNeuralNetwork); // Clone the best 
                                          // neural network and set it to be for the next car
                Car.NextNetwork.Mutate(); // Mutate it
            }

            Cars.Add(Instantiate(CarPrefab, transform.position, 
                     Quaternion.identity, transform).GetComponent<Car>()); // Instantiate 
                                          // a new car and add it to the list of cars
        }
    }

Cars 调用的内容

    // Gets called by cars when they die
    public void CarDead (Car DeadCar, int Fitness)
    {
        Cars.Remove(DeadCar);        // Remove the car from the list
        Destroy(DeadCar.gameObject); // Destroy the dead car

        if (Fitness > BestFitness)   // If it is better that the current best car
        {
            BestNeuralNetwork = DeadCar.TheNetwork; // Make sure it becomes the best car
            BestFitness = Fitness;   // And also set the best fitness
        }

        if (Cars.Count <= 0)         // If there are no cars left
            StartGeneration();       // Create a new generation
    }

CameraFollow

只是另一个简单的多合一脚本,可以完成工作

using UnityEngine;

public class CameraFollow : MonoBehaviour
{
    Vector3 SmoothPosVelocity;      // Velocity of Position Smoothing
    Vector3 SmoothRotVelocity;      // Velocity of Rotation  Smoothing

    void FixedUpdate ()
    {
        Car BestCar = transform.GetChild(0).GetComponent<Car>(); // The best car in 
                                                                 // the bunch is the first one

        for (int i = 1; i < transform.childCount; i++)           // Loop over all the cars
        {
            Car CurrentCar = transform.GetChild(i).GetComponent<Car>(); // Get the component 
                                                                        // of the current car

            if (CurrentCar.Fitness > BestCar.Fitness) // If the current car is better than 
                                                      // the best car
            {
                BestCar = CurrentCar;                 // Then, the best car is the current car
            }
        }

        Transform BestCarCamPos = BestCar.transform.GetChild(0); // The target position 
                                                      // of the camera relative to the best car

        Camera.main.transform.position = Vector3.SmoothDamp
               (Camera.main.transform.position, BestCarCamPos.position, 
                ref SmoothPosVelocity, 0.7f);         // Smoothly set the position

        Camera.main.transform.rotation = Quaternion.Lerp(Camera.main.transform.rotation,
                                         Quaternion.LookRotation(BestCar.transform.position - 
                                         Camera.main.transform.position),
                                         0.1f);       // Smoothly set the rotation
    }
}

关注点

现在我们已经详细解释了所有脚本,你可以安心入睡了,因为 之前实现的 NeuralNetwork 类效果很好,而且没有浪费时间。看到这些汽车一步一步地学会如何通过赛道行驶,感觉非常好。此外,汽车使用内置传感器,这意味着汽车可以在它没有学习过的赛道上行驶。一旦我完成了这个,我感觉我的二进制孩子学会了如何驾驶!我尽我最大的努力使这个实现尽可能简单,以便那些不想深入研究 Unity 内容的人能够理解。而且……永远不要认为我们在这里就结束了。我目前的目标是实现 3 个交叉算子,以使进化更有效率,并为开发人员提供更多多样性。在此之后,反向传播是目标。

2018 年 2 月 20 日更新

第 3 部分已上线!它显示了与第 1 部分和第 2 部分讨论的系统相比有了实质性的改进。告诉我你的想法!

历史

  • 2017 年 12 月 11 日:版本 1.0:主实现
© . All rights reserved.