65.9K
CodeProject 正在变化。 阅读更多。
Home

在 Java 中使用便携式 ONNX AI 模型

starIconstarIconstarIconstarIconstarIcon

5.00/5 (2投票s)

2020 年 9 月 11 日

CPOL

4分钟阅读

viewsIcon

17036

downloadIcon

246

在本文中,我简要概述了 ONNX Runtime 和 ONNX 格式。

在本系列关于在 2020 年使用便携式神经网络的文章中,您将学习如何在 x64 架构上安装 ONNX 并在 Java 中使用它。

ONNX 由微软与 Facebook 和 AWS 共同开发。ONNX 格式和 ONNX Runtime 都得到了行业支持,以确保所有重要框架都能将它们的图导出为 ONNX,并且这些模型可以在任何硬件配置上运行。

ONNX Runtime 是一个用于运行已转换为 ONNX 格式的机器学习模型的引擎。传统机器学习模型和深度学习模型(神经网络)都可以导出为 ONNX 格式。该运行时可以在 Linux、Windows 和 Mac 上运行,并可以在各种芯片架构上运行。它还可以利用 GPU 和 TPU 等硬件加速器。但是,并非所有操作系统、芯片架构和加速器的组合都有安装包,因此如果您使用的不是常见的组合,则可能需要从源代码构建运行时。请查看 ONNX Runtime 网站以获取您需要的组合的安装说明。本文将展示如何在 x64 架构上安装具有默认 CPU 的 ONNX Runtime,以及在具有 GPU 的 x64 架构上安装。

除了能够在多种硬件配置上运行外,该运行时还可以从大多数流行的编程语言调用。本文的目的是展示如何在 Java 中使用 ONNX Runtime。我将展示如何安装 onnxruntime 包。一旦安装了 ONNX Runtime,我将把先前导出的 MNIST 模型加载到 ONNX Runtime 中,并使用它进行预测。

安装和导入 ONNX Runtime

在使用 ONNX Runtime 之前,您需要将正确的依赖项添加到您的构建工具中。Maven 存储库是为包括 Maven 和 Gradle 在内的各种工具设置 ONNX Runtime 的一个好来源。要在 x64 架构上使用具有默认 CPU 的运行时,请参阅以下链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform

要在 x64 架构上使用带有 GPU 的运行时,请使用以下链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform-gpu

安装好运行时后,可以通过下面的 import 语句将其导入到您的 Java 代码文件中。导入 TensorProto 工具的 import 语句将有助于我们创建 ONNX 模型的输入,并有助于解释 ONNX 模型的输出(预测)。

import ai.onnxruntime.OnnxMl.TensorProto;
import ai.onnxruntime.OnnxMl.TensorProto.DataType;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;

加载 ONNX 模型

下面的代码片段展示了如何在运行在 Java 中的 ONNX Runtime 中加载 ONNX 模型。此代码创建了一个可以用来进行预测的会话对象。这里使用的模型是从 PyTorch 导出的 ONNX 模型。

这里有几点值得注意。首先,您需要查询会话以获取其输入。这是通过会话的 getInputInfo 方法完成的。我们的 MNIST 模型只有一个输入参数:一个包含 784 个浮点数的数组,代表 MNIST 数据集中的一张图像。如果您的模型有多个输入参数,那么 InputMetadata 将为每个参数提供一个条目。

Utilities.LoadTensorData();
String modelPath = "pytorch_mnist.onnx";

try (OrtSession session = env.createSession(modelPath, options)) {
   Map<String, NodeInfo> inputMetaMap = session.getInputInfo();
   Map<String, OnnxTensor> container = new HashMap<>();
   NodeInfo inputMeta = inputMetaMap.values().iterator().next();

   float[] inputData = Utilities.ImageData[imageIndex];
   string label = Utilities.ImageLabels[imageIndex];
   System.out.println("Selected image is the number: " + label);

   // this is the data for only one input tensor for this model
   Object tensorData =
            OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape());
   OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorData);
   container.put(inputMeta.getName(), inputTensor);

   // Run code omitted for brevity.

}

上面的代码中没有显示读取原始 MNIST 图像并将每个图像转换为 784 个浮点数数组的实用程序。每张图像的标签也从 MNIST 数据集中读取,以便确定预测的准确性。这段代码是标准的 Java 代码,但仍然鼓励您查看并使用它。如果您需要读取类似于 MNIST 数据集的图像,它将为您节省时间。

使用 ONNX Runtime 进行预测

下面的函数展示了如何使用我们在加载 ONNX 模型时创建的 ONNX 会话。

try (OrtSession session = env.createSession(modelPath, options)) {

   // Load code not shown for brevity.

   // Run the inference
   try (OrtSession.Result results = session.run(container)) {

      // Only iterates once
      for (Map.Entry<String, OnnxValue> r : results) {
         OnnxValue resultValue = r.getValue();
         OnnxTensor resultTensor = (OnnxTensor) resultValue;
         resultTensor.getValue()
         System.out.println("Output Name: {0}", r.Name);
         int prediction = MaxProbability(resultTensor);
         System.out.println("Prediction: " + prediction.ToString());
	}
   }
}

大多数神经网络不会直接返回预测。它们会返回每个输出类别的概率列表。对于我们的 MNIST 模型,每张图像的返回值将是一个包含 10 个概率的列表。概率最高的条目就是预测。一个有趣的测试是比较 ONNX 模型返回的概率与原始模型在其创建的框架内运行时返回的概率。理想情况下,模型格式和运行时的更改不应改变任何产生的概率。这可以作为每次模型发生更改时运行的良好单元测试。

总结和后续步骤

在本文中,我简要概述了 ONNX Runtime 和 ONNX 格式。然后,我展示了如何在 ONNX Runtime 中使用 Java 加载和运行 ONNX 模型。

本文的代码示例包含一个可运行的控制台应用程序,演示了此处展示的所有技术。此代码示例是 Github 存储库的一部分,该存储库探索了如何使用神经网络来预测 MNIST 数据集中找到的数字。具体来说,有示例展示了如何在 Keras、PyTorch、TensorFlow 1.0 和 TensorFlow 2.0 中创建神经网络。

如果您想了解更多关于导出到 ONNX 格式和使用 ONNX Runtime 的信息,请查看本系列的其它文章。

参考文献

© . All rights reserved.