我如何使用 sklearn 按组运行多个单变量回归



我正在尝试复制这个解决方案 Python pandas:如何按组运行多个单变量回归,但使用 sklearn 线性回归而不是统计模型。

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
  'y': np.random.randn(20),
  'x1': np.random.randn(20), 
  'x2': np.random.randn(20),
  'grp': ['a', 'b'] * 10})

def ols_res(x, y):
    return pd.Series(LinearRegression.fit(x,y).predict(x))

results = df.groupby('grp').apply(lambda x : x[['x1', 'x2']].apply(ols_res, y=x['y']))
print(results)

我得到:

TypeError: ("fit() missing 1 required positional argument: 'y'", 'occurred at index x1')

结果应该与我链接的文章相同,即:

             x1        x2
grp                      
a   0 -0.102766 -0.205196
    1 -0.073282 -0.102290
    2  0.023832  0.033228
    3  0.059369 -0.017519
    4  0.003281 -0.077150
        ...       ...
b   5  0.072874 -0.002919
    6  0.180362  0.000502
    7  0.005274  0.050313
    8 -0.065506 -0.005163
    9  0.003419 -0.013829

您的代码有两个小问题:

  1. 您不实例化 LinearRegression 对象,因此您的代码实际上尝试调用 LinearRegression 的未绑定 fit 方法。

  2. 即使修复此问题,LinearRegression实例也无法执行fittransform,因为它需要一个 2D 数组并获得一个 1D 数组。因此,您还需要重塑每个Series中包含的数组。

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
  'y': np.random.randn(20),
  'x1': np.random.randn(20), 
  'x2': np.random.randn(20),
  'grp': ['a', 'b'] * 10})
def ols_res(x, y):
    x_2d = x.values.reshape(len(x), -1)
    return pd.Series(LinearRegression().fit(x_2d, y).predict(x_2d))
results = df.groupby('grp').apply(lambda df: df[['x1', 'x2']].apply(ols_res, y=df['y']))
print(results)

输出:

             x1        x2
grp                      
a   0 -0.126680  0.137907
    1 -0.441300 -0.595972
    2 -0.285903 -0.385033
    3 -0.252434  0.560938
    4 -0.046632 -0.718514
    5 -0.267396 -0.693155
    6 -0.364425 -0.476643
    7 -0.221493 -0.779082
    8 -0.203781  0.722860
    9 -0.106912 -0.090262
b   0 -0.015384  0.092137
    1  0.478447  0.032881
    2  0.366102  0.059832
    3 -0.055907  0.055388
    4 -0.221876  0.013941
    5 -0.054299  0.048263
    6  0.043979  0.024594
    7 -0.307831  0.059972
    8 -0.226570 -0.024809
    9  0.394460  0.038921

最新更新