为Spark ML编码转置或旋转一组分类变量的最佳方式



我正在为spark ML模型设置分类变量。我有一个包含分类变量数组的列,而不是包含单个分类变量的列。请参阅下面的示例数据。

(尽管这些是数字,但它们代表了一个类别)。

我需要将这些特性分离为单独的特性,例如,重要的是要保留#1、#3、#6和#7具有类别19,而不管阵列中还有哪些其他类别。

我可以使用SQL手动识别所有分类变量,并为每个变量创建一列。但这似乎并不优雅,我认为必须有一种更好的方法来将所有类别集中到列,然后指定一个1或0,这可能是一个热编码。或者,我想知道是否有更好的方法来思考这个问题。

我使用的是scala 2.2.0(目前无法升级),所以我不能使用更新的数组函数。

+---------------+----------------+
|id             |categorical_code|
+---------------+----------------+
|1              |           [19] |
|2              |       [87, 19] |
|3              |           [18] |
|4              |           [96] |
|5              |           [18] |
|6              |  [111, 22, 19] |
|7              |  [161, 19, 18] |
|8              |           [12] |
|9              |          [170] |
+---------------+----------------+

需要的输出(我认为)类似于:

id,cat_12,cat_18,cat_19,cat_22,cat_87,cat_111,cat_161,cat_170
1,,,1,,,,,
2,,,1,,1,,,
3,,1,,,,,,
4,,,,,,,,
5,,1,,,,,,
6,,,1,1,,1,1,
7,,1,1,,,,,
8,1,,,,,,,1
9,,,,,,,,

我们可以将数组分解成单独的行,然后使用groupby pivot来获得所需的输出。

val df2 =
df.
select(
df("id"),
explode(df("categorical_code")).as("categorical_code"),
lit(1).as("categorical_code_exist")
)
df2.show()
+---+----------------+----------------------+
| id|categorical_code|categorical_code_exist|
+---+----------------+----------------------+
|  1|              19|                     1|
|  2|              87|                     1|
|  2|              19|                     1|
|  3|              18|                     1|
|  4|              96|                     1|
|  5|              18|                     1|
|  6|             111|                     1|
|  6|              22|                     1|
|  6|              19|                     1|
|  7|             161|                     1|
|  7|              19|                     1|
|  7|              18|                     1|
|  8|              12|                     1|
|  9|             170|                     1|
+---+----------------+----------------------+
val df3 =
df2.
groupBy("id").
pivot("categorical_code").
agg(coalesce(first(df2("categorical_code_exist")))).
orderBy("id")
df3.show()
+---+----+----+----+----+----+----+----+----+----+
| id|  12|  18|  19|  22|  87|  96| 111| 161| 170|
+---+----+----+----+----+----+----+----+----+----+
|  1|null|null|   1|null|null|null|null|null|null|
|  2|null|null|   1|null|   1|null|null|null|null|
|  3|null|   1|null|null|null|null|null|null|null|
|  4|null|null|null|null|null|   1|null|null|null|
|  5|null|   1|null|null|null|null|null|null|null|
|  6|null|null|   1|   1|null|null|   1|null|null|
|  7|null|   1|   1|null|null|null|null|   1|null|
|  8|   1|null|null|null|null|null|null|null|null|
|  9|null|null|null|null|null|null|null|null|   1|
+---+----+----+----+----+----+----+----+----+----+
df3.printSchema()
root
|-- id: integer (nullable = true)
|-- 12: integer (nullable = true)
|-- 18: integer (nullable = true)
|-- 19: integer (nullable = true)
|-- 22: integer (nullable = true)
|-- 87: integer (nullable = true)
|-- 96: integer (nullable = true)
|-- 111: integer (nullable = true)
|-- 161: integer (nullable = true)
|-- 170: integer (nullable = true)

相关内容

  • 没有找到相关文章

最新更新