如何在TensorFlow C API中创建字符串类型张量



参数列表中的data到底应该是什么?

TF_Tensor* tensorStr = TF_NewTensor(TF_STRING, nullptr, 0, &data[0], 8, no_op, nullptr);

我尝试了:

char * data = "blah";
char* data[] = {"blah"};
char data[1][4] = {{'b','l','a','h'}};

全部不幸。输入输入时。我总是得到:

Malformed TF_STRING tensor; element 0 out of range

有效(但有点丑)代码的示例,该代码创建字符串张量:

std::string input_str = "abracdabra";  // any input string
size_t encoded_size = TF_StringEncodedSize(input_str.size());
size_t total_size = 8 + encoded_size;  // 8 extra bytes - for start_offset 
char *input_encoded = (char*)malloc(total_size);
for (int i =0; i < 8; ++i) {  // fills start_offset
    input_encoded[i] = 0;
}
TF_StringEncode(input_str.c_str(), input_str.size(), input_encoded+8, encoded_size, status); // fills the rest of tensor data
if (TF_GetCode(status) != TF_OK){
    fprintf(stderr, "ERROR: something wrong with encoding: %s", TF_Message(status));
    return 1;
}
TF_Tensor* input = TF_NewTensor(TF_STRING, NULL, 0, input_encoded, total_size, &Deallocator, 0);

为什么它可以工作:https://github.com/tensorflow/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h#l213根据此链接,字符串张量的数据包含两个部分。最后一个是通过TF_Stringencode函数编码的输入字符串。第一个是数组" start_offset",我不完全理解它的角色,但看起来八个零可以做到这一点)

可以在C API测试中找到张量创建的另一个示例:https://github.com/tensorflow/tensorflow/tensorflow/blob/master/tensorflow/c/c_api_test.cc#l1934

,除非张量是标量(仅保留一个数字),则需要传递尺寸信息。您正在将NULLPTR传递给当前代码中的DIMS,这就是为什么会引起错误的原因。您可以在此处查看如何调用字符串的TF_NewTensor的示例:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.cc#l441

最新更新