在我的代码中有几个实例,其中我有一个基于1xN数组系数的条件,并且需要根据这些条件设置MxN数组的整个列。在我的例子中,N是Eigen::Dynamic
, M的范围从2到4,但是在每个实例中都是一个编译时常数。
这里有一个简单的函数来说明我的意思,其中a
和b
是构成条件的1xN数组,c
是附加数据的2xN数组,res
是外参数,其列总是作为一个整体设置:
#include <iostream>
#include <Eigen/Dense>
using namespace Eigen;
template<Index nRows>
using ArrayNXd = Array<double, nRows, Dynamic>;
using Array1Xd = ArrayNXd<1>;
using Array2Xd = ArrayNXd<2>;
using Array3Xd = ArrayNXd<3>;
void asFunction(
Array3Xd& res,
const Array1Xd& a, const Array1Xd& b, const Array2Xd& c
){
for (Index col{0}; col<a.cols(); ++col){
if ( a[col] > b[col] )
res.col(col) = Array3d{
a[col] + b[col],
(a[col] + b[col]) * c(0, col),
(a[col] - b[col]) * c(1, col)
};
else
res.col(col) = Array3d{
a[col] - b[col],
a[col] + b[col],
(a[col] + b[col]) * (a[col] - b[col])
};
}
}
int main(){
Array1Xd a (3), b(3);
Array2Xd c (2, 3);
a << 1, 2, 3;
b << 0, 1, 2;
c <<
0, 1, 2,
1, 2, 3;
Array3Xd res (3,3);
asFunction(res, a, b, c);
std::cout << "as function:n" << res << "n";
return 0;
}
类似的函数在我的代码的性能关键部分*中使用,我觉得我把性能留在表上,因为使用Eigen
类型的循环通常不是最佳解决方案。
*是的,我做了侧写。
我写了与NullaryExpr
相同的函数,但这有点慢。我想这是有道理的,考虑到条件的额外评估和每行的分支:
#include <iostream>
#include <Eigen/Dense>
using namespace Eigen;
template<Index nRows>
using ArrayNXd = Array<double, nRows, Dynamic>;
using Array1Xd = ArrayNXd<1>;
using Array2Xd = ArrayNXd<2>;
using Array3Xd = ArrayNXd<3>;
class MyFunctor
{
public:
using Scalar = double;
static constexpr Index
RowsAtCompileTime { 3 },
MaxRowsAtCompileTime { 3 },
ColsAtCompileTime { Dynamic },
MaxColsAtCompileTime { Dynamic };
using DenseType = Array<
Scalar , RowsAtCompileTime, ColsAtCompileTime,
ColMajor, MaxRowsAtCompileTime, MaxColsAtCompileTime
>;
private:
typename Array1Xd::Nested m_a;
typename Array1Xd::Nested m_b;
typename Array2Xd::Nested m_c;
public:
MyFunctor(
const Array1Xd& a,
const Array1Xd& b,
const Array2Xd& c
) : m_a {a}, m_b {b}, m_c{c}
{}
bool cond(Index col) const {
return m_a[col] > m_b[col];
}
Scalar func1(Index col) const {
return m_a[col] + m_b[col];
}
Scalar func2(Index col) const {
return m_a[col] - m_b[col];
}
Scalar func3(Index row, Index col) const {
switch(row){
case 0: return func1(col);
case 1: return func1(col) * m_c(0, col);
case 2: return func2(col) * m_c(1, col);
default: __builtin_unreachable();
}
}
Scalar func4(Index row, Index col) const {
switch (row){
case 0: return func2(col);
case 1: return func1(col);
case 2: return func1(col) / func2(col);
default: __builtin_unreachable();
}
}
Scalar operator() (Index row, Index col) const {
if ( cond(col) )
return func3(row, col);
else
return func4(row, col);
}
};
using MyReturnType = Eigen::CwiseNullaryOp<
MyFunctor, typename MyFunctor::DenseType
>;
MyReturnType asFunctor(
const Array1Xd& a,
const Array1Xd& b,
const Array2Xd& c
){
using DenseType = typename MyFunctor::DenseType;
return DenseType::NullaryExpr(
3, a.cols(),
MyFunctor(a, b, c)
);
}
int main(){
Array1Xd a (3), b(3);
Array2Xd c (2, 3);
a << 1, 2, 3;
b << 0, 1, 2;
c <<
0, 1, 2,
1, 2, 3;
std::cout << "as functor:n" << asFunctor(a,b,c) << "n";
return 0;
}
我的问题是:是否有一种更有效的方法来实现类似于上面的的逻辑(计算矩阵的每一列的标量条件,根据条件返回整个列的值)使用eigen
库?
注意:使用表达式会稍微优先,因为我不需要担心内存分配,输出参数等,并且代码可以用标量来编写,这使得它更容易理解。
编辑:注2:我也试过使用<Condition>.template replicate<nRows,1>().select(..., ...)
,但它更慢,更难阅读。
你可以使用Eigen的选择方法,但它只适用于标量,所以你必须沿着一个维度循环。
const auto condition = a > b;
res.row(0) = condition.select(a + b /*true*/, a - b /*false*/);
res.row(1) = condition.select((a + b) * c.row(0), a + b);
res.row(2) = condition.select((a - b) * c.row(1), (a + b) * (a - b));
请注意,如果您对所有数组进行转置,您可能会更快。然后逐列迭代,因为特征是列为主的,所以矢量化得更好。
所以我只看了这段代码
for (Index col{0}; col<a.cols(); ++col){
if ( a[col] > b[col] )
res.col(col) = Array3d{
a[col] + b[col],
(a[col] + b[col]) * c(0, col),
(a[col] - b[col]) * c(1, col)
};
else
res.col(col) = Array3d{
a[col] - b[col],
a[col] + b[col],
(a[col] + b[col]) * (a[col] - b[col])
};
}
我怀疑,但无法证明,每次调用a[col]和b[col]时都被访问。您可能希望尝试为重用的值创建简短的临时值。例如:所以我只看了这段代码
for (Index col{0}; col<a.cols(); ++col){
auto acol=a[col];
auto bcol=b[col];
auto apb=acol+bcol;
auto amb=acol-bcol;
if ( acol > bcol )
res.col(col) = Array3d{
apb,
(apb) * c(0, col),
(amb) * c(1, col)
};
else
res.col(col) = Array3d{
amb,
apb,
(apb) * (amb)
};
}
是的,我知道这不是你想要的。也许这对你有帮助