如何在使用递归参数包C++时避免重复代码



在c++中使用可变参数时,如何避免代码重复?请注意,我正在递归地使用模板来实现我的目标,因此我需要一些基本情况和递归情况。这会造成大量的代码重复,有什么方法可以减少这种重复吗?

下面,提供了一个创建任意张量(N维数组(的代码示例。

它工作得很好,但重复太多了。当像这样递归地使用模板参数包时,我如何避免编写重复的代码?

#include <cstddef>
#include <array>
#include <iostream>
template<typename T, std::size_t...>
class Tensor;
template<typename T, std::size_t N>
class Tensor<T, N> {
using Type = std::array<T, N>;
Type data;
public:
Tensor()
{
zero();
}
void zero()
{
fill(0);
}
Type::iterator begin() { return data.begin(); }
Type::iterator end() { return data.end(); }
void fill(T value)
{
std::fill(data.begin(), data.end(), value);
}
void print() const
{
std::cout << "[";
for(const auto& v : data)
{
std::cout << v << ",";
}
std::cout << "]";
}
};
template<typename T, std::size_t N, std::size_t M>
class Tensor<T, N, M>
{
using Type = std::array<Tensor<T, M>, N>;
Type data;
public:
Tensor()
{
zero();
}
void zero()
{
fill(0);
}
Type::iterator begin() { return data.begin(); }
Type::iterator end() { return data.end(); }
void fill(T value)
{
for(auto& v: data) {
std::fill(v.begin(), v.end(), value);
}
}
void print() const
{
std::cout << "[";
for(const auto& v : data)
{
v.print();
std::cout << ",";
}
std::cout << "]";
}
};
template<typename T, std::size_t N, std::size_t... M>
class Tensor<T, N, M...>
{
using Type = std::array<Tensor<T, M...>, N>;
Type data;
public:
Type::iterator begin() { return data.begin(); }
Type::iterator end() { return data.end(); }
Tensor()
{
zero();
}
void zero()
{
fill(0);
}
void fill(T value)
{
for(auto& v: data) {
v.fill(value);
}
}
void print() const
{
std::cout << "[";
for(const auto& v : data)
{
v.print();
std::cout << ",";
}
std::cout << "]";
}
};

一维张量和多维张量之间的唯一区别是std::array的类型,T表示单个,Tensor<T, M...>表示另一个。

template<typename T, std::size_t N, std::size_t... M>
class Tensor<T, N, M...> {
using InnerT = std::conditional_t<(sizeof...(M) > 0),
Tensor<T, M...>,
T>;
using Type = std::array<InnerT, N>;
Type data;
}

然后,使用if constexpr来区分一维情况、

void fill(T value)
{
if constexpr(sizeof...(M) > 0) {
for(auto& v: data) {
v.fill(value);
}
} else {
std::fill(data.begin(), data.end(), value);
}
}
void print() const
{
std::cout << "[";
for(const auto& v : data)
{
if constexpr(sizeof...(M) > 0) {
v.print();
std::cout << ",";
} else {
std::cout << v << ",";
}
}
std::cout << "]";
}

演示

相关内容

  • 没有找到相关文章

最新更新