我在网上得到了这段代码,一个热编码了一个标签编码值数组。我特别不明白最后一行。请帮忙
我最初认为,只要 y 是 1,它就会用 1 替换该索引的值,但是,如何呢?
def read_dataset():
df = pd.read_csv("sonar.all-data.csv")
x = df[df.columns[0:60]].values
y = df[df.columns[60]]
encoder = LabelEncoder()
encoder.fit(y)
y = oneHotEncode(y)
return(x, y)
def oneHotEncode(labels):
n_labels = len(labels)
n_unique_labels = len(np.unique(labels))
oneHE = np.zeros((n_labels, n_unique_labels))
oneHE[np.arange(n_labels), labels] = 1
return oneHE
我期待这段代码是如何工作的,但我不明白 np.arange 的那行
np.arange()
类似于range()
,但创建了一个numpy数组。因此,如果您有 10 个标签,它将返回一个数组,其中包含从 0 到 9 的连续数字。这用于选择oneHE
数组的行(初始化后仅包含零(。labels
用于选择列。
因此,只需在所有行中选择相应的列并将值设置为 1。