多元回归构造,无需循环遍历 numpy 数组



我正在构建 1500 个不同的模型,以使用相同的 1500 个预测因子预测 1500 个不同的y值,Xs ,在线性模型中。我每个都有 15 个数据点。我把这些Ys放在一个数组中,Xs放在另一个数组中。

Ys = np.random.rand(15,1500)
Xs = np.random.rand(15,1500)

我可以遍历Ys列并拟合我的模型并获得所有Xs的系数。

>>> from sklearn import linear_model
>>> clf = linear_model.LinearRegression()
>>> def f(Ys,Xs):
...     for i in range(Ys.shape[1]):
...         clf.fit(Xs,Ys[:,i])
...         print clf.coef_
>>> f(Ys,Xs)
[ 0.00415945  0.00518805  0.00200809 ..., -0.00293134  0.00405276
 -0.00082493]
[-0.00278009 -0.00926449  0.00849694 ..., -0.00183793  0.00493365
 -0.00053502]
[-0.004892   -0.00067937  0.00490643 ...,  0.00074988  0.00166438
  0.00197527]...

这足够有效,但是循环遍历Ys列似乎是处理这些数组的低效方法,尤其是在我将交叉验证引入图片之后。

是否有某种apply等价物(如pandas)可以使它更有效率?

几个想法:

(1) 假设每个线性模型的预测变量 (1500) 多于数据点 (15),您的模型将过度拟合训练数据(它们对新数据没有预测能力)。 考虑改用岭回归 (http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html)

(2) 如果在一系列线性模型中重复使用同一组预测变量,则可以考虑到线性回归的解为 coef = inv(Xs'*Xs)Xsy 。 请注意,inv(Xs'*Xs)*Xs 对于每个线性模型都是相同的。 因此,您可以将所有线性模型同时计算为 inv(Xs'*Xs)XsYs。 如果您最终使用 Ridge 回归,则需要稍微修改此公式以 inv(Xs'Xs + alpha I)Xs Ys(其中 I 是 15 x 15 的单位矩阵)。

线性回归估计器支持开箱即用的多目标回归,您可以简单地执行以下操作:

>>> import numpy as np
>>> Ys = np.random.rand(15,1500)
>>> Xs = np.random.rand(15,1500)
>>> from sklearn.linear_model import LinearRegression
>>> clf = LinearRegression().fit(Xs, Ys)

系数存储在形状(n_targets、n_features)的coef_属性中:

>>> clf.coef_
array([[  5.55249034e-03,   4.80064644e-03,  -9.84935468e-03, ...,
     -4.56988996e-03,   1.13633031e-03,   1.76111517e-03],
   [ -3.92718305e-03,  -3.97534623e-03,   6.19243263e-03, ...,
     -1.87971624e-03,  -1.45732814e-03,   1.51018259e-03],
   [ -6.79887329e-04,  -4.80656996e-04,   1.74724622e-03, ...,
     -3.42881741e-04,  -3.48451425e-03,  -3.85790348e-04],
   ..., 
   [ -1.73318217e-03,  -8.70409477e-03,  -9.64475499e-05, ...,
     -4.52182601e-03,   3.49238171e-03,  -1.50492517e-03],
   [  2.77132135e-05,  -7.12606751e-04,   4.32136642e-03, ...,
      3.34105396e-03,   1.98439783e-03,  -1.04102019e-03],
   [  1.93154992e-03,   2.45374075e-03,  -1.17614144e-03, ...,
     -2.33196606e-03,   1.60940753e-03,   2.04974586e-03]])

相关内容

  • 没有找到相关文章

最新更新