我需要将此代码转换为pytorch。下面给出的代码使用np.vectorize。我正在寻找一个与pytorch等效的代码。
class SimplexPotentialProjection(object):
def __init__(self, potential, inversePotential, strong_convexity_const, precision = 1e-10):
self.inversePotential = inversePotential
self.gradPsi = np.vectorize(potential)
self.gradPsiInverse = np.vectorize(inversePotential)
self.precision = precision
self.strong_convexity_const = strong_convexity_const
numpy.vectorize
的文档明确指出:
提供
vectorize
功能主要是为了方便,而不是为了性能。该实现本质上是一个for循环。
因此,为了将numpy代码转换为pytorch,您只需要在循环中对其张量参数应用potential
和inversePotential
。然而,这可能是非常低效的。你最好重新执行你的职能;原生地";在张量上以矢量化的方式。