C++ - 如果给出参数(函数),则修改成员函数



我正在创建一个类,该类将12函数作为输入参数。

我的目标是,如果只给出一个函数func则使用num_dfunc计算成员函数dfunc(它是func的数值导数,并在类中硬编码)。如果给出两个函数funcanalytical_dfunc则使用分析版本计算导数。实现这一目标的最佳方法是什么?

这是我代码的一部分

class MyClass
{
public:
int dim = 2;
vector<double> num_dfunc(vector<double> l0)
{
// Numerical gradient of the potential up to second order
// #TODO This should be rewritten!
vector<double> result(dim);
double eps = 0.001;
for (int i = 0; i < dim; i++)
{
vector<double> lp2 = l0;
lp2[i] += 2 * eps;
vector<double> lp1 = l0;
lp1[i] += eps;
vector<double> lm1 = l0;
lm1[i] -= eps;
vector<double> lm2 = l0;
lm2[i] -= 2 * eps;
result[i] = (-func(lp2) + 8 * func(lp1) - 8 * func(lm1) + func(lm2)) / (12 * eps);
}
return result;
}
double (*func)(vector<double>);   // Potencial pointer
vector<double> (*dfunc)(vector<double>);  // Gradient pointer
MyClass(double (*func)(vector<double>))
{
this->func = func;
// THIS IS WRONG
this->dfunc = num_dfunc;
}
MyClass(double (*func)(vector<double>),double (*analytical_dfunc)(vector<double>))
{
this->func = func;
// THIS IS WRONG
this->dfunc = analytical_dfunc;
}

这在某种程度上是我想做的pythonic方式。

PS:到目前为止,这不是我所拥有的,我尝试了很多东西,但没有一个奏效。

编辑:错误,在 dfunc 返回类型上。analytical_func 上的错别字

我会num_dfunc做一个静态成员函数,或者可能更好的自由函数,与MyClass无关。我会修改它,使其也采用零阶函数和维度作为输入。

如果仅使用一个参数调用构造函数,则可以使用 lambda 从静态num_dfunc创建第二个函数。

[=](auto const& arg){
return num_dfunc(func, dim, arg);
}

符号[=]捕获lambda主体中所有需要的变量(funcdim),这些变量不属于参数列表(auto const& arg)

此外,我会让你的类成为接受任何类型的函数F的类模板,C++有许多不同类型的可调用对象,其中原始函数指针只是一个(还有函子、lambdas、std::function......如果你把你的类作为一个模板,它将适用于所有类型的函数类型。

我还没有测试过这个,但基本上你的类看起来像这样:

#include <vector>

template <typename F>
class MyClass
{
private:
static std::vector<double> num_dfunc(
F const& f,
int dim,
std::vector<double> const& l0
)
{
// Numerical gradient of the potential up to second order
// #TODO This should be rewritten!
std::vector<double> result(dim);
double eps = 0.001;
for (int i = 0; i < dim; i++)
{
std::vector<double> lp2 = l0;
lp2[i] += 2 * eps;
std::vector<double> lp1 = l0;
lp1[i] += eps;
std::vector<double> lm1 = l0;
lm1[i] -= eps;
std::vector<double> lm2 = l0;
lm2[i] -= 2 * eps;
result[i] = (-f(lp2) + 8 * f(lp1) - 8 * f(lm1) + f(lm2)) / (12 * eps);
}
return result;
}

int dim;
F func;   // Potencial pointer
F dfunc;  // Gradient pointer
public:
MyClass(F const& fun)
: dim(2)
, func(fun)
, dfunc(
[=](auto const& arg){
return num_dfunc(func, dim, arg);
}
)
{
}
MyClass(F const& fun, F const& analytical_fun)
: dim(2)
, func(fun)
, dfunc(analytical_fun)
{
}
};

您可以在编译器资源管理器上使用以下方法:https://godbolt.org/z/EMxqr7KE4

最新更新