在火花 scala 中减少到单列



我有下面的DF。这里的列是根据插槽动态创建的。(1,2,3,4等(

scala> df1.show
+---+-------+-------+-------+-------+-----+-----+-----+-----+
| ID|1_count|2_count|3_count|4_count|1_per|2_per|3_per|4_per|
+---+-------+-------+-------+-------+-----+-----+-----+-----+
|  1|      3|     11|     15|      3|   70|   80|  150|   20|
|  2|     19|      5|     15|      3|  150|   80|  200|   43|
|  3|     30|     15|     15|     39|   55|   80|  150|  200|
|  4|      8|     65|      3|      3|   70|   80|  150|   55|
+---+-------+-------+-------+-------+-----+-----+-----+-----+

创建者 -

val df1=Seq(
(1,3,11,15,3,70,80,150,20),
(2,19,5,15,3,150,80,200,43),
(3,30,15,15,39,55,80,150,200),
(4,8,65,3,3,70,80,150,55)
)toDF("ID","1_count","2_count","3_count","4_count","1_per","2_per","3_per","4_per")

我需要选择第一次出现 per 并计数,其中每 <100 计数>10。这应该是行级操作,即针对每个 ID。

预期产出

+---+-------+----+----+
| ID|  count|per |slot|
+---+-------+----+----+
|  1|     11|  80|  2 |
|  2|      0|  0 |  0 |
|  3|     30| 55 |  1 |
|  4|     65|  80|  2 |
+---+-------+----+----+

输出 ID 的逻辑是找到满足条件的第一列值(x_count ,x_per((其中每 <100 和计数>10(

干净的解决方案可以是使用UDF

val df1 = Seq(
(1, 3, 11, 15, 3, 70, 80, 150, 20),
(2, 19, 5, 15, 3, 150, 80, 200, 43),
(3, 30, 15, 15, 39, 55, 80, 150, 200),
(4, 8, 65, 3, 3, 70, 80, 150, 55)
)
.toDF("ID", "1_count", "2_count", "3_count", "4_count", "1_per", "2_per", "3_per", "4_per")
//combine all the count and per coluns in a array
val allCols = df1.columns.filter(_.contains("_count")).map(col)
.zip(df1.columns.filter(_.contains("_per")).map(col))
.map(x => array(x._1, x._2))
//UDF to extract the first matched or 0 in case of unmatched
val findFirstOccur = udf((all: Seq[Seq[Int]]) => {
all.zipWithIndex.filter(r => {
if (r._1(0) > 10 && r._1(1) < 100) true else false
}).map(x => (x._1(0), x._1(1), x._2 + 1)).headOption.getOrElse((0,0,0))
})
//New column with udf 
val df2 = df1.withColumn("a", findFirstOccur(array(allCols: _*)))
.select(($"ID"), $"a._1"as("count"), $"a._2".as("per"), $"a._3".as("slot"))

输出:

+---+-----+---+----+
|ID |count|per|slot|
+---+-----+---+----+
|1  |11   |80 |2   |
|2  |0    |0  |0   |
|3  |30   |55 |1   |
|4  |65   |80 |2   |
+---+-----+---+----+
  1. 加载数据帧
  2. 将列映射到数组列计数到计数Arr , perArr
  3. 添加行映射器以循环访问列并查找第一个匹配的条目
  4. 将行映射到新匹配的列或默认值(rowid,0,0,0(
import org.apache.spark.sql.functions._
import scala.collection.mutable
object PerCount {
def main(args: Array[String]): Unit = {
val spark = Constant.getSparkSess
import spark.implicits._
val df = List((1, 3, 11, 15, 3, 70, 80, 150, 20),
(2, 19, 5, 15, 3, 150, 80, 200, 43),
(3, 30, 15, 15, 39, 55, 80, 150, 200),
(4, 8, 65, 3, 3, 70, 80, 150, 55)
).toDF("ID", "1_count", "2_count", "3_count", "4_count", "1_per", "2_per", "3_per", "4_per")
val countArrayColumns = List("1_count", "2_count", "3_count", "4_count")
val perArrayColumns = List("1_per", "2_per", "3_per", "4_per")
df.withColumn("countArr", array(countArrayColumns.map(col): _*))
.withColumn("perArr", array(perArrayColumns.map(col): _*))
.map(row => {
val countArr = row.getAs[mutable.WrappedArray[Int]]("countArr")
val perArr = row.getAs[mutable.WrappedArray[Int]]("perArr")
val (position, count, per) = countArr.zipWithIndex
.filter(row => row._1 > 10 && perArr(row._2) < 100)
.map(row => (row._2 + 1, row._1, perArr(row._2)))
.headOption.getOrElse((0, 0, 0))
(row.getInt(0), count, per, position)
}).toDF("ID", "count", "per", "slot")
.show()
}
}

这是在Python中 - 希望你能够转换为scala -

%python
from pyspark.sql import functions as F
df1=spark.createDataFrame([
(1,3,11,15,3,70,80,150,20),
(2,19,5,15,3,150,80,200,43),
(3,30,15,15,39,55,80,150,200),
(4,8,65,3,3,70,80,150,55)], ["ID","1_count","2_count","3_count","4_count","1_per","2_per","3_per","4_per"])

df1  = df1.withColumn('slot', 
F.when(((F.col('1_count') > 10) & (F.col('1_per') < 100)), '1')
.when(((F.col('2_count') > 10) & (F.col('2_per') < 100)), '2')
.when(((F.col('3_count') > 10) & (F.col('3_per') < 100)), '3')
.when(((F.col('4_count') > 10) & (F.col('4_per') < 100)), '4')
.otherwise('0'))
df1 = df1.withColumn('count',
F.when((F.col('slot') == F.lit('1')),F.col('1_count'))
.when((F.col('slot') == F.lit('2')),F.col('2_count'))
.when((F.col('slot') == F.lit('3')),F.col('3_count'))
.when((F.col('slot') == F.lit('4')),F.col('4_count'))
.otherwise('0'))
df1 = df1.withColumn('per',
F.when((F.col('slot') == F.lit('1')),F.col('1_per'))
.when((F.col('slot') == F.lit('2')),F.col('2_per'))
.when((F.col('slot') == F.lit('3')),F.col('3_per'))
.when((F.col('slot') == F.lit('4')),F.col('4_per'))
.otherwise('0'))
df1 = df1.select('ID', 'count', 'per', 'slot')
df1.show()

-------输出---------

+---+-----+---+----+
| ID|count|per|slot|
+---+-----+---+----+
|  1|   11| 80|   2|
|  2|    0|  0|   0|
|  3|   30| 55|   1|
|  4|   65| 80|   2|
+---+-----+---+----+

这是我的尝试-

1. 没有 UDF

2.您可以添加任意数字(count, per)解决方案将按原样工作

加载提供的数据

val cols = Seq("ID", "1_count", "2_count", "3_count", "4_count", "1_per", "2_per", "3_per", "4_per")
val df1 = Seq(
(1, 3, 11, 15, 3, 70, 80, 150, 20),
(2, 19, 5, 15, 3, 150, 80, 200, 43),
(3, 30, 15, 15, 39, 55, 80, 150, 200),
(4, 8, 65, 3, 3, 70, 80, 150, 55)
) toDF (cols: _*)
df1.show(false)
df1.printSchema()
/**
* +---+-------+-------+-------+-------+-----+-----+-----+-----+
* |ID |1_count|2_count|3_count|4_count|1_per|2_per|3_per|4_per|
* +---+-------+-------+-------+-------+-----+-----+-----+-----+
* |1  |3      |11     |15     |3      |70   |80   |150  |20   |
* |2  |19     |5      |15     |3      |150  |80   |200  |43   |
* |3  |30     |15     |15     |39     |55   |80   |150  |200  |
* |4  |8      |65     |3      |3      |70   |80   |150  |55   |
* +---+-------+-------+-------+-------+-----+-----+-----+-----+
*
* root
* |-- ID: integer (nullable = false)
* |-- 1_count: integer (nullable = false)
* |-- 2_count: integer (nullable = false)
* |-- 3_count: integer (nullable = false)
* |-- 4_count: integer (nullable = false)
* |-- 1_per: integer (nullable = false)
* |-- 2_per: integer (nullable = false)
* |-- 3_per: integer (nullable = false)
* |-- 4_per: integer (nullable = false)
*/

创建array<struct>并使用filter函数对其进行过滤(火花>= 2.4.0(

val (countCols, perCols) = cols.filter(_ != "ID").partition(_.endsWith("count"))
val struct = countCols.zip(perCols).map { case (countCol, perCol) =>
expr(s"named_struct('count', $countCol, 'per', $perCol, " +
s"'slot', cast(substring_index('$countCol', '_', 1) as int))")
}
val processedDf = df1.select($"ID", array(struct: _*).as("count_per"))
.withColumn("count_per_p", coalesce(
expr("FILTER(count_per, x -> x.count > 10 and x.per < 100)[0]"),
expr("named_struct('count', 0,'per', 0, 'slot', 0)")
))
.selectExpr("ID", "count_per_p.*")
processedDf.show(false)
processedDf.printSchema()
/**
* +---+-----+---+----+
* |ID |count|per|slot|
* +---+-----+---+----+
* |1  |11   |80 |2   |
* |2  |0    |0  |0   |
* |3  |30   |55 |1   |
* |4  |65   |80 |2   |
* +---+-----+---+----+
*
* root
* |-- ID: integer (nullable = false)
* |-- count: integer (nullable = false)
* |-- per: integer (nullable = false)
* |-- slot: integer (nullable = true)
*/

最新更新