65.9K
CodeProject 正在变化。 阅读更多。
Home

使用 XGBoost 和 C# 解决鸢尾花分类问题

starIconstarIconstarIconstarIconemptyStarIcon

4.00/5 (2投票s)

2019年9月4日

CPOL

2分钟阅读

viewsIcon

13749

如何在 C# 应用程序中嵌入极端梯度提升等机器学习算法。

目录

引言

Sample Image - maximum width is 600 pixels

图片来源:维基百科

在本文中,我演示了如何使用流行的 XGBoost 非托管库的 C# 包装器XGBoost 代表 "Extreme Gradient Boosting"(极端梯度提升)。我使用著名的 IRIS 数据集来训练和测试模型。我的目标是分享我关于如何在 C# 应用程序中嵌入像极端梯度提升这样的机器学习算法的学习心得。在继续之前,我必须向 XGBoost 非托管库的开发者和 .NET 包装器库的开发者表示感谢。

顶部

背景

本文假设用户对以下内容具有中级知识:

  • 决策树算法
  • 梯度提升算法
  • 数据规范化
  • C#

本文和随附代码避免提供决策树和梯度提升算法的深入教程。 我提供了 YouTube 培训视频的链接,我认为这些视频具有极高的教育意义。

顶部

梯度提升分类算法概述

决策树入门 (StatQuest)

顶部

理解构建决策树时的基尼指数

顶部

AdaBoost 入门

顶部

梯度提升入门

顶部

