用于分割二维数组以在train_test_split()中使用的函数error ValueError:太多的值无法解压缩


大家好,我不知道为什么这个函数不能按预期工作。如有任何指导,我们将不胜感激。我得到错误:
ValueError: too many values to unpack (expected 2)

编写一个函数,将一个二维numpy数组作为输入,并以(X_train, y_train), (X_test, y_test)的形式返回四个变量,其中(X_train, y_train)是训练集的特征+响应,(X-test, y_test)是测试集的特征-响应。

功能规范:

  • 应采用二维numpyarray作为输入
  • 应该拆分数组,使得X是年份,y是相应的总体
  • 应返回两个形式为(X_train, y_train), (X_test, y_test)tuples
  • 应将sklearn的train_testrongplit函数与test_size = 0.2random_state = 42一起使用

Numpy数组输入:

array([[  1960,  54211],
[  1961,  55438],
[  1962,  56225],
[  1963,  56695],
[  1964,  57032],
[  1965,  57360],
[  1966,  57715],
[  1967,  58055],
[  1968,  58386],
[  1969,  58726],
[  1970,  59063],
[  1971,  59440]...

我的功能:

def feature_response_split(arr):
X, y = np.split(arr, 2, axis=1)
(X_train, y_train), (X_test, y_test) = train_test_split(X, y, test_size = 0.2, random_state = 42)
return (X_train, y_train), (X_test, y_test)

输入代码:(不可更改(

data = get_year_pop('Aruba')
(X_train, y_train), (X_test, y_test) = feature_response_split(data)

预期输出:

X_train == array([1996, 1991, 1968, 1977, 1966, 1964, 2001, 1979, 1990, 2009, 2010,
2014, 1975, 1969, 1987, 1986, 1976, 1984, 1993, 2015, 2000, 1971,
1992, 2016, 2003, 1989, 2013, 1961, 1981, 1962, 2005, 1999, 1995,
1983, 2007, 1970, 1982, 1978, 2017, 1980, 1967, 2002, 1974, 1988,
2011, 1998])
y_train == array([ 83200,  64622,  58386,  60366,  57715,  57032,  92898,  59980,
62149, 101453, 101669, 103795,  60657,  58726,  61833,  62644,
60586,  62836,  72504, 104341,  90853,  59440,  68235, 104822,
97017,  61032, 103187,  55438,  60567,  56225, 100031,  89005,
80324,  62201, 101220,  59063,  61345,  60103, 105264,  60096,
58055,  94992,  60528,  61079, 102053,  87277])
X_test == array([1960, 1965, 1994, 1973, 2004, 2012, 1997, 1985, 2006, 1972, 2008,
1963])
y_test == array([ 54211,  57360,  76700,  60243,  98737, 102577,  85451,  63026,
100832,  59840, 101353,  56695])

这是不正确的:

(X_train, y_train), (X_test, y_test) = train_test_split(X, y, ...)

根据文件,它应该是

X_train, X_test, y_train, y_test = train_test_split(X, y, ...)

相关内容

最新更新