65.9K
CodeProject 正在变化。 阅读更多。
Home

使 PyTorch AI 模型可移植到 ONNX

starIconstarIconstarIconstarIconstarIcon

5.00/5 (2投票s)

2020年9月8日

CPOL

5分钟阅读

viewsIcon

7341

downloadIcon

79

在本文中,我将为那些正在寻找用于构建和训练神经网络的深度学习框架的人们简要介绍 PyTorch。

在本系列文章中,我们将探讨如何在 2020 年使用可移植神经网络,您将学习如何将 PyTorch 模型转换为可移植的 ONNX 格式。

由于 ONNX 并不是一个用于构建和训练模型的框架,因此我将从简要介绍 PyTorch 开始。这对于从头开始并考虑使用 PyTorch 作为构建和训练模型的框架的工程师来说会很有用。

PyTorch 简介

PyTorch 于 2016 年发布,由 Facebook 的 AI 研究实验室 (FAIR) 开发。它已成为研究人员在自然语言处理和计算机视觉领域进行实验的首选框架。这很有趣,因为 TensorFlow 在生产环境中更为广泛地使用。研究实验室和生产环境之间的这种差异真正突显了像 ONNX 这样的标准的重要性,它为模型提供了一种通用格式,并提供了一种可供所有流行编程语言使用的运行时。例如,假设一个组织不希望在其生产环境中拥有所有可能的框架,而是想标准化一个。没有 ONNX,就需要在为生产选择的框架中重新实现模型并进行部署。这是一个非微不足道的工程任务。使用 ONNX,只需几行代码即可导出 PyTorch 模型,并用任何语言进行消费。生产中只需要 ONNX Runtime。

导入转换器

PyTorch 的维护者已将 ONNX 转换器集成到 PyTorch 本身中。您无需安装任何额外的包。安装 PyTorch 后,您可以通过在模块中包含以下导入来访问 PyTorch 到 ONNX 的转换器

import torch

导入 torch 模块后,您可以按如下方式访问转换函数

torch.onnx.export()

希望这是其他框架会采用的做法。将转换器与框架本身一起打包和版本化,可以减少需要安装的包的数量,还可以防止框架和转换器之间的版本不匹配。

快速浏览模型

在转换 PyTorch 模型之前,我们需要查看创建模型的代码,以确定输入的形状。下面的代码创建了一个 PyTorch 模型,该模型可以预测 MNIST 数据集中找到的数字。模型层中更详细的描述超出了本文的范围,但我们需要注意输入的形状。这里是 784。更具体地说,这段代码创建了一个模型,其中输入将是一个展平的张量,即一个包含 784 个浮点数的数组。784 有什么意义?嗯,MNIST 数据集中的每个图像都是一个 28 x 28 像素的图像。28 x 28 = 784。所以,一旦展平,我们的输入就是 784 个浮点数,每个浮点数代表一个灰度级别。底线:这个模型期望来自单个图像的 784 个浮点数。它不期望多维数组,也不期望图像批次。一次只有一个预测。这是将模型转换为 ONNX 时的重要事实。

def build_model():
   # Layer details for the neural network
   input_size = 784
   hidden_sizes = [128, 64]
   output_size = 10

   # Build a feed-forward network
   model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                       nn.ReLU(),
                       nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                       nn.ReLU(),
                       nn.Linear(hidden_sizes[1], output_size),
                       nn.LogSoftmax(dim=1))
   return model

将 PyTorch 模型转换为 ONNX

下面的函数显示了如何使用 torch.onnx.export 函数。正确使用此函数有一些技巧。第一个也是最重要的技巧是正确设置样本输入。sample_input 参数用于确定 ONNX 模型的输入。export_to_onnx 函数将接受您提供的任何内容 — 只要它是张量 — 并且转换将无错误地进行。但是,如果样本输入的形状不正确,那么在尝试从 ONNX Runtime 运行 ONNX 模型时会出现错误。

def export_to_onnx(model):
   sample_input = torch.randn(1, 784)
   print(type(sample_input))
   torch.onnx.export(model,            # model being run
                     sample_input,     # model input (or a tuple for multiple inputs)
                     ONNX_MODEL_FILE,  # where to save the model
                     input_names = ['input'],   # the model's input names
                     output_names = ['output'] # the model's output names

   # Set metadata on the model.
   onnx_model = onnx.load(ONNX_MODEL_FILE)
   meta = onnx_model.metadata_props.add()
   meta.key = "creation_date"
   meta.value = datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
   meta = onnx_model.metadata_props.add()
   meta.key = "author"
   meta.value = 'keithpij'
   onnx_model.doc_string = 'MNIST model converted from Pytorch'
   onnx_model.model_version = 3  # This must be an integer or long.
   onnx.save(onnx_model, ONNX_MODEL_FILE)

如果原始 PyTorch 模型设计为接受 100 个图像的批次,那么这个样本输入是没问题的。但是,如前所述,我们的模型在进行预测时一次只设计为接受一个图像。如果使用此样本输入导出模型,则在运行模型时会出现错误。

为模型添加元数据的代码是一种最佳实践。随着用于训练模型的数据的演变,您的模型也会随之演变。因此,为模型添加元数据是一个好主意,这样您就可以将其与之前的模型区分开来。上面的示例在 doc_string 属性中添加了对模型的简短描述并设置了版本。creation_dateauthor 是添加到 metadata_props 属性包的自定义属性。您可以随时使用此属性包创建任意数量的自定义属性。不幸的是,model_version 属性需要一个整数或长整型,所以您将无法像服务一样使用 major.minor.revision 语法对其进行版本化。此外,导出函数会自动将模型保存到文件,因此要添加此元数据,您需要重新打开文件并重新保存。

总结和后续步骤

在本文中,我为那些正在寻找用于构建和训练神经网络的深度学习框架的人们简要介绍了 PyTorch。然后,我展示了如何使用 PyTorch 本身已有的转换工具将 PyTorch 模型转换为 ONNX 格式。我还展示了为导出的模型添加元数据的最佳实践。

由于本文的目的是演示如何将 Keras 模型转换为 ONNX 格式,因此我没有详细介绍构建和训练 Keras 模型。此帖子的代码示例包含探索 Keras 本身的代码。keras_mnist.py 模块是一个完整的端到端演示,展示了如何加载数据、探索图像以及训练模型。

接下来,我们将研究如何将 TensorFlow 模型转换为 ONNX。

参考文献

© . All rights reserved.