使用Scala,如何将dataFrame拆分为具有相同列值的多个dataFrame(无论是数组还是集合)。例如,我想分割以下DataFrame:
ID Rate State
1 24 AL
2 35 MN
3 46 FL
4 34 AL
5 78 MN
6 99 FL
:
数据集1
ID Rate State
1 24 AL
4 34 AL
数据集2
ID Rate State
2 35 MN
5 78 MN
数据集3
ID Rate State
3 46 FL
6 99 FL
您可以收集唯一的状态值并简单地映射结果数组:
val states = df.select("State").distinct.collect.flatMap(_.toSeq)
val byStateArray = states.map(state => df.where($"State" <=> state))
或映射:
val byStateMap = states
.map(state => (state -> df.where($"State" <=> state)))
.toMap
在Python中也是一样:
from itertools import chain
from pyspark.sql.functions import col
states = chain(*df.select("state").distinct().collect())
# PySpark 2.3 and later
# In 2.2 and before col("state") == state)
# should give the same outcome, ignoring NULLs
# if NULLs are important
# (lit(state).isNull() & col("state").isNull()) | (col("state") == state)
df_by_state = {state:
df.where(col("state").eqNullSafe(state)) for state in states}
这里明显的问题是,它需要对每个级别进行完整的数据扫描,因此这是一个昂贵的操作。如果您正在寻找分割输出的方法,请参见如何将RDD分割为两个或更多RDD ?特别地,您可以将Dataset
按照感兴趣的列进行分区:
val path: String = ???
df.write.partitionBy("State").parquet(path)
并在需要时回读:
// Depend on partition prunning
for { state <- states } yield spark.read.parquet(path).where($"State" === state)
// or explicitly read the partition
for { state <- states } yield spark.read.parquet(s"$path/State=$state")
根据数据的大小、分割的级别、存储和输入的持久性级别,它可能比多个过滤器更快或更慢。
如果您将数据框架作为临时表,这将非常简单(如果spark版本是2)。
df1.createOrReplaceTempView("df1")
现在你可以做查询了,
var df2 = spark.sql("select * from df1 where state = 'FL'")
var df3 = spark.sql("select * from df1 where state = 'MN'")
var df4 = spark.sql("select * from df1 where state = 'AL'")
现在你得到了df2, df3, df4。如果您想将它们作为列表,您可以使用
df2.collect()
df3.collect()
甚至映射/过滤功能。请参考https://spark.apache.org/docs/latest/sql-programming-guide.html#datasets-and-dataframes
火山灰你可以使用…
var stateDF = df.select("state").distinct() // to get states in a df
val states = stateDF.rdd.map(x=>x(0)).collect.toList //to get states in a list
for (i <- states) //loop to get each state
{
var finalDF = sqlContext.sql("select * from table1 where state = '" + state
+"' ")
}