Python scipy 重载_stats beta 发行版的函数



我需要为我的 beta 发行版重载 _stats 函数。这是我当前的代码:

from scipy.stats import beta
import scipy.stats as st
class CustomBeta(st.rv_continuous):
    def _stats(self, a, b):
        # will add own code here
        mn = a * 1.0 / (a + b)
        var = (a * b * 1.0) / (a + b + 1.0) / (a + b) ** 2.0
        g1 = 2.0 * (b - a) * sqrt((1.0 + a + b) / (a * b)) / (2 + a + b)
        g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b))
        g2 /= a * b * (a + b + 2) * (a + b + 3)
        return mn, var, g1, g2
dist = beta(4, 6)
print dist.rvs()  # works fine

dist = CustomBeta(4, 6)
print dist.rvs()  # crashes

从我的自定义对象获取_rvs()会给我一个很长的堆栈跟踪和一个错误

RuntimeError: maximum recursion depth exceeded

这与

重载_stats无关。相同的行为仅仅是由以下原因引起的

class CustomBeta(st.rv_continuous):
    pass
dist = CustomBeta(4, 6)
print(dist.rvs())  # crashes

rv_continuous的文件指出

可以通过对rv_continuous类进行子类化并至少重新定义_pdf_cdf方法来定义新的随机变量。

您需要提供其中至少一种方法来计算概率密度函数 (pdf( 或累积概率密度函数 (cdf(。

此外

[ rv_continuous ] 不能直接用作发行版。

它的用法如下:

class CustomBetaGen(st.rv_continuous):
    ...
CustomBeta = CustomBetaGen(name='CustomBeta')
dist = CustomBeta(4, 6)

最后,如果您不提供_rvs方法,rvs.()似乎无法正常工作。

将所有内容放在一起并从beta发行版中窃取适当的方法:

from scipy.stats import beta
import scipy.stats as st
import numpy as np
class CustomBetaGen(st.rv_continuous):
    def _cdf(self, x, a, b):
        return beta.cdf(x, a, b)
    def _pdf(self, x, a, b):
        return beta.pdf(x, a, b)
    def _rvs(self, a, b):
        return beta.rvs(a, b)
    def _stats(self, a, b):
        # will add own code here
        mn = a * 1.0 / (a + b)
        var = (a * b * 1.0) / (a + b + 1.0) / (a + b) ** 2.0
        g1 = 2.0 * (b - a) * np.sqrt((1.0 + a + b) / (a * b)) / (2 + a + b)
        g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b))
        g2 /= a * b * (a + b + 2) * (a + b + 3)
        return mn, var, g1, g2
CustomBeta = CustomBetaGen(name='CustomBeta')
dist = beta(4, 6)
print(dist.rvs())  # works fine
print(dist.stats())  # (array(0.4), array(0.021818181818181816))
dist = CustomBeta(4, 6)
print(dist.rvs())  # works fine
print(dist.stats())  # (array(0.4), array(0.021818181818181816))

最新更新