XGBoost 库 (C#)

托管包装器

原始 XGBoost 库的 C/C++ 源代码可在 Github 上找到。 您可以找到 Windows 的构建说明。 感谢 PicNet 的努力,我们可以跳过编译非托管源的步骤,直接跳转到托管包装器。

顶部

简单线性分类问题

我们将进行一个简单的练习,我们将训练一个模型来对两个可以很好地线性分离的点簇进行分类。

        /// <summary>
        /// Two classes of vectors - Class-Blue and Class-Red
        /// Class-Blue  - The vectors are centered around the point (+0.5,+0.5) 
        /// and label value=1
        /// Class-Red   - The vectors are centered around the point (-0.5,-0.5) 
        /// and label value=0
        /// <summary>
        [TestMethod]
        public void LinearClassification1()
        {
            var xgb = new XGBoost.XGBClassifier();
            float[][] vectorsTrain = new float[][]
            {
                new[] {0.5f,0.5f},
                new[] {0.6f,0.6f},
                new[] {0.6f,0.4f},
                new[] {0.4f,0.6f},
                new[] {0.4f,0.4f},

                new[] {-0.5f,-0.5f},
                new[] {-0.6f,-0.6f},
                new[] {-0.6f,-0.4f},
                new[] {-0.4f,-0.6f},
                new[] {-0.4f,-0.4f},
            };
            var lablesTrain = new[]
            {
                1.0f,
                1.0f,
                1.0f,
                1.0f,
                1.0f,

                0.0f,
                0.0f,
                0.0f,
                0.0f,
                0.0f,
            };
            ///
            /// Ensure count of training labels=count of training vectors
            ///
            Assert.AreEqual(vectorsTrain.Length, lablesTrain.Length);
            ///
            /// Train the model
            ///
            xgb.Fit(vectorsTrain, lablesTrain);
            ///
            /// Test the model using test vectors
            ///
            float[][] vectorsTest = new float[][]
            {
                new[] {0.55f,0.55f},
                new[] {0.55f,0.45f},
                new[] {0.45f,0.55f},
                new[] {0.45f,0.45f},

                new[] {-0.55f,-0.55f},
                new[] {-0.55f,-0.45f},
                new[] {-0.45f,-0.55f},
                new[] {-0.45f,-0.45f},
            };
            var labelsTestExpected = new[]
            {
                1.0f,
                1.0f,
                1.0f,
                1.0f,

                0.0f,
                0.0f,
                0.0f,
                0.0f,
            };
            float[] labelsTestPredicted = xgb.Predict(vectorsTest);
            ///
            /// Verify that predicted labels match the expected labels
            ///
            CollectionAssert.AreEqual(labelsTestPredicted, labelsTestExpected);
        }

顶部

实现异或逻辑

异或逻辑比线性分类更复杂。 数据点不是直接线性可分的。

异或真值表

        X | Y | OUTPUT
        --------------
        1 | 0 |   1
        --------------
        0 | 1 |   1
        --------------
        0 | 0 |   0
        --------------
        1 | 1 |   0
        --------------    

示例代码

        [TestMethod]
        public void TestMethod1()
        {
            var xgb = new XGBoost.XGBClassifier();
            ///
            /// Generate training vectors
            ///
            int countTrainingPoints = 50;
            entity.XGBArray trainClass_0_1 = 
                   Util.GenerateRandom2dPoints(countTrainingPoints / 2, 
                0.0, 0.5,
                0.5, 1.0, 1.0);//0,1
            entity.XGBArray trainClass_1_0 = 
                   Util.GenerateRandom2dPoints(countTrainingPoints / 2,
                0.5, 1.0,
                0.0, 0.5, 1.0);//1,0
            entity.XGBArray trainClass_0_0 = 
                   Util.GenerateRandom2dPoints(countTrainingPoints / 2,
                0.0, 0.5,
                0.0, 0.5, 0.0);//0,0
            entity.XGBArray trainClass_1_1 = 
                   Util.GenerateRandom2dPoints(countTrainingPoints / 2,
                0.5, 1.0,
                0.5, 1.0, 0.0);//1,1
            ///
            /// Train the model
            ///
            entity.XGBArray allVectorsTraining = 
                   Util.UnionOfXGBArrays(trainClass_0_1,trainClass_1_0,
                                         trainClass_0_0,trainClass_1_1);
            xgb.Fit(allVectorsTraining.Vectors, allVectorsTraining.Labels);
            ///
            /// Test the model
            ///
            int countTestingPoints = 10;
            entity.XGBArray testClass_0_1 = 
                   Util.GenerateRandom2dPoints(countTestingPoints ,
                0.1, 0.4,
                0.6, 0.9, 1.0);//0,1
            entity.XGBArray testClass_1_0 = 
                   Util.GenerateRandom2dPoints(countTestingPoints,
                0.6, 0.9,
                0.1, 0.4, 1.0);//1,0
            entity.XGBArray testClass_0_0 = 
                   Util.GenerateRandom2dPoints(countTestingPoints,
                0.1, 0.4,
                0.1, 0.4, 0.0);//0,0
            entity.XGBArray testClass_1_1 = 
                   Util.GenerateRandom2dPoints(countTestingPoints,
                0.6, 0.9,
                0.6, 0.9, 0.0);//1,1
            entity.XGBArray allVectorsTest = 
                   Util.UnionOfXGBArrays(testClass_0_1, testClass_1_0,
                                         testClass_0_0,testClass_1_1);
            var resultsActual = xgb.Predict(allVectorsTest.Vectors);
            CollectionAssert.AreEqual(resultsActual, allVectorsTest.Labels);
        }    

顶部

将模型持久化到文件

一旦模型经过训练并发现产生令人满意的结果,您就可以在生产中使用该模型。 方法 SaveModelToFile 将模型持久化到二进制文件。 static 方法 LoadClassifierFromFile 将重新加载保存的模型。

        var xgbTrainer = new XGBoost.XGBClassifier();
        ///
        ///Train the model
        ///
        xgbTrainer.SaveModelToFile("SimpleLinearClassifier.dat");
        ///
        ///Load the persisted model
        ///
        var xgbProduction = XGBoost.XGBClassifier.LoadClassifierFromFile(fileModel);    

Iris 数据集

概述

来源:维基百科

该数据集包含来自 Iris 花的三个物种中每一个物种的 50 条记录。 此数据集是演示许多统计分类技术的测试用例。 描述列:

  1. Iris-setosa(山鸢尾)
  2. Iris-versicolor(变色鸢尾)
  3. Iris-virginica(维吉尼亚鸢尾)

顶部

数据结构

来源:维基百科

顶部

从 CSV 解析 IRIS 记录

    ///
    ///The C# class Iris will be used for capturing a single data row
    ///
    public class Iris
    {
        public float Col1 { get; set; }
        public float Col2 { get; set; }
        public float Col3 { get; set; }
        public float Col4 { get; set; }
        public string Petal { get; set; }
    }
    ///
    ///The function LoadIris will read the specified file line by line 
    ///and create an instance of the Iris POCO
    ///The class TextFieldParser from the assembly Microsoft.VisualBasic 
    ///is being used here
    ///
    private Iris[] LoadIris(string filename)
    {
        string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), filename);
        List<Iris> records = new List<Iris>();
        using (var parser = new TextFieldParser(pathFull))
        {
            parser.TextFieldType = FieldType.Delimited;
            parser.SetDelimiters(",");
            while (!parser.EndOfData)
            {
                var fields = parser.ReadFields();
                Iris oRecord = new Iris();
                oRecord.Col1 = float.Parse(fields[0]);
                oRecord.Col2 = float.Parse(fields[1]);
                oRecord.Col3 = float.Parse(fields[2]);
                oRecord.Col4 = float.Parse(fields[3]);
                oRecord.Petal = fields[4];
                records.Add(oRecord);
            }
        }

