当我在新的自定义 c++ 操作中调用标准操作(例如 MatMul)时,Bazel 返回错误



我在tensorflow中实现了一个新的自定义c ++操作。在相应的操作内核的计算函数中,调用了一些标准操作(例如 MatMul(。 主要源代码是:

REGISTER_OP("NewOp")
.Input("input: int32")
.Output("output: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
using namespace tensorflow;
using namespace tensorflow::ops;
class MyNewOp : public OpKernel {
public:
explicit MyNewOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
……
// Create an output tensor
……
Scope root = Scope::NewRootScope();
auto A = Const(root, { {35.f, 22.f}, {-10.f, 0.f} });
auto b = Const(root, { {30.f, 55.f} });
auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
std::vector<Tensor> results;
ClientSession session(root);
TF_CHECK_OK(session.Run({v}, &results));
// Set the output tensor according to the results of MatMul
……
}
};
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU), MyNewOp);

而对应的Bazel BUILD文件是:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "MyNewOp.so",
srcs = ["mynewop.cc"],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:tensorflow",
],
)

当我构建上述目标时,Bazel 返回错误:

tensorflow/cc:cc_ops cannot depend on tensorflow/core:framework

如何解决此问题?我想知道我是否可以在新的自定义 c++ op 中调用 ternsorflow 预定义操作?谢谢!

您遇到的问题是,您的自定义操作依赖于此规则明确禁止的tensorflow/core:framework

disallowed_deps=[
clean_dep("//tensorflow/core:framework"),
clean_dep("//tensorflow/core:lib")
]

最好的方法是找到另一种解决方案。

如果你真的想要这种依赖关系,那么有一种黑客方式可以在没有禁止依赖的情况下重新实现tf_custom_op_library规则。

例如,可以通过这种方式完成此操作:

load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "clean_dep")
tf_cc_shared_object(
name = "MyNewOp.so",
srcs = ["mynewop.cc"],
copts = tf_copts(is_external=True),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:tensorflow",],
linkopts= select({
"//conditions:default": [
"-lm",
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
clean_dep("//tensorflow:darwin"): [],
}),
)

工作正常:

Target //tensorflow/user_ops:MyNewOp.so up-to-date:
bazel-bin/tensorflow/user_ops/MyNewOp.so
INFO: Elapsed time: 46.399s, Critical Path: 19.71s
INFO: 397 processes, local.
INFO: Build completed successfully, 400 total actions

最新更新