TensorFlow C API的TF_SessionRun无法运行在其他成员函数中创建的会话



我有一个C++类RF,它有两个成员函数:load((用于加载savedmodel(.pb(,detect((用于多次运行模型进行推理。我有另一个C++类IMPL,它有两个函数:一个initialize((,另一个create((。RF的一个对象是IMPL的成员。当我在IMPL的同一函数中运行RF::load((和RF::detect((时,它运行得很好但是,当我在IMPL::initialize((中调用RF::load((,在IMPL::create((中呼叫RF::detect((时,TF_SessionRun((会冻结,不会产生任何输出或错误我的代码或TF_SessionRun((出了什么问题?

class RF
{
public:
RF(){}
~RF();
int load();
int detect(cv::Mat Image&);
const char* saved_model_dir = "path/";// set model path 
const char* tags = "serve";    
int ntags = 1;
private:
const int NumInputs = 1;
const int NumOutputs = 1;
const unsigned short int LengthFV = 16;

TF_Graph* Graph;
TF_Status* Status;
TF_Session* Session;
TF_SessionOptions* SessionOpts;
TF_Buffer* RunOpts;
TF_Output* Input;
TF_Output* Output;
TF_Tensor** InputValues;
TF_Tensor** OutputValues;
TF_Tensor* vecTensor;
TF_Tensor * imgTensor;
};
RF::~RF()
{
// release memory;
}
int RF::load()
{
Graph = TF_NewGraph();
Status = TF_NewStatus();
SessionOpts = TF_NewSessionOptions();
RunOpts = NULL;
Session = TF_LoadSessionFromSavedModel(SessionOpts, RunOpts, saved_model_dir, &tags, ntags, Graph, NULL, Status);
if(TF_GetCode(Status) != TF_OK) {
std::cout << "nERROR: Failed to load SavedModel." << TF_Message(Status);
return -1;  }
// prepare input and output tensors
Input = new TF_Output[NumInputs];// input tensor
Output = new TF_Output[NumOutputs];// output tensor
TF_Output t0 = {TF_GraphOperationByName(Graph, "serving_default_input_image"), 0};  
if(t0.oper == NULL) {
std::cout << "nERROR: Failed TF_GraphOperationByName serving_default_input_image.";
return -1;  }
Input[0] = t0;       

TF_Output t2 = {TF_GraphOperationByName(Graph, "StatefulPartitionedCall"), 0};  
if(t2.oper == NULL) {
std::cout << "nERROR: Failed TF_GraphOperationByName StatefulPartitionedCall.";   
return -1;  }
Output[0] = t2;
// allocate data for input and output tensors
InputValues  = (TF_Tensor**)new TF_Tensor*[NumInputs];
OutputValues = (TF_Tensor**)new TF_Tensor*[NumOutputs];
int64_t dims_vector[] = {1, LengthFV};// depends on the output of retinaface
unsigned short int dim = 2;
std::vector<float> output_buffer(LengthFV);
vecTensor = TF_NewTensor(TF_FLOAT, dims_vector, dim, output_buffer.data(), LengthFV*sizeof(float), &NoOpDeallocator, 0);
if (vecTensor == NULL)  {
std::cout << "nERROR: Failed generation of TF_Tensor* vecTensor."; 
return -1;  }
OutputValues[0] = vecTensor;

return 0;    
}
int RF::detect(cv::Mat& img)
{
int iro = img.height;
int ico = img.width;
int ich = 3;// channels
unsigned short int batches = 1;
unsigned short int dims = 4;
int64_t dims_img[] = {batches, iro, ico, ich};// {1, rows, cols, 3}
unsigned int isi = iro * ico * ich;// size
std::vector<float> ppm_buffer(isi);
float* ppm_data=ppm_buffer.data();
for(int i=0; i<isi; i++)
ppm_data[i] = img.data.get()[i];// copy img.data to ppm_buffer
// create new tensor and assign to InputValues
//std::vector<float> input_buffer(this->input_size);
imgTensor = TF_NewTensor(TF_FLOAT, dims_img, dims, ppm_buffer.data(), isi*sizeof(float), &NoOpDeallocator, 0);
if (imgTensor == NULL)  {
std::cout << "nERROR: Failed generation of TF_Tensor* imgTensor.";
return -1;  }     
InputValues[0] = imgTensor; 
// run session
std::cout << "nRunning TF_SessionRun...";
TF_SessionRun(Session, NULL,  Input,  InputValues, NumInputs,  Output,  OutputValues, NumOutputs, NULL, 0,NULL ,  Status);
if(TF_GetCode(Status) == TF_OK)
std::cout << "nSUCCESS: TF_SessionRun is OK.";
else
{
std::cout << "nERROR: " << TF_Message(Status);
return -1;
}
return 0;
}
class IMPL
{
private:
RF rf;
public:
int initialize(int status)
{
if (rf.load() != 0 );
return -1;
else
return 0;
}
int create(cv::Mat& image)
{
if( rf.detect(image) != 0 )
return -1;
else
return 0;
}
};
main()
{
IMPL impl;
impl.initialize();
impl.create();
}

原因是使用了多个进程。initialize((和create((使用不同的进程ID。create((中的指针没有指向任何内容。下面的解决方案可能并不优雅,但它解决了我的问题。我修改了create((及其工作(现在我可以加载一次并多次使用(。


int create(cv::Mat& image)
{
if ( !rf.initialized )// add bool initialized = false; to RF
{
if (rf.load() != 0)
std::cout << "nERROR: loading failed.";
else
std::cout <<"nRF loaded.";
rf.initialized = true;
}
if( rf.detect(image) != 0 )
return -1;
else
return 0;
}

最新更新