从统计模型中得到OLS拟合的预测



我正试图从OLS拟合中获得样本预测,如下所示,

import numpy as np
import pandas as pd
import statsmodels.api as sm
macrodata = sm.datasets.macrodata.load_pandas().data
macrodata.index = pd.period_range('1959Q1', '2009Q3', freq='Q')
mod = sm.OLS(macrodata['realgdp'], sm.add_constant(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']])).fit()
mod.get_prediction(sm.add_constant(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']])).summary_frame(0.95).head()

这很好。但如果我改变mod.get_prediction中回归量的位置,我得到了不同的估计,

mod.get_prediction(sm.add_constant(macrodata[['tbilrate', 'unemp', 'realdpi', 'realinv']])).summary_frame(0.95).head()

这很令人惊讶。mod.get_prediction不能根据列名识别回归量吗?

正如评论中所指出的,sm.OLS将把你的数据帧转换成一个数组进行拟合,同样地,对于预测,它期望预测器以相同的顺序。

如果希望使用列名,可以使用公式接口,请参阅文档了解更多详细信息。下面我应用你的例子:

import statsmodels.api as sm
import statsmodels.formula.api as smf
macrodata = sm.datasets.macrodata.load_pandas().data
mod = smf.ols(formula='realgdp ~ realdpi + realinv + tbilrate + unemp', data=macrodata)
res = mod.fit()

按提供的顺序:

res.get_prediction(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']]).summary_frame(0.95).head()
mean    mean_se  mean_ci_lower  mean_ci_upper  obs_ci_lower  obs_ci_upper
0  2716.423418  14.608110    2715.506229    2717.340607   2710.782460   2722.064376
1  2802.820840  13.714821    2801.959737    2803.681943   2797.188729   2808.452951
2  2781.041564  12.615903    2780.249458    2781.833670   2775.419588   2786.663539
3  2786.894138  12.387428    2786.116377    2787.671899   2781.274166   2792.514110
4  2848.982580  13.394688    2848.141577    2849.823583   2843.353507   2854.611653

如果我们翻转列,结果是一样的:

res.get_prediction(macrodata[['tbilrate', 'unemp', 'realdpi', 'realinv']]).summary_frame(0.95).head()
mean    mean_se  mean_ci_lower  mean_ci_upper  obs_ci_lower  obs_ci_upper
0  2716.423418  14.608110    2715.506229    2717.340607   2710.782460   2722.064376
1  2802.820840  13.714821    2801.959737    2803.681943   2797.188729   2808.452951
2  2781.041564  12.615903    2780.249458    2781.833670   2775.419588   2786.663539
3  2786.894138  12.387428    2786.116377    2787.671899   2781.274166   2792.514110
4  2848.982580  13.394688    2848.141577    2849.823583   2843.353507   2854.611653

最新更新