>我有一个这样的数据帧
data = [(("ID1", "A", 1)), (("ID1", "B", 5)), (("ID2", "A", 12)),
(("ID3", "A", 3)), (("ID3", "B", 3)), (("ID3", "C", 5)), (("ID4", "A", 10))]
df = spark.createDataFrame(data, ["ID", "Type", "Value"])
df.show()
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID1| A| 1|
|ID1| B| 5|
|ID2| A| 12|
|ID3| A| 3|
|ID3| B| 3|
|ID3| C| 5|
|ID4| A| 10|
+---+----+-----+
我只想提取那些只包含一个特定类型的行(或 ID) - "A">
因此,我的预期输出将包含以下行
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2| A| 1|
|ID4| A| 10|
+---+----+-----+
对于每个 ID 可以包含任何类型 - A、B、C 等。我想提取那些只包含一个且只有一个类型的 ID - 'A'
如何在 PySpark 中实现这一点
您可以对其应用过滤器。
import pyspark.sql.functions as f
data = [(("ID1", "A", 1)), (("ID1", "B", 5)), (("ID2", "A", 12)),
(("ID3", "A", 3)), (("ID3", "B", 3)), (("ID3", "C", 5)), (("ID4", "A", 10))]
df = spark.createDataFrame(data, ["ID", "Type", "Value"])
df.show()
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID1| A| 1|
|ID1| B| 5|
|ID2| A| 12|
|ID3| A| 3|
|ID3| B| 3|
|ID3| C| 5|
|ID4| A| 10|
+---+----+-----+
x= df.filter(f.col('Type')=='A')
x.show()
如果我们需要过滤所有只有一条记录的 ID,并且类型也为"A",那么下面的代码可能是解决方案
df.registerTempTable('table1')
sqlContext.sql('select a.ID, a.Type,a.Value from table1 as a, (select ID, count(*) as cnt_val from table1 group by ID) b where a.ID = b.ID and (a.Type=="A" and b.cnt_val ==1)').show()
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2| A| 12|
|ID4| A| 10|
+---+----+-----+
会有更好的替代方法来找到相同的。
按照OP的要求,我记下了我在评论下写的答案。
手头问题的目的是过滤掉每个特定ID
只有一个Type
A
元素而没有其他元素的DataFrame
。
# Loading the requisite packages
from pyspark.sql.functions import col, collect_set, array_contains, size, first
这个想法是首先通过ID
来aggregate()
DataFrame
,从而我们使用数组中的collect_set()
将所有unique
元素分组Type
。拥有unique
元素很重要,因为对于特定ID
可能会有两行,其中两行都有Type
作为A
。这就是为什么我们应该使用collect_set()
而不是collect_list()
,因为后者不会返回唯一元素,而是返回所有元素。
然后我们应该使用first()
来获取组Type
和Value
的第一个值。如果A
是特定ID
唯一可能unique
Type
,则first()
将返回唯一的值A
,以防A
发生一次,如果存在重复的A
,则返回最大值。
df = df = df.groupby(['ID']).agg(first(col('Type')).alias('Type'),
first(col('Value')).alias('Value'),
collect_set('Type').alias('Type_Arr'))
df.show()
+---+----+-----+---------+
| ID|Type|Value| Type_Arr|
+---+----+-----+---------+
|ID2| A| 12| [A]|
|ID3| A| 3|[A, B, C]|
|ID1| A| 1| [A, B]|
|ID4| A| 10| [A]|
+---+----+-----+---------+
最后,我们将同时放置 2 个条件来过滤掉所需的数据集。
条件 1:它使用array_contains()
检查Type
数组中是否存在A
。
条件 2:它检查数组的size
。如果大小大于 1,则应该有多个Types
。
df = df.where(array_contains(col('Type_Arr'),'A') & (size(col('Type_Arr'))==1)).drop('Type_Arr')
df.show()
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2| A| 12|
|ID4| A| 10|
+---+----+-----+
我不精通 Python,这里有一个 Scala 中可能的解决方案:
df.groupBy("ID").agg(collect_set("Type").as("Types"))
.select("ID").where((size($"Types")===1).and(array_contains($"Types", "A"))).show()
+---+
| ID|
+---+
|ID2|
|ID4|
+---+
这个想法是通过ID
进行分组,并仅过滤包含A
值的大小为 1 的Types
。