使用 XGBoost 和 C# 解决鸢尾花分类问题
如何在 C# 应用程序中嵌入极端梯度提升等机器学习算法。
目录
引言
图片来源:维基百科
在本文中,我演示了如何使用流行的 XGBoost 非托管库的 C# 包装器。XGBoost 代表 "Extreme Gradient Boosting"(极端梯度提升)。我使用著名的 IRIS 数据集来训练和测试模型。我的目标是分享我关于如何在 C# 应用程序中嵌入像极端梯度提升这样的机器学习算法的学习心得。在继续之前,我必须向 XGBoost 非托管库的开发者和 .NET 包装器库的开发者表示感谢。
背景
本文假设用户对以下内容具有中级知识:
- 决策树算法
- 梯度提升算法
- 数据规范化
- C#
本文和随附代码避免提供决策树和梯度提升算法的深入教程。 我提供了 YouTube 培训视频的链接,我认为这些视频具有极高的教育意义。
梯度提升分类算法概述
决策树入门 (StatQuest)
理解构建决策树时的基尼指数
AdaBoost 入门
梯度提升入门
XGBoost 库 (C#)
托管包装器
原始 XGBoost
库的 C/C++ 源代码可在 Github 上找到。 您可以找到 Windows 的构建说明。 感谢 PicNet 的努力,我们可以跳过编译非托管源的步骤,直接跳转到托管包装器。
简单线性分类问题
我们将进行一个简单的练习,我们将训练一个模型来对两个可以很好地线性分离的点簇进行分类。
/// <summary>
/// Two classes of vectors - Class-Blue and Class-Red
/// Class-Blue - The vectors are centered around the point (+0.5,+0.5)
/// and label value=1
/// Class-Red - The vectors are centered around the point (-0.5,-0.5)
/// and label value=0
/// <summary>
[TestMethod]
public void LinearClassification1()
{
var xgb = new XGBoost.XGBClassifier();
float[][] vectorsTrain = new float[][]
{
new[] {0.5f,0.5f},
new[] {0.6f,0.6f},
new[] {0.6f,0.4f},
new[] {0.4f,0.6f},
new[] {0.4f,0.4f},
new[] {-0.5f,-0.5f},
new[] {-0.6f,-0.6f},
new[] {-0.6f,-0.4f},
new[] {-0.4f,-0.6f},
new[] {-0.4f,-0.4f},
};
var lablesTrain = new[]
{
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
};
///
/// Ensure count of training labels=count of training vectors
///
Assert.AreEqual(vectorsTrain.Length, lablesTrain.Length);
///
/// Train the model
///
xgb.Fit(vectorsTrain, lablesTrain);
///
/// Test the model using test vectors
///
float[][] vectorsTest = new float[][]
{
new[] {0.55f,0.55f},
new[] {0.55f,0.45f},
new[] {0.45f,0.55f},
new[] {0.45f,0.45f},
new[] {-0.55f,-0.55f},
new[] {-0.55f,-0.45f},
new[] {-0.45f,-0.55f},
new[] {-0.45f,-0.45f},
};
var labelsTestExpected = new[]
{
1.0f,
1.0f,
1.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
};
float[] labelsTestPredicted = xgb.Predict(vectorsTest);
///
/// Verify that predicted labels match the expected labels
///
CollectionAssert.AreEqual(labelsTestPredicted, labelsTestExpected);
}
实现异或逻辑
异或逻辑比线性分类更复杂。 数据点不是直接线性可分的。
异或真值表
X | Y | OUTPUT
--------------
1 | 0 | 1
--------------
0 | 1 | 1
--------------
0 | 0 | 0
--------------
1 | 1 | 0
--------------
示例代码
[TestMethod]
public void TestMethod1()
{
var xgb = new XGBoost.XGBClassifier();
///
/// Generate training vectors
///
int countTrainingPoints = 50;
entity.XGBArray trainClass_0_1 =
Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.0, 0.5,
0.5, 1.0, 1.0);//0,1
entity.XGBArray trainClass_1_0 =
Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.5, 1.0,
0.0, 0.5, 1.0);//1,0
entity.XGBArray trainClass_0_0 =
Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.0, 0.5,
0.0, 0.5, 0.0);//0,0
entity.XGBArray trainClass_1_1 =
Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.5, 1.0,
0.5, 1.0, 0.0);//1,1
///
/// Train the model
///
entity.XGBArray allVectorsTraining =
Util.UnionOfXGBArrays(trainClass_0_1,trainClass_1_0,
trainClass_0_0,trainClass_1_1);
xgb.Fit(allVectorsTraining.Vectors, allVectorsTraining.Labels);
///
/// Test the model
///
int countTestingPoints = 10;
entity.XGBArray testClass_0_1 =
Util.GenerateRandom2dPoints(countTestingPoints ,
0.1, 0.4,
0.6, 0.9, 1.0);//0,1
entity.XGBArray testClass_1_0 =
Util.GenerateRandom2dPoints(countTestingPoints,
0.6, 0.9,
0.1, 0.4, 1.0);//1,0
entity.XGBArray testClass_0_0 =
Util.GenerateRandom2dPoints(countTestingPoints,
0.1, 0.4,
0.1, 0.4, 0.0);//0,0
entity.XGBArray testClass_1_1 =
Util.GenerateRandom2dPoints(countTestingPoints,
0.6, 0.9,
0.6, 0.9, 0.0);//1,1
entity.XGBArray allVectorsTest =
Util.UnionOfXGBArrays(testClass_0_1, testClass_1_0,
testClass_0_0,testClass_1_1);
var resultsActual = xgb.Predict(allVectorsTest.Vectors);
CollectionAssert.AreEqual(resultsActual, allVectorsTest.Labels);
}
将模型持久化到文件
一旦模型经过训练并发现产生令人满意的结果,您就可以在生产中使用该模型。 方法 SaveModelToFile
将模型持久化到二进制文件。 static
方法 LoadClassifierFromFile
将重新加载保存的模型。
var xgbTrainer = new XGBoost.XGBClassifier();
///
///Train the model
///
xgbTrainer.SaveModelToFile("SimpleLinearClassifier.dat");
///
///Load the persisted model
///
var xgbProduction = XGBoost.XGBClassifier.LoadClassifierFromFile(fileModel);
Iris 数据集
概述
来源:维基百科
该数据集包含来自 Iris 花的三个物种中每一个物种的 50 条记录。 此数据集是演示许多统计分类技术的测试用例。 描述列:
Iris-setosa(山鸢尾)
Iris-versicolor(变色鸢尾)
Iris-virginica(维吉尼亚鸢尾)
数据结构
来源:维基百科
从 CSV 解析 IRIS 记录
///
///The C# class Iris will be used for capturing a single data row
///
public class Iris
{
public float Col1 { get; set; }
public float Col2 { get; set; }
public float Col3 { get; set; }
public float Col4 { get; set; }
public string Petal { get; set; }
}
///
///The function LoadIris will read the specified file line by line
///and create an instance of the Iris POCO
///The class TextFieldParser from the assembly Microsoft.VisualBasic
///is being used here
///
private Iris[] LoadIris(string filename)
{
string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), filename);
List<Iris> records = new List<Iris>();
using (var parser = new TextFieldParser(pathFull))
{
parser.TextFieldType = FieldType.Delimited;
parser.SetDelimiters(",");
while (!parser.EndOfData)
{
var fields = parser.ReadFields();
Iris oRecord = new Iris();
oRecord.Col1 = float.Parse(fields[0]);
oRecord.Col2 = float.Parse(fields[1]);
oRecord.Col3 = float.Parse(fields[2]);
oRecord.Col4 = float.Parse(fields[3]);
oRecord.Petal = fields[4];
records.Add(oRecord);
}
}
从 CSV 创建特征向量
/// <summary>
/// Create XGBoost consumable feature vector from Iris POCO classes
/// </summary>
internal static XGVector<Iris>[] ConvertFromIrisToFeatureVectors(Iris[] records)
{
List<XGVector<Iris>> vectors = new List<XGVector<Iris>>();
foreach (var rec in records)
{
XGVector<Iris> newVector = new XGVector<Iris>();
newVector.Original = rec;
newVector.Features = new float[]
{
rec.Col1, rec.Col2,rec.Col3,rec.Col4
};
newVector.Label = ConvertLabelFromStringToNumeric(rec.Petal);
vectors.Add(newVector);
}
return vectors.ToArray();
}
/// <summary>
/// Converts the string based name of the petal to a numeric representation
/// </summary>
internal static float ConvertLabelFromStringToNumeric(string petal)
{
if (petal.Contains("setosa"))
{
return 0;
}
else if (petal.Contains("versicolor"))
{
return 1.0f;
}
else if (petal.Contains("virginica"))
{
return 2.0f;
}
else
{
throw new NotImplementedException();
}
}
加载 IRIS - 将所有内容整合在一起
[TestMethod]
public void BasicLoadData()
{
string filename = "Iris\\Iris.train.data";
iris.Iris[] records = IrisUtils.LoadIris(filename);
entity.XGVector<iris.Iris>[] vectors =
IrisUtils.ConvertFromIrisToFeatureVectors(records);
Assert.IsTrue(records.Length >= 140);
}
训练和测试 IRIS
[TestMethod]
public void TrainAndTestIris()
{
///
/// Load training vectors
///
string filenameTrain = "Iris\\Iris.train.data";
iris.Iris[] recordsTrain = IrisUtils.LoadIris(filenameTrain);
entity.XGVector<iris.Iris>[] vectorsTrain =
IrisUtils.ConvertFromIrisToFeatureVectors(recordsTrain);
///
/// Load testingvectors
///
string filenameTest = "Iris\\Iris.test.data";
iris.Iris[] recordsTest = IrisUtils.LoadIris(filenameTest);
entity.XGVector<iris.Iris>[] vectorsTest =
IrisUtils.ConvertFromIrisToFeatureVectors(recordsTest);
int noOfClasses = 3;
var xgbc = new XGBoost.XGBClassifier
(objective: "multi:softprob", numClass:3);
entity.XGBArray arrTrain = Util.ConvertToXGBArray(vectorsTrain);
entity.XGBArray arrTest = Util.ConvertToXGBArray(vectorsTest);
xgbc.Fit(arrTrain.Vectors, arrTrain.Labels);
var outcomeTest=xgbc.Predict(arrTest.Vectors);
for(int index=0;index<arrTest.Vectors.Length;index++)
{
string sExpected = IrisUtils.ConvertLabelFromNumericToString
(arrTest.Labels[index]);
float[] arrResults = new float[]
{
outcomeTest[index*noOfClasses +0],
outcomeTest[index*noOfClasses +1],
outcomeTest[index*noOfClasses +2]
};
float max = arrResults.Max();
int indexWithMaxValue = Util.GetIndexWithMaxValue(arrResults);
string sActualClass = IrisUtils.ConvertLabelFromNumericToString
((float)indexWithMaxValue);
Trace.WriteLine($"{index} Expected={sExpected}
Actual={sActualClass}");
Assert.AreEqual(sActualClass, sExpected);
}
string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(),
_fileModelIris);
xgbc.SaveModelToFile(pathFull);
}
Using the Code
Github
解决方案结构
|
|-----XGBoost
|
|-----XGBoostTests
| |
| |---iris
| | |
| | |--Iris.data
| | |
| | |--Iris.test.data
| | |
| | |--Iris.train.data
| | |
| | |--Iris.cs
| | |
| |
| |---IrisUtils.cs
| |
| |---IrisUnitTest.cs
| |
| |---SimpleLinearClassifierTests.cs
| |
| |---XORClassifierTests.cs
| |
|
|
历史
- 2019 年 9 月 4 日:初始版本