如何在Android上以CSV作为输入正确运行tensorflow lite推理



我有一个tflite模型,期望输入形状为(1000,12(。只是为了测试它,我打算加载一个CSV文件并在上面运行推理。下面是我的代码和运行它时得到的错误消息的相关部分。

我认为我在正确加载或读取CSV文件时犯了一个错误。我是安卓系统的新手,很乐意在这件事上得到任何帮助!

val testModel = myModel.newInstance(context)
// Creates inputs for reference.
val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 1000, 12), DataType.FLOAT32)
val openRawResource = resources.openRawResource(R.raw.inputdata).readBytes()
val byteBuffer = ByteBuffer.wrap(openRawResource)
// inputFeature0.loadBuffer(byteBuffer)
inputFeature0.loadBuffer(byteBuffer)
// Runs model inference and gets result.
val outputs = testModel.process(inputFeature0)
val outputFeature0 = outputs.outputFeature0AsTensorBuffer
// Releases model resources if no longer used.
model.close()
Caused by: java.lang.IllegalArgumentException: The size of byte buffer and the shape do not match.
at org.tensorflow.lite.support.common.SupportPreconditions.checkArgument(SupportPreconditions.java:104)
at org.tensorflow.lite.support.tensorbuffer.TensorBuffer.loadBuffer(TensorBuffer.java:296)
at org.tensorflow.lite.support.tensorbuffer.TensorBuffer.loadBuffer(TensorBuffer.java:323)
at com.example.ecgclassifier.MainActivity.analyze(MainActivity.kt:47)
at com.example.ecgclassifier.MainActivity.onCreate(MainActivity.kt:23)

请确保上面的原始资源字节数组有一个用于张量的浮点缓冲区数组,形状为[11000,12]。

错误消息称,来自byteBuffer的给定字节数组与[11000,12]形状的浮点张量的大小要求不匹配。

最新更新