我正在尝试生成一个多项式数据集。我写了一个代码
def generate_dataset1():
n = 500
X = 2 - 3 * np.random.normal(0, 1, n)
y = X - 2 * (X ** 2) + 0.5 * ( X ** 3) + np.random.normal(-3, 3, n)
m = np.random.uniform(0.3, 0.5, (n, ))
b = np.random.uniform(5, 10, (n, ))
plt.scatter(X, y, s=10)
plt.show()
现在,如果我想使用给定的公式(来自维基百科)生成一个数据集,你能告诉我我必须改变什么吗?
y = B_0 + B_1*x, B_2*x2 + B_3*x3 + ... + e
其中,x2
表示x (square)
,x3
表示x (cube)
,依此类推,e
为均值为零的不可观测随机误差。
x
与B
相乘的方法有很多,比如点积。但是我认为for循环已经足够好了。循环遍历B
和x
元素:
def generate_dataset(B, n):
# B is beta, n is number of sample
e = np.random.normal(-3, 3, n)
X = 2 - 3 * np.random.normal(0, 1, n)
y = 0
for i in range(len(B)):
y += B[i] * X**i
y += e
return X, y
def plot_dataset(X, y):
#m = np.random.uniform(0.3, 0.5, (n, )) # not sure why you need this
#b = np.random.uniform(5, 10, (n, )) # not sure why you need this
plt.scatter(X, y, s=10)
plt.show()
n = 500
B = [0, 1, -2, 0.5] # [beta0, beta1, beta2, beta3]
X, y = generate_dataset(B, 500)
plot_dataset(X, y)