同时接受特征密集矩阵和稀疏矩阵的函数



我正在努力将稀疏矩阵支持添加到开源数学库中,并希望DenseSparse矩阵类型都没有重复的函数。

下面的示例显示了一个add函数。一个具有两个函数的工作示例,然后是两次失败的尝试。下面提供了代码示例的 godbolt 链接。

我已经查看了关于编写采用特征类型的函数的特征文档,但他们使用Eigen::EigenBase的答案不起作用,因为MatrixBaseSparseMatrixBase都有可用的特定方法,而这些方法在EigenBase中不存在

https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html

我们使用C++14,任何帮助,非常感谢您的时间!

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>
// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;
// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;
// Below two are the generics that work
template <class Derived>
eigen_return_t<Derived> add(const Eigen::MatrixBase<Derived>& A) {
return A + A;
}
template <class Derived>
eigen_return_t<Derived> add(const Eigen::SparseMatrixBase<Derived>& A) {
return A + A;
}
int main()
{
// Fill up the sparse and dense matrices
tripletList.reserve(4);
tripletList.push_back(triplet_d(0, 0, 1));
tripletList.push_back(triplet_d(0, 1, 2));
tripletList.push_back(triplet_d(1, 0, 3));
tripletList.push_back(triplet_d(1, 1, 4));
sparse_mat_d mat(2, 2);
mat.setFromTriplets(tripletList.begin(), tripletList.end());
Eigen::Matrix<double, -1, -1> v(2, 2);
v << 1, 2, 3, 4;
// Works fine
sparse_mat_d output = add(mat * mat);
std::cout << output;
// Works fine
Eigen::Matrix<double, -1, -1> output2 = add(v * v);
std::cout << output2;
} 

我只想有一个同时接受稀疏和密集矩阵的函数,而不是两个 add 函数,但下面的尝试没有成功。

模板模板类型

我的尝试显然很糟糕,但是用模板模板类型替换上面的两个add函数会导致模棱两可的基类错误。

template <template <class> class Container, class Derived>
Container<Derived> add(const Container<Derived>& A) {
return A + A;    
}

错误:

<source>: In function 'int main()':
<source>:35:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
35 |   sparse_mat_d output = add(mat * mat);
|                                      ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
20 | Container<Derived> add(const Container<Derived>& A) {
|                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:35:38: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>'
35 |   sparse_mat_d output = add(mat * mat);
|                                      ^
<source>:40:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
|                                                    ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
20 | Container<Derived> add(const Container<Derived>& A) {
|                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:40:52: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>'
40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
|                                                    ^

我相信这是同样的钻石继承问题:

https://www.fluentcpp.com/2017/05/19/crtp-helper/

使用 std::conditional_t

下面尝试使用conditional_t来推断正确的输入类型

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>
// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;

// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;
// Check it Object inherits from DenseBase
template<typename Derived>
using is_dense_matrix_expression = std::is_base_of<Eigen::DenseBase<std::decay_t<Derived>>, std::decay_t<Derived>>;
// Check it Object inherits from EigenBase
template<typename Derived>
using is_eigen_expression = std::is_base_of<Eigen::EigenBase<std::decay_t<Derived>>, std::decay_t<Derived>>;
// Alias to deduce if input should be Dense or Sparse matrix
template <typename Derived>
using eigen_matrix = typename std::conditional_t<is_dense_matrix_expression<Derived>::value,
typename Eigen::MatrixBase<Derived>, typename Eigen::SparseMatrixBase<Derived>>;
template <typename Derived>
eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
return A + A;
}
int main()
{
tripletList.reserve(4);
tripletList.push_back(triplet_d(0, 0, 1));
tripletList.push_back(triplet_d(0, 1, 2));
tripletList.push_back(triplet_d(1, 0, 3));
tripletList.push_back(triplet_d(1, 1, 4));
sparse_mat_d mat(2, 2);
mat.setFromTriplets(tripletList.begin(), tripletList.end());
sparse_mat_d output = add(mat * mat);
std::cout << output;
Eigen::Matrix<double, -1, -1> v(2, 2);
v << 1, 2, 3, 4;
Eigen::Matrix<double, -1, -1> output2 = add(v * v);
std::cout << output2;
} 

这将引发错误

<source>: In function 'int main()':
<source>:94:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
94 |   sparse_mat_d output = add(mat * mat);
|                                      ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
|                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:94:38: note:   couldn't deduce template parameter 'Derived'
94 |   sparse_mat_d output = add(mat * mat);
|                                      ^
<source>:99:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
|                                                    ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
|                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:99:52: note:   couldn't deduce template parameter 'Derived'
99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);

这似乎是因为无法像此链接那样推断依赖类型的依赖参数。

https://deque.blog/2017/10/12/why-template-parameters-of-dependent-type-names-cannot-be-deduced-and-what-to-do-about-it/

戈博尔特示例

下面的神霹雳有上面的所有实例可供玩

https://godbolt.org/z/yKEAsn

有没有办法只有一个功能而不是两个?我们有很多函数可以同时支持稀疏矩阵和密集矩阵,因此最好避免代码重复。

编辑:可能的答案

@Max朗霍夫建议使用

template <class Mat>
auto add(const Mat& A) {
return A + A; 
}

auto关键字对于本征有点危险

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

template <class Mat> 
typename Mat::PlainObject add(const Mat& A) { 
return A + A; 
}

有效,尽管 tbh 我不完全确定为什么在这种情况下返回普通对象有效

编辑

编辑

有几个人提到了auto关键字的使用。可悲的是,Eigen 不能很好地与auto在 C++11 上的第二个和下面的链接中的自动中引用的那样

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

在某些情况下可以使用 auto,尽管我想看看是否有一种通用的auto'ish 方式是对 Eigen 模板返回类型的抱怨

有关带有自动的段错误的示例,您可以尝试将添加替换为

template <typename T1>
auto add(const T1& A) 
{
return ((A+A).eval()).transpose();
}

如果你想传递EigenBase<Derived>,你可以使用.derived()提取底层类型(本质上,这只是强制转换为Derived const&):

template <class Derived>
eigen_return_t<Derived> add(const Eigen::EigenBase<Derived>& A_) {
Derived const& A = A_.derived();
return A + A;
}

更高级的是,对于此特定示例,由于您使用了两次A,因此可以使用内部评估器结构来表达:

template <class Derived>
eigen_return_t<Derived> add2(const Eigen::EigenBase<Derived>& A_) {
// A is used twice:
typedef typename Eigen::internal::nested_eval<Derived,2>::type NestedA;
NestedA A (A_.derived());
return A + A;
}

这样做的好处是,当将产品作为A_传递时,在评估A+A时不会对其进行两次评估,但如果A_类似于Block<...>,则不会不必要地复制。但是,实际上并不建议使用internal功能(其API可能随时更改)。

编译器的问题如下:

无法推断模板参数"派生">

传递Derived所需的类型可能应该有效,如下所示:

add<double>(v * v)

但是我不确定,因为在我看来Eigen::MatrixEigen::MatrixBase类型不同。

但是,如果你对编译器的限制较少,它将能够找出类型:

template <typename T>
auto add(const T& A) {
return A + A;
}

编辑:

刚刚在评论中看到该解决方案已经发布,并且 Eigen 文档建议不要使用auto.我不熟悉 Eigen,但在我看来,从浏览文档来看,可能是 Eigen 产生了表示表达式的结果 - 例如,将矩阵加法表示为算法的对象;不是矩阵加法结果本身。在这种情况下,如果您知道A + A会导致类型T(在我看来它实际上应该用于operator+),您可以这样写:

template <typename T>
T add(const T& A) {
return A + A;
}

在矩阵示例中,这应该强制返回矩阵结果;而不是表示表达式的对象。但是,由于您最初一直在使用eigen_result_t,因此我不是100%确定。

我还没有理解你所有的代码和注释。无论如何,您的问题似乎归结为找到一种方法来编写可以接受服务器矩阵类型的函数。

template <typename T>
auto add(const T& A)
{
return 2*A;
}

您还可以添加 2 个不同类型的矩阵:

template <typename T1, typename T2>
auto add(const T1& A, const T2& B) -> decltype(A+B) // decltype can be omitted since c++14
{
return A + B;
}

然后,add(A,A)给出与add(A)相同的结果。但是我认为带有 2 个参数的add函数更有意义。而且它更通用,因为您可以将稀疏矩阵与密集矩阵相加。

int main()
{
constexpr size_t size = 10;
Eigen::SparseMatrix<double> spm_heap(size,size);
Eigen::MatrixXd m_heap(size,size);
Eigen::Matrix<double,size,size> m_stack; 
// fill the matrices
std::cout << add(spm_heap,m_heap);
std::cout << add(spm_heap,m_stack);
return 0;
}

编辑

关于您声明auto不应与特征一起使用的编辑。这很有趣!

template <typename T>
auto add(const T& A) 
{
return ((A+A).eval()).transpose();
}

这会产生一个segfault。为什么?auto确实很好地推导了类型,但推导的类型不是decltype(A),而是该类型的引用。为什么?我最初认为这是因为返回值周围的括号(如果有兴趣,请阅读此处),但这似乎是由于transpose函数的返回类型。

无论如何,克服这个问题很容易。正如您所建议的,您可以删除auto

template <typename T>
T add(const T& A) 
{
return ((A+A).eval()).transpose();
}

或者,您可以使用auto但指定所需的返回类型:

template <typename T>
auto add(const T& A) -> typename std::remove_reference<decltype(A)>::type // or simply decltype(A.eval())
{
return ((A+A).eval()).transpose();
}

现在,对于这个特定的add函数,第一个选项(省略auto)是最好的解决方案。但是,对于另一个需要 2 个不同类型的参数的add函数,这是一个很好的解决方案:

template <typename T1, typename T2>
auto add(const T1& A, const T2& B) -> decltype((A+B).eval())
{
return ((A+B).eval()).transpose();
}

相关内容

  • 没有找到相关文章

最新更新