将所有列在列火花上放置所有具有特殊条件的列



我有一个数据集,我需要删除等于0的标准偏差的列。我尝试了:

val df = spark.read.option("header",true)
  .option("inferSchema", "false").csv("C:/gg.csv")
val finalresult = df
  .agg(df.columns.map(stddev(_)).head, df.columns.map(stddev(_)).tail: _*)

我想计算每列的标准偏差,如果列等于零

,则将其丢弃
RowNumber,Poids,Age,Taille,0MI,Hmean,CoocParam,LdpParam,Test2,Classe,
0,87,72,160,5,0.6993,2.9421,2.3745,3,4,
1,54,70,163,5,0.6301,2.7273,2.2205,3,4,
2,72,51,164,5,0.6551,2.9834,2.3993,3,4,
3,75,74,170,5,0.6966,2.9654,2.3699,3,4,
4,108,62,165,5,0.6087,2.7093,2.1619,3,4,
5,84,61,159,5,0.6876,2.938,2.3601,3,4,
6,89,64,168,5,0.6757,2.9547,2.3676,3,4,
7,75,72,160,5,0.7432,2.9331,2.3339,3,4,
8,64,62,153,5,0.6505,2.7676,2.2255,3,4,
9,82,58,159,5,0.6748,2.992,2.4043,3,4,
10,67,49,160,5,0.6633,2.9367,2.333,3,4,
11,85,53,160,5,0.6821,2.981,2.3822,3,4,

您可以尝试此操作,使用getValueMapfilter获取要删除的列名,然后删除它们:

//Extract the standard deviation from the data frame summary:    
val stddev = df.describe().filter($"summary" === "stddev").drop("summary").first()
// Use `getValuesMap` and `filter` to get the columns names where stddev is equal to 0:    
val to_drop = stddev.getValuesMap[String](df.columns).filter{ case (k, v) => v.toDouble == 0 }.keys
//Drop 0 stddev columns    
df.drop(to_drop.toSeq: _*).show
+---------+-----+---+------+------+---------+--------+
|RowNumber|Poids|Age|Taille| Hmean|CoocParam|LdpParam|
+---------+-----+---+------+------+---------+--------+
|        0|   87| 72|   160|0.6993|   2.9421|  2.3745|
|        1|   54| 70|   163|0.6301|   2.7273|  2.2205|
|        2|   72| 51|   164|0.6551|   2.9834|  2.3993|
|        3|   75| 74|   170|0.6966|   2.9654|  2.3699|
|        4|  108| 62|   165|0.6087|   2.7093|  2.1619|
|        5|   84| 61|   159|0.6876|    2.938|  2.3601|
|        6|   89| 64|   168|0.6757|   2.9547|  2.3676|
|        7|   75| 72|   160|0.7432|   2.9331|  2.3339|
|        8|   64| 62|   153|0.6505|   2.7676|  2.2255|
|        9|   82| 58|   159|0.6748|    2.992|  2.4043|
|       10|   67| 49|   160|0.6633|   2.9367|   2.333|
|       11|   85| 53|   160|0.6821|    2.981|  2.3822|
+---------+-----+---+------+------+---------+--------+

好吧,我写了一个独立于您数据集的解决方案。所需的导入和示例数据:

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{lit, stddev, col}
val df = spark.range(1, 1000).withColumn("X2", lit(0)).toDF("X1","X2")
df.show(5)
// +---+---+
// | X1| X2|
// +---+---+
// |  1|  0|
// |  2|  0|
// |  3|  0|
// |  4|  0|
// |  5|  0|

首先按列计算标准偏差:

// no need to rename but I did it to become more human 
// readable when you show df2
val aggs = df.columns.map(c => stddev(c).as(c)) 
val stddevs = df.select(aggs: _*)
stddevs.show // df2 contains the stddev of each columns
// +-----------------+---+
// |               X1| X2|
// +-----------------+---+
// |288.5307609250702|0.0|
// +-----------------+---+

收集第一行和过滤列以保持:

val columnsToKeep: Seq[Column] = stddevs.first  // Take first row
  .toSeq  // convert to Seq[Any]
  .zip(df.columns)  // zip with column names
  .collect {
    // keep only names where stddev != 0
    case (s: Double, c) if s != 0.0  => col(c) 
  }

选择并检查结果:

df.select(columnsToKeep: _*).show
// +---+
// | X1|
// +---+
// |  1|
// |  2|
// |  3|
// |  4|
// |  5|

相关内容

  • 没有找到相关文章

最新更新