我正试图从OLS拟合中获得样本预测,如下所示,
import numpy as np
import pandas as pd
import statsmodels.api as sm
macrodata = sm.datasets.macrodata.load_pandas().data
macrodata.index = pd.period_range('1959Q1', '2009Q3', freq='Q')
mod = sm.OLS(macrodata['realgdp'], sm.add_constant(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']])).fit()
mod.get_prediction(sm.add_constant(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']])).summary_frame(0.95).head()
这很好。但如果我改变mod.get_prediction
中回归量的位置,我得到了不同的估计,
mod.get_prediction(sm.add_constant(macrodata[['tbilrate', 'unemp', 'realdpi', 'realinv']])).summary_frame(0.95).head()
这很令人惊讶。mod.get_prediction
不能根据列名识别回归量吗?
正如评论中所指出的,sm.OLS
将把你的数据帧转换成一个数组进行拟合,同样地,对于预测,它期望预测器以相同的顺序。
如果希望使用列名,可以使用公式接口,请参阅文档了解更多详细信息。下面我应用你的例子:
import statsmodels.api as sm
import statsmodels.formula.api as smf
macrodata = sm.datasets.macrodata.load_pandas().data
mod = smf.ols(formula='realgdp ~ realdpi + realinv + tbilrate + unemp', data=macrodata)
res = mod.fit()
按提供的顺序:
res.get_prediction(macrodata[['realdpi', 'realinv', 'tbilrate', 'unemp']]).summary_frame(0.95).head()
mean mean_se mean_ci_lower mean_ci_upper obs_ci_lower obs_ci_upper
0 2716.423418 14.608110 2715.506229 2717.340607 2710.782460 2722.064376
1 2802.820840 13.714821 2801.959737 2803.681943 2797.188729 2808.452951
2 2781.041564 12.615903 2780.249458 2781.833670 2775.419588 2786.663539
3 2786.894138 12.387428 2786.116377 2787.671899 2781.274166 2792.514110
4 2848.982580 13.394688 2848.141577 2849.823583 2843.353507 2854.611653
如果我们翻转列,结果是一样的:
res.get_prediction(macrodata[['tbilrate', 'unemp', 'realdpi', 'realinv']]).summary_frame(0.95).head()
mean mean_se mean_ci_lower mean_ci_upper obs_ci_lower obs_ci_upper
0 2716.423418 14.608110 2715.506229 2717.340607 2710.782460 2722.064376
1 2802.820840 13.714821 2801.959737 2803.681943 2797.188729 2808.452951
2 2781.041564 12.615903 2780.249458 2781.833670 2775.419588 2786.663539
3 2786.894138 12.387428 2786.116377 2787.671899 2781.274166 2792.514110
4 2848.982580 13.394688 2848.141577 2849.823583 2843.353507 2854.611653