是否有任何替代的tf.pytorch中的Random_gamma ?



我正在将TensorFlow存储库转换为PyTorch代码。我遇到了这行代码:

tf.squeeze(tf.random_gamma(shape =(self.n_sample,),alpha=self.alpha+tf.to_float(self.B)))

我想知道等效的tf.random_gamma在PyTorch。我认为torch.distributions.gamma.Gamma的工作方式不同。

看起来在这种情况下可以使用torch.distributions.gamma.Gamma。下面是一个例子:

import torch
from torch.distributions.gamma import Gamma

def random_gamma(shape, alpha, beta=1.0):
alpha = torch.ones(shape) * torch.tensor(alpha)
beta = torch.ones(shape) * torch.tensor(beta)
gamma_distribution = Gamma(alpha, beta)
return gamma_distribution.sample()
print(random_gamma(shape=(10,), alpha=3.0))

输出:

tensor([2.7673, 1.5498, 6.5191, 5.2923, 3.3204, 3.9286, 1.4163, 1.2400, 3.9661, 1.7663])

不同之处在于torch.distributions.gamma.Gamma需要alpha和beta的完全张量,而不是像在TF中那样的形状+值。此外,TF版本的默认值为1,我试图在示例代码中模仿。

创建一次分布实例是有意义的,以防函数将被多次使用。

最新更新