我正在为spark ML模型设置分类变量。我有一个包含分类变量数组的列,而不是包含单个分类变量的列。请参阅下面的示例数据。
(尽管这些是数字,但它们代表了一个类别)。
我需要将这些特性分离为单独的特性,例如,重要的是要保留#1、#3、#6和#7具有类别19,而不管阵列中还有哪些其他类别。
我可以使用SQL手动识别所有分类变量,并为每个变量创建一列。但这似乎并不优雅,我认为必须有一种更好的方法来将所有类别集中到列,然后指定一个1或0,这可能是一个热编码。或者,我想知道是否有更好的方法来思考这个问题。
我使用的是scala 2.2.0(目前无法升级),所以我不能使用更新的数组函数。
+---------------+----------------+
|id |categorical_code|
+---------------+----------------+
|1 | [19] |
|2 | [87, 19] |
|3 | [18] |
|4 | [96] |
|5 | [18] |
|6 | [111, 22, 19] |
|7 | [161, 19, 18] |
|8 | [12] |
|9 | [170] |
+---------------+----------------+
需要的输出(我认为)类似于:
id,cat_12,cat_18,cat_19,cat_22,cat_87,cat_111,cat_161,cat_170
1,,,1,,,,,
2,,,1,,1,,,
3,,1,,,,,,
4,,,,,,,,
5,,1,,,,,,
6,,,1,1,,1,1,
7,,1,1,,,,,
8,1,,,,,,,1
9,,,,,,,,
我们可以将数组分解成单独的行,然后使用groupby pivot来获得所需的输出。
val df2 =
df.
select(
df("id"),
explode(df("categorical_code")).as("categorical_code"),
lit(1).as("categorical_code_exist")
)
df2.show()
+---+----------------+----------------------+
| id|categorical_code|categorical_code_exist|
+---+----------------+----------------------+
| 1| 19| 1|
| 2| 87| 1|
| 2| 19| 1|
| 3| 18| 1|
| 4| 96| 1|
| 5| 18| 1|
| 6| 111| 1|
| 6| 22| 1|
| 6| 19| 1|
| 7| 161| 1|
| 7| 19| 1|
| 7| 18| 1|
| 8| 12| 1|
| 9| 170| 1|
+---+----------------+----------------------+
val df3 =
df2.
groupBy("id").
pivot("categorical_code").
agg(coalesce(first(df2("categorical_code_exist")))).
orderBy("id")
df3.show()
+---+----+----+----+----+----+----+----+----+----+
| id| 12| 18| 19| 22| 87| 96| 111| 161| 170|
+---+----+----+----+----+----+----+----+----+----+
| 1|null|null| 1|null|null|null|null|null|null|
| 2|null|null| 1|null| 1|null|null|null|null|
| 3|null| 1|null|null|null|null|null|null|null|
| 4|null|null|null|null|null| 1|null|null|null|
| 5|null| 1|null|null|null|null|null|null|null|
| 6|null|null| 1| 1|null|null| 1|null|null|
| 7|null| 1| 1|null|null|null|null| 1|null|
| 8| 1|null|null|null|null|null|null|null|null|
| 9|null|null|null|null|null|null|null|null| 1|
+---+----+----+----+----+----+----+----+----+----+
df3.printSchema()
root
|-- id: integer (nullable = true)
|-- 12: integer (nullable = true)
|-- 18: integer (nullable = true)
|-- 19: integer (nullable = true)
|-- 22: integer (nullable = true)
|-- 87: integer (nullable = true)
|-- 96: integer (nullable = true)
|-- 111: integer (nullable = true)
|-- 161: integer (nullable = true)
|-- 170: integer (nullable = true)