顶部

从 CSV 创建特征向量

        /// <summary>
        /// Create XGBoost consumable feature vector from Iris POCO classes
        /// </summary>
        internal static XGVector<Iris>[] ConvertFromIrisToFeatureVectors(Iris[] records)
        {
            List<XGVector<Iris>> vectors = new List<XGVector<Iris>>();
            foreach (var rec in records)
            {
                XGVector<Iris> newVector = new XGVector<Iris>();
                newVector.Original = rec;
                newVector.Features = new float[]
                {
                    rec.Col1, rec.Col2,rec.Col3,rec.Col4
                };
                newVector.Label = ConvertLabelFromStringToNumeric(rec.Petal);
                vectors.Add(newVector);
            }
            return vectors.ToArray();
        }

        /// <summary>
        /// Converts the string based name of the petal to a numeric representation
        /// </summary>
        internal static float ConvertLabelFromStringToNumeric(string petal)
        {
            if (petal.Contains("setosa"))
            {
                return 0;
            }
            else if (petal.Contains("versicolor"))
            {
                return 1.0f;
            }
            else if (petal.Contains("virginica"))
            {
                return 2.0f;
            }
            else
            {
                throw new NotImplementedException();
            }
        }

顶部

加载 IRIS - 将所有内容整合在一起

        [TestMethod]
        public void BasicLoadData()
        {
            string filename = "Iris\\Iris.train.data";
            iris.Iris[] records = IrisUtils.LoadIris(filename);
            entity.XGVector<iris.Iris>[] vectors = 
                   IrisUtils.ConvertFromIrisToFeatureVectors(records);
            Assert.IsTrue(records.Length >= 140);
        }

顶部

训练和测试 IRIS

        [TestMethod]
        public void TrainAndTestIris()
        {
            ///
            /// Load training vectors
            ///
            string filenameTrain = "Iris\\Iris.train.data";
            iris.Iris[] recordsTrain = IrisUtils.LoadIris(filenameTrain);
            entity.XGVector<iris.Iris>[] vectorsTrain = 
                            IrisUtils.ConvertFromIrisToFeatureVectors(recordsTrain);
            ///
            /// Load testingvectors
            ///
            string filenameTest = "Iris\\Iris.test.data";
            iris.Iris[] recordsTest = IrisUtils.LoadIris(filenameTest);
            entity.XGVector<iris.Iris>[] vectorsTest = 
                   IrisUtils.ConvertFromIrisToFeatureVectors(recordsTest);

            int noOfClasses = 3;
            var xgbc = new XGBoost.XGBClassifier
                       (objective: "multi:softprob", numClass:3);
            entity.XGBArray arrTrain = Util.ConvertToXGBArray(vectorsTrain);
            entity.XGBArray arrTest = Util.ConvertToXGBArray(vectorsTest);
            xgbc.Fit(arrTrain.Vectors, arrTrain.Labels);
            var outcomeTest=xgbc.Predict(arrTest.Vectors);
            for(int index=0;index<arrTest.Vectors.Length;index++)
            {
                string sExpected = IrisUtils.ConvertLabelFromNumericToString
                                   (arrTest.Labels[index]);
                float[] arrResults = new float[]
                {
                    outcomeTest[index*noOfClasses +0],
                    outcomeTest[index*noOfClasses +1],
                    outcomeTest[index*noOfClasses +2]
                };
                float max = arrResults.Max();
                int indexWithMaxValue = Util.GetIndexWithMaxValue(arrResults);
                string sActualClass = IrisUtils.ConvertLabelFromNumericToString
                                      ((float)indexWithMaxValue);
                Trace.WriteLine($"{index}       Expected={sExpected}  
                                                Actual={sActualClass}");
                Assert.AreEqual(sActualClass, sExpected);
            }
            string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), 
                              _fileModelIris);
            xgbc.SaveModelToFile(pathFull);
        }

顶部

Using the Code

Github

解决方案结构

        |
        |-----XGBoost
        |
        |-----XGBoostTests
        |           |
        |           |---iris
        |           |     |
        |           |     |--Iris.data
        |           |     |
        |           |     |--Iris.test.data
        |           |     |
        |           |     |--Iris.train.data
        |           |     |
        |           |     |--Iris.cs
        |           |     |
        |           |     
        |           |---IrisUtils.cs
        |           |
        |           |---IrisUnitTest.cs
        |           |
        |           |---SimpleLinearClassifierTests.cs
        |           |
        |           |---XORClassifierTests.cs
        |           |
        |
        |    

顶部

历史

  • 2019 年 9 月 4 日:初始版本
© . All rights reserved.