我在Spark 2.1中涉及余弦/列相似性有第二个问题。我对Scala和所有火花环境有点新手,这对我来说并不清楚:
如何恢复Spark中RowMatrix的每种组合的列相似性。这是我尝试的:
数据:
import org.apache.spark.sql.{SQLContext, Row, DataFrame}
import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType, DoubleType}
import org.apache.spark.sql.functions._
// rdd
val rowsRdd: RDD[Row] = sc.parallelize(
Seq(
Row(2.0, 7.0, 1.0),
Row(3.5, 2.5, 0.0),
Row(7.0, 5.9, 0.0)
)
)
// Schema
val schema = new StructType()
.add(StructField("item_1", DoubleType, true))
.add(StructField("item_2", DoubleType, true))
.add(StructField("item_3", DoubleType, true))
// Data frame
val df = spark.createDataFrame(rowsRdd, schema)
代码:
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix, RowMatrix}
val rows = new VectorAssembler().setInputCols(df.columns).setOutputCol("vs")
.transform(df)
.select("vs")
.rdd
val items_mllib_vector = rows.map(_.getAs[org.apache.spark.ml.linalg.Vector](0))
.map(org.apache.spark.mllib.linalg.Vectors.fromML)
val mat = new RowMatrix(items_mllib_vector)
val simsPerfect = mat.columnSimilarities()
println("Pairwise similarities are: " + simsPerfect.entries.collect.mkString(", "))
输出:
Pairwise similarities are: MatrixEntry(0,2,0.24759378423606918), MatrixEntry(1,2,0.7376189553526812), MatrixEntry(0,1,0.8355316482961213)
所以我得到的是我的列和相似性的SimSperfect org.apache.spark.mllib.linalg.distributed.CoordinateMatrix
。如何将其转换回数据框并获取正确的列名称?
我首选的输出:
item_from | item_to | similarity
1 | 2 | 0.83 |
1 | 3 | 0.24 |
2 | 3 | 0.73 |
预先感谢
此方法也可以使用,而无需将行转换为字符串:
val transformedRDD = simsPerfect.entries.map{case MatrixEntry(row: Long, col:Long, sim:Double) => (row,col,sim)}
val dff = sqlContext.createDataFrame(transformedRDD).toDF("item_from", "item_to", "sim")
在哪里,我认为val sqlContext = new org.apache.spark.sql.SQLContext(sc)
已经定义,sc
是SparkContext。
我找到了解决我问题的解决方案:
//Transform result to rdd
val transformedRDD = simsPerfect.entries.map{case MatrixEntry(row: Long, col:Long, sim:Double) => Array(row,col,sim).mkString(",")}
//Transform rdd[String] to rdd[Row]
val rdd2 = transformedRDD.map(a => Row(a))
// to DF
val dfschema = StructType(Array(StructField("value",StringType)))
val rddToDF = spark.createDataFrame(rdd2,dfschema)
//create new DF with schema
val newdf = rddToDF.select(expr("(split(value, ','))[0]").cast("string").as("item_from")
,expr("(split(value, ','))[1]").cast("string").as("item_to")
,expr("(split(value, ','))[2]").cast("string").as("sim"))
我敢肯定还有另一种更简单的方法可以做到这一点,但是我很高兴它有效。