将特征张量存储为二进制文件的代码



基于这里的优秀答案,我尝试构建一个序列化特征张量的方法。

基于谷物包装,我尝试了以下操作:

namespace cereal
{
template <class Archive, class Derived> inline
typename std::enable_if<traits::is_output_serializable<BinaryData<typename Derived::Scalar>, Archive>::value, void>::type
save(Archive & ar, Eigen::PlainObjectBase<Derived> const & m){
const auto& d = m.dimensions();
const int dims = d.size;
for(int i=0;i<dims;i++)
{
ar(d[i]);
}
ar(binary_data(m.data(),m.size()*sizeof(typename Derived::Scalar)));
}
template <class Archive, class Derived> inline
typename std::enable_if<traits::is_input_serializable<BinaryData<typename Derived::Scalar>, Archive>::value, void>::type
load(Archive & ar, Eigen::PlainObjectBase<Derived> const & m){
const auto& d = m.dimensions();
const int dims = d.size;
for(int i=0;i<dims;i++)
{
ar(d[i]);
}
ar(binary_data(m.data(),static_cast<std::size_t>(m.size()*sizeof(typename Derived::Scalar))));
}
}

与save/load对一起使用:

Eigen::Tensor<double, 3> tensor3dmaus = {4, 3, 2};
tensor3dmaus.setValues( {{{1, 2}, {3, 4}, {5, 6}},
{{7, 8}, {9, 10}, {11, 12}},
{{13, 14}, {15, 16}, {17, 18}},
{{19, 20}, {21, 22}, {23, 24}}} );
{
std::ofstream out("eigen.cereal", std::ios::binary);
cereal::BinaryOutputArchive archive_o(out);
archive_o(tensor3dmaus);
}
std::cout << "test:" << std::endl << tensor3dmaus << std::endl;
Tensor<double,3> test_loaded;
{
std::ifstream in("eigen.cereal", std::ios::binary);
cereal::BinaryInputArchive archive_i(in);
archive_i(test_loaded);
}
std::cout << "test loaded:" << std::endl << test_loaded << std::endl;

但是,在编译时,我得到以下错误消息:

Types must either have a serialize function, load/save pair, or 
load_minimal/save_minimal pair (you may not mix these). 
Serialize functions generally have the following signature: 
template<class Archive> 
void serialize(Archive & ar) 
{ 
ar( member1, member2, member3 ); 
} 
谁能帮我弄清楚这个?我尝试使用Eigen的张量基类,而不是具有访问级别3的PlainObject基类,不幸的是,这不起作用。

任何关于如何调整代码与张量函数的提示将非常感激,谢谢!

我认为你的问题是你有

load(Archive & ar, Eigen::PlainObjectBase<Derived> const & m){

不是

load(Archive & ar, Eigen::PlainObjectBase<Derived> & m){

我有以下内容,它似乎工作得很好:

namespace cereal
{
// binary serialization
template <class Archive, typename Scalar_, int NumIndices_, int Options_, typename IndexType_> inline
typename std::enable_if<cereal::traits::is_output_serializable<cereal::BinaryData<Scalar_>, Archive>::value, void>::type
save(Archive& ar, Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ > const & t)
{
int numCells = 1;
for (auto i = 0; i < NumIndices_; ++i)
{
int size = t.dimension(i);
ar(size);
numCells*=size;
}
ar(binary_data(t.data(), numCells * sizeof(Scalar_)));
}
template <class Archive, typename Scalar_, int NumIndices_, int Options_, typename IndexType_> inline
typename std::enable_if<cereal::traits::is_input_serializable<cereal::BinaryData<Scalar_>, Archive>::value, void>::type
load(Archive& ar, Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType_> & t)
{
auto dims = t.dimensions();
int numCells = 1;
for (auto i = 0; i < NumIndices_; ++i)
{
int size;
ar(size);
dims[i] = size;
numCells*=size;
}
t.resize(dims);
ar(binary_data(t.data(), static_cast<std::size_t>(numCells * sizeof(Scalar_))));
}
}  // namespace cereal

最新更新