65.9K
CodeProject 正在变化。 阅读更多。
Home

在 Android AI 危险检测中加载 TensorFlow 模型

starIconstarIconstarIconstarIconstarIcon

5.00/5 (2投票s)

2021年1月6日

CPOL

3分钟阅读

viewsIcon

6012

在Android上的AI危险检测系列文章中,我们将把TensorFlow Lite模型添加到项目中,并准备好进行处理。

在本系列的前一篇文章中,我们创建了一个项目,该项目将用于驾驶员的实时危险检测,并准备了一个检测模型以供TensorFlow Lite使用。 在这里,我们将继续加载模型并准备好进行图像处理。

要将模型添加到项目中,请在src/main中创建一个名为assets的新文件夹。 将TensorFlow Lite模型和包含标签的文本文件复制到src/main/assets,使其成为项目的一部分。

要使用该模型,我们必须编写代码来加载它并通过它传递数据。 检测代码将放置在一个可以被两个用户界面共享的类中,以便相同的代码可以用于静态图像(用于测试)和实时视频流。

为模型格式化我们的数据

在开始为此编写代码之前,我们需要知道模型期望其输入数据的结构如何。 数据作为多维数组传入和传出。 这也称为数据的“形状”。 通常,当您找到模型时,此信息将被记录在案。

您也可以使用工具Netron检查数据。 从此工具打开模型时,将显示组成网络的节点。 单击输入节点(显示在图形顶部)会显示输入数据(在本例中为图像)和网络输出的信息格式。 在这种情况下,我们看到输入数据是一个32位浮点数数组。 数组的维度为1x416x416x3。 这意味着网络将一次接受一张416 x 416像素的图像,其中包含红色,绿色和蓝色分量。 如果您要为此项目使用其他模型,则需要检查模型的输入和输出,并相应地调整代码。 我们将在解释结果时更详细地检查输出数据。

向项目添加一个名为Detector的新类。 用于管理已训练网络的所有代码都将添加到此类中。 构建该类时,它将接受图像并以更易于使用的格式提供结果。 我们应该向该类添加一些常量和字段,以便开始使用它。 这些字段包括一个包含已训练网络的TensorFlow Interpreter对象、模型识别的对象类列表以及应用程序上下文。

class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector<String>()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap<Int, Any>? = null
}

此类的构造函数将创建输出缓冲区,加载网络模型,并从assets文件夹加载对象类的名称。

class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector<String>()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap<Int, Any>? = null
}

测试模型

执行网络模型只需要几行代码。 当将图像提供给Detector类时,它将被调整大小以匹配网络的要求。 Bitmap图像中的数据被编码为字节。 这些值必须转换为32位浮点值。 TensorFlow Lite库包含使此类常见转换变得容易的功能。 TensorImage类型还具有一个方便的方法,允许将其用作需要输入缓冲区的方法的缓冲区。

public fun processImage(sourceImage: Bitmap) {
   val imageProcessor = ImageProcessor.Builder()
           .add(ResizeOp(IMAGE_HEIGHT, IMAGE_WIDTH, ResizeOp.ResizeMethod.BILINEAR))
           .build()
   var tImage = TensorImage(DataType.FLOAT32)
   tImage.load(sourceImage)
   tImage = imageProcessor.process(tImage)
   tfLiteInterpreter.runForMultipleInputsOutputs(arrayOf<any>(tImage.buffer), outputBuffers!!)
}</any>

要测试这一点,请向项目添加一个新布局。 该布局将具有一个简单的界面,以允许选择来自设备的图像。 所选图像将由检测器处理。

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout>
   <ImageView
       android:id="@+id/selected_image_view"
 />
   <Button
       android:id="@+id/select_image_button"
       android:onClick="onSelectImageClicked"
 />
</androidx.constraintlayout.widget.ConstraintLayout>

此activity的代码打开系统图像选择器。 选择图像并将其传递回应用程序时,它会将图像传递给检测器。

public override fun onActivityResult(reqCode: Int, resultCode: Int, data: Intent?) {
   super.onActivityResult(reqCode, resultCode, data)
   if (resultCode == RESULT_OK) {
       if (reqCode == SELECT_PICTURE) {
           val selectedUri = data!!.data
           val fileString = selectedUri!!.path
           selected_image_view!!.setImageURI(selectedUri)
           var sourceBitmap: Bitmap? = null
           try {
               sourceBitmap =
                   MediaStore.Images.Media.getBitmap(this.contentResolver, selectedUri)
               RunDetector(sourceBitmap)
           } catch (e: IOException) {
               e.printStackTrace()
           }
       }
   }
}

fun RunDetector(bitmap: Bitmap?) {
   if (detector == null) detector = Detector(this)
   detector!!.processImage(bitmap!!)
}

UI布局的结果

现在我们可以选择一个图像,检测器将处理该图像,识别其中的对象。 但是结果意味着什么? 我们如何使用这些结果来警告用户有关危险? 在本系列的下一篇文章中,我们将解释结果并向用户提供相关信息。

© . All rights reserved.