在 Go 中使用预训练模型和 TensorFlow





5.00/5 (4投票s)
探讨如何将 TensorFlow 的预训练模型之一集成到 Go 中执行 - 特别是检测图像中的多个对象。
机器学习面临的挑战之一是如何将训练好的模型部署到生产环境中。在训练完模型后,您可以“冻结”权重,并将其导出以在生产环境中使用,根据您的应用程序部署到任意数量的服务器实例上。
对于许多常见的用例,我们开始看到组织分享他们现成的训练模型,并且在 TensorFlow models repo 中已经有许多最常见的模型可供使用了。
对于许多构建大型 Web 服务的开发者来说,Go 已成为一种选择的语言。Go 也有一个不断增长的数据科学社区,但与 Python 等其他语言相比,一些工具仍然缺少文档或功能。
在本文中,我们将探讨如何将 TensorFlow 的预训练模型之一集成到 Go 中执行。我们将探讨的具体用例是在任何图像中检测多个对象——这是机器学习做得非常好的事情。在这种情况下,我们将使用新发布的 TensorFlow Object Detection 模型,该模型是在 COCO dataset 上训练的。
我们将构建一个小型命令行应用程序,该应用程序接收任何 JPG 图像作为输入,并输出另一张图像,其中识别出的对象已在图像中标注。您可以在此 Go CLI for using COCO models repo 中找到本文中所有代码以及完整的应用程序。
一个好的起点是查看 Go TensorFlow 绑定中包含的 示例应用程序,它使用 Inception 模型进行对象识别。尽管文档不详尽,但这个例子可以为我们提供宝贵的线索,指导我们如何使用绑定来处理其他预训练模型,这些模型类似但又不完全相同。
对于我们的用途,我们将使用在 COCO 数据集上训练的多对象检测模型。您可以在 GitHub 上找到该模型。您可以选择任何模型进行下载。我们将为了速度而牺牲一点精度,并使用移动端模型 ssd_mobilenet_v1_coco。
提取模型后,我们程序的第一个步骤是加载冻结图,以便我们可以使用它来识别我们的图像。
// Load a frozen graph to use for queries modelpath := filepath.Join(*modeldir, "frozen_inference_graph.pb") model, err := ioutil.ReadFile(modelpath) if err != nil { log.Fatal(err) } // Construct an in-memory graph from the serialized form. graph := tf.NewGraph() if err := graph.Import(model, ""); err != nil { log.Fatal(err) } // Create a session for inference over graph. session, err := tf.NewSession(graph, nil) if err != nil { log.Fatal(err) } defer session.Close()
令人欣慰的是,正如您所见,我们可以直接将协议缓冲区文件馈送到 NewGraph
函数,它将对其进行解码并构建图。然后,我们只需使用此图设置一个会话,就可以继续下一步了。
在最近一篇总结我 GopherCon 演讲的博文中,我使用了 LoadSavedModel
来加载一个模型,该模型是我在 Python 中训练并导出的,以便在 Go 中使用。在这种情况下,我们不能使用 LoadSavedModel
,而必须像上面那样直接加载图。
现在我们有了图,如何使用它来识别图像?这个图的输入和输出节点是什么?我们的输入数据需要什么形状?
不幸的是,这些问题都没有简单或文档齐全的答案!
步骤 1:识别图的输入和输出节点
在我 GopherCon 演示 中,我开玩笑说我直接查看协议缓冲区以找到模型 TensorFlow 节点的名称,并说如果您很聪明,您可以在导出模型之前在 Python 中打印出名称,或者将它们转储到磁盘文件。
令人惊讶的是,这仍然是一种可能的有效策略。结合一些关于该主题的谷歌搜索和对源代码的深入研究,您会发现该模型的节点如下:
节点名称 | 输入/输出 | 形状 | 数据描述 |
image_tensor | 输入 | [1,?,?,3] | uint8 格式的 RGB 像素值,呈方形(宽度、高度)。第一列代表批次大小。 |
detection_boxes | 输出 | [?][4] | 检测到的每个对象的边界框数组,格式为 [yMin , xMin , yMax , xMax ] |
detection_scores | 输出 | [?] | 检测到的每个对象的概率分数数组,介于 0..1 之间 |
detection_classes | 输出 | [?] | 检测到的每个对象在 COCO 对象基础上对应的类索引数组 |
num_detections | 输出 | [1] | 检测数量 |
我建议,在发布模型时,最好将此信息作为文档的一部分。一旦您获得了与输入/输出相关的节点名称,就可以使用 Shape
方法显示这些输入的形状。在我们的例子中,输入形状与前面提到的 Inception 示例中使用的形状相似。
从这里开始,我们现在可以着手加载图像并将其转换为我们可以用于图的格式。
步骤 2:加载图像并将其转换为张量
接下来我们需要做的是加载作为命令行参数提供的图像,并将其转换为张量。有多种方法可以做到这一点,但 Google 在 Inception 示例中使用的其中一种方法非常酷,所以我将在此使用它的简化版本来演示您可以做到这一点——我之前并不知道可以这样做,也许你们也不知道!
因此,在加载 JPG 文件后,我们可以构造一个 TensorFlow 图来解码它,并输出一个张量以馈送到检测图中。太棒了!请记住,严格来说,TensorFlow 是一个通用计算图库,因此它实际上可以执行相当广泛的功能,并且在我们尝试转换数据以与 TensorFlow 一起使用的情况下,使用它来帮助我们完成这项工作是说得通的。
func decodeJpegGraph() (graph *tf.Graph, input, output tf.Output, err error) { s := op.NewScope() input = op.Placeholder(s, tf.String) output = op.ExpandDims(s, op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), op.Const(s.SubScope("make_batch"), int32(0))) graph, err = s.Finalize() return graph, input, output, err }
实际上,这个图所做的唯一事情就是解码 JPG,但有一个很酷的地方是,如果您还记得,我们的输入需要是 [1,?,?,3] 的形式,并且我们需要添加一列来指示批次大小。在我们的例子中,因为只有一张图像,我们只是添加了一个额外的 0,但 ExpandDims
操作可以轻松地操作我们的 JPG 数据。
我们将此图的输入和输出节点作为此函数的返回值返回,然后就可以使用 JPG 数据作为输入,在该图上运行一个会话,而这个图的输出将是我们可以在 COCO 检测图(执行实际工作)中使用的张量。
步骤 3:执行 COCO 图以识别对象
我们现在已经将图像转换为张量,并且已经识别了 COCO 图上的所有输入和输出节点。现在,如果我们在一个会话中执行该图,我们将收到一个列表,其中包含图像中检测到的对象的概率。
output, err := session.Run(
map[tf.Output]*tf.Tensor{
inputop.Output(0): tensor,
},
[]tf.Output{
o1.Output(0),
o2.Output(0),
o3.Output(0),
o4.Output(0),
},
nil)
上面的变量 tensor
是前面我们构建的 DecodeJpeg
图的输出。输出列表(o1,o2,o3,o4
)是上面表格中概述的各种输出。在此阶段,我们可以解析输出结果。
解析结果时需要注意几点:
- 您可能希望设置一个阈值,低于该阈值您希望忽略结果,因为算法会尝试检测概率非常低的事物。我过滤掉了低于 40% 置信度的所有结果。
detection_scores
列表是按概率排序的,并且每个对应的数组也同样排序。因此,例如,索引 0 将是检测到的概率最高的对象。detection_boxes
将包含其边界框的坐标,而detection_classes
将包含对象的类标签(例如,对象名称:“dog
”、“person
” 等)。- 框坐标是归一化的,因此如果您想将它们转换为图像中的像素坐标,则必须确保缓存原始 JPG 的宽度和高度。
步骤 4:可视化输出
仅仅打印出概率列表、类索引和边界框尺寸并不是很有趣,所以我们将扩展我们的 CLI,输出一个图像版本,其中模型的结果已渲染到其中。就像许多现有示例一样,我们将绘制边界框并以置信度百分比标记它们。
我们将仅使用内置的 image
包进行一些基本渲染,以及内置字体来渲染标签。
TensorFlow 仓库中找到的标签似乎已过时,与模型不符,因为我能够检测到标签文件中不存在的对象。因此,我在这里使用(并包含)来自 COCO-Stuff repo 的扩展 COCO 对象列表,格式为每行一个。
将标签加载到一个简单的数组中,并利用 Go 的标准库进行图像处理的工具,我们可以相对轻松地遍历结果并输出一个已识别每个对象的图像。我们将使用 Google 提供的标准的 dog 图像作为示例。
总结
就是这样!我们现在有了一个小型 Go 程序,可以接收任何图像并使用 Google 提供的流行的 COCO TensorFlow 模型对其进行识别。我在 GitHub 的 Go CLI for COCO repo 中发布了此程序的所有代码,包括本文中的所有代码片段。
很明显,在文档记录这些模型的“API”以及如何处理它们的最佳实践方面,仍然有大量工作要做。然而,运营化机器学习算法正成为一种越来越普遍的用例,并且随着 Go 在服务器应用程序中的广泛使用,拥有关于如何使用这些模型的良好工具和信息至关重要。
>,