在Tensorflow Lite C API中注册自定义运算符



我正在使用C API在Android上运行tensorflow lite。我的模型需要操作员RandomStandardNormal,它最近在tensorflowv2.4.0-rc0中作为自定义操作原型实现,这里是

TfLiteInterpreterOptionsAddCustomOp()函数列在tensorflow/lite/c_api_experimental.h:中

TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
TfLiteInterpreterOptions* options, const char* name,
const TfLiteRegistration* registration, int32_t min_version,
int32_t max_version);

看看这个例子&线程,我正在尝试像这样使用TfLiteInterpreterOptionsAddCustomOp

// create model and interpreter options
TfLiteModel *model = TfLiteModelCreateFromFile("path/to/model.tflite");
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
// register custom ops
TfLiteInterpreterOptionsAddCustomOp(options, "RandomStandardNormal", Register_RANDOM_STANDARD_NORMAL(), 1, 1);
// create the interpreter
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
TfLiteInterpreterAllocateTensors(interpreter);

我看到Register_RANDOM_STANDARD_NORMAL()函数是在tensorflow/lite/kernels/custom_ops_register.h中的tflite::ops::customC++命名空间中定义的。但是,当我试图将其包含在C文件中时,编译器会抱怨,因为namespace是C中的未知类型。

如何使用tensorflow lite C API注册自定义运算符?我是否需要使用C++编译器才能将C API与此自定义运算符一起使用,因为它是在C++中定义的?

注:编译libtensorflowlite_c.so时,我在bazel BUILD deps中包含了//tensorflow/lite/kernels:custom_ops

这似乎是通过以下解决方法在Github上得到的:

https://github.com/tensorflow/tensorflow/issues/44664#issuecomment-723310060

在tensorflow github上,@jdduke建议了一个临时解决方案:

  • extern "C"包装头添加到custom_ops_register.h
extern "C" {
TFL_CAPI_EXPORT TfLiteRegistration* TfLiteRegisterRandomStandardNormal();
}
  • random_standard_normal.cc添加extern "C"包装器实现
extern "C" {
TFL_CAPI_EXPORT TfLiteRegistration* TfLiteRegisterRandomStandardNormal() {
return tflite::ops::custom::Register_RANDOM_STANDARD_NORMAL();
}
}
  • 确保//tensorflow/lite/kernels:custom_ops作为依赖项包含在tensorflow/lite/c/BUILD
tflite_cc_shared_object(
name = "tensorflowlite_c",
linkopts = select({
"//tensorflow:ios": [
"-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)",
],
"//tensorflow:macos": [
"-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-z defs",
"-Wl,--version-script,$(location //tensorflow/lite/c:version_script.lds)",
],
}),
per_os_targets = True,
deps = [
":c_api",
":c_api_experimental",
":exported_symbols.lds",
":version_script.lds",
"//tensorflow/lite/kernels:custom_ops", # here
],
)
  • 修改我的C++代码以调用这个新的包装器函数
TfLiteInterpreterOptionsAddCustomOp(options, "RandomStandardNormal", TfLiteRegisterRandomStandardNormal(), 1, 1);

它成功了!我的张量终于在安卓上分配了:(

最新更新