Spark:分布式系统性能差.如何改进



我写了一个简单的Spark程序,想把它部署到分布式服务器上。这很简单:

获取数据->整理数据->训练数据->重新应用查看训练结果

输入数据只有10K行,有3个特征。我首先在本地机器上运行,使用"local[*]"。它只运行大约3分钟。现在,当我部署到集群时,它运行得非常慢:半个小时都没有完成。在训练阶段,它变得非常缓慢。

我很好奇,如果我做错了什么。请帮我查一下。我使用Spark 1.6.1.

I submit:

spark-submit --packages com.databricks:spark-csv_2.11:1.5.0  orderprediction_2.11-1.0.jar --driver-cores 1 --driver-memory 4g --executor-cores 8 --executor-memory 4g

代码如下:

 def main(args: Array[String]) {
    // Set the log level to only print errors
    Logger.getLogger("org").setLevel(Level.ERROR)
    val conf = new SparkConf()
        .setAppName("My Prediction")
        //.setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    val data = sqlContext.read
        .option("header","true")
        .option("delimiter", "t")
        .format("com.databricks.spark.csv")
        .option("inferSchema","true")
        .load("mydata.txt")
    data.printSchema()
    data.show()
    val dataDF = data.toDF().filter("clicks >=10")
    dataDF.show()
    val assembler = new VectorAssembler()
      .setInputCols(Array("feature1", "feature2", "feature3"))
      .setOutputCol("features")
    val trainset = assembler.transform(dataDF).select("target", "features")
    trainset.printSchema()
    val trainset2 = trainset.withColumnRenamed("target", "label")
    trainset2.printSchema()
    val trainset3 = trainset2.withColumn("label", trainset2.col("label").cast(DataTypes.DoubleType))
    trainset3.cache() // cache data into memory
    trainset3.printSchema()
    trainset3.show()
    // Train a RandomForest model.
    println("training Random Forest")
    val rf = new RandomForestRegressor()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(1000)
    val rfmodel = rf.fit(trainset3)
    println("prediction")
    val result = rfmodel.transform(trainset3)
    result.show()
}

更新:经过调查,我发现它在

卡住了
collectAsMap at RandomForest.scala:525

这条线已经花了1.1个小时,还没有完工。数据,我相信只有几兆。

你正在构建一个由1000个随机树组成的随机森林,它将训练1000个实例。

在代码中,collectAsMap是第一个动作,而其余的都是转换(惰性计算)。所以当你看到挂在那一行时这是因为现在所有的maps, flatMaps, filters, groupBy,等都被求值了

最新更新