读取csv文件时出错:将列从字符串转换为浮点数



我试图读取一个csv文件,其中包含一个列,SpType,其中有字符串值。我的变量被转换成一个对象,但我需要它是float类型。下面是代码片段:

data = pd.read_csv("/content/Star3642_balanced.csv")
X_orig = data[["Vmag", "Plx", "e_Plx", "B-V", "SpType", "Amag"]].to_numpy()

下面是错误的原因:

X = torch.tensor(X_orig, dtype=torch.float32)

错误读取"can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."

我试着在阅读csv文件后这样做,但它没有帮助:

data["SpType"] = data.SpType.astype(float)
有没有人能告诉我应该怎么做?

字符串应该被编码成数值。最简单的方法是使用Pandas的one-hot编码(在这种情况下会创建许多额外的列,但是神经网络应该可以毫不费力地处理它们):

ohe = pd.get_dummies(data["SpType"], drop_first=True)
data[ohe.columns] = ohe
data = data.drop(["SpType"], axis=1)

或者,您可以使用sklearn编码器或category_encoders库——更复杂的编码可能需要单独处理测试集,以避免目标泄漏。

最新更新