我正在尝试将LogisticRegression应用于我的数据集。
我已将数据拆分为训练、测试和验证。数据使用一种热编码进行规范化。我正在得到
ValueError: bad input shape (527, 2)
这是我的代码:
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
#read the data
train_data = pd.read_csv('ip2ttt_train.data',header=None)
test_data = pd.read_csv('ip2ttt_test.data', header=None)
valid_data = pd.read_csv('ip2ttt_valid.data', header=None)
#for valid dataset
valid_label = valid_data[9]
valid_features = valid_data.drop(columns =9)
#for test dataset
test_label = test_data[9]
test_features = test_data.drop(columns =9)
#for train dataset
train_label = train_data[9]
train_features = train_data.drop(columns =9)
X_valid = pd.get_dummies(valid_features)
y_valid = pd.get_dummies(valid_label)
X_test = pd.get_dummies(test_features)
y_test = pd.get_dummies(test_label)
X_train = pd.get_dummies(train_features)
y_train = pd.get_dummies(train_label)
clf = LogisticRegression(random_state=0, multi_class='multinomial', solver='newton-cg', penalty='l2') #penalty = L1 or L2 and solver = newton-cg or lbfgs
clf.fit(X_train, y_train)
下面是 X 和 y 的形状:
X_train.shape
(527, 27)
y_train.shape
(527, 2)
我尝试过:
我发现我需要改变y_train
的形状。我尝试将y_train
转换为np.array
并flatten()
它,但它不起作用。我想我需要(527,1)
形状。我也尝试了reshape([527,1])
但它给了我一个错误。我知道那件事
y:形状的阵列状(n_samples,(
相对于 X 的目标向量。
但不知道如何正确实现它。
更新:train_label
的示例数据:
0 positive
1 positive
2 positive
3 positive
4 positive
...
522 negative
523 negative
524 negative
525 negative
526 negative
Name: 9, Length: 527, dtype: object
train_features的示例数据
0 1 2 3 4 5 6 7 8
0 x x x x o o x o o
1 x x x x o o o x o
2 x x x x o o b o b
3 x x x x o b o o b
4 x x x x b o o b o
... ... ... ... ... ... ... ... ... ...
522 x o x o o x x x o
523 o x x x o o x o x
524 o x x x o o o x x
525 o x o x x o x o x
526 o x o x o x x o x
我试图在没有热编码的情况下将它们输入fit()
并得到错误:ValueError: could not convert string to float: 'x'
数据使用一种热编码进行规范化。
scikit-learn的LogisticRegression
不应该是这种情况;正如引用的文档所说:
y:形状的阵列状(n_samples,(
相对于 X 的目标向量。
您需要为所有标签(训练、验证、测试(提供(n_samples,)
形状。你应该删除所有用于定义y_train
、y_valid
和y_test
的pd.get_dummies()
命令,并分别使用train_label
、valid_label
和test_label
。