聚合pyspark中的One Hot Encoded功能



我对python很有经验,但对pyspark完全陌生。我有一个数据框架,它包含大约50M行,具有几个分类特征。对于每个功能,我都有一个热编码。下面是一个简化但具有代表性的代码示例。

从pyspark.ml.feature导入StringIndexer,OneHotEncoder从pyspark.ml导入管道

df = sc.parallelize([
(1, 'grocery'),
(1, 'drinks'),
(1, 'bakery'),
(2, 'grocery'),
(3, 'bakery'),
(3, 'bakery'),
]).toDF(["id", "category"])
indexer = StringIndexer(inputCol='category', outputCol='categoryIndex')
encoder = OneHotEncoder(inputCol='categoryIndex', outputCol='categoryVec')
pipe = Pipeline(stages = [indexer, encoder])
newDF = pipe.fit(df).transform(df)

给出输出

+---+--------+-------------+-------------+
| id|category|categoryIndex|  categoryVec|
+---+--------+-------------+-------------+
|  1| grocery|          1.0|(2,[1],[1.0])|
|  1|  drinks|          2.0|    (2,[],[])|
|  1|  bakery|          0.0|(2,[0],[1.0])|
|  2| grocery|          1.0|(2,[1],[1.0])|
|  3|  bakery|          0.0|(2,[0],[1.0])|
|  3|  bakery|          0.0|(2,[0],[1.0])|
+---+--------+-------------+-------------+

我现在想groupBy"id",并用sum聚合"categoryVec"列,这样我就可以为每个id获得一行,并用一个向量指示客户正在购物的(可能是几个(类别中的哪一个。在panda中,这只是对pd.get_dummies()步骤中产生的每一列应用sum/me均值的情况,但在这里似乎并不那么简单。

然后,我将把输出传递给ML算法,这样我就需要能够在输出上使用VectorAssembler或类似的工具。

哦,我真的需要一个pyspark解决方案。

非常感谢你的帮助!

您可以为此使用Counvectorizer。它将类别索引数组转换为编码向量。

from pyspark.ml.feature import CountVectorizer
from pyspark.ml import Pipeline
from pyspark.sql.window import Window
from pyspark.sql import functions as F

df = sc.parallelize([
(1, 'grocery'),
(1, 'drinks'),
(1, 'bakery'),
(2, 'grocery'),
(3, 'bakery'),
(3, 'bakery'),
]).toDF(["id", "category"]) 
.groupBy('id') 
.agg(F.collect_list('category').alias('categoryIndexes'))
cv = CountVectorizer(inputCol='categoryIndexes', outputCol='categoryVec')
transformed_df = cv.fit(df).transform(df)
transformed_df.show()

结果:

+---+--------------------+--------------------+
| id|     categoryIndexes|         categoryVec|
+---+--------------------+--------------------+
|  1|[grocery, drinks,...|(3,[0,1,2],[1.0,1...|
|  3|    [bakery, bakery]|       (3,[0],[2.0])|
|  2|           [grocery]|       (3,[1],[1.0])|
+---+--------------------+--------------------+

最新更新