我正在尝试将实际包含双精度值的csv字符串转换为与spark-ml兼容的数据集。由于我事先不知道预期的功能数量,我决定使用一个辅助类"Instance",它已经包含分类器要使用的正确数据类型,并且在其他一些情况下已经按预期工作:
public class Instance implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6091606543088855593L;
private Vector indexedFeatures;
private double indexedLabel;
...getters and setters for both fields...
}
我得到意外行为的部分是这个:
Encoder<Instance> encoder = Encoders.bean(Instance.class);
System.out.println("encoder.schema()");
encoder.schema().printTreeString();
Dataset<Instance> dfInstance = df.select("value").as(Encoders.STRING())
.flatMap(s -> {
String[] splitted = s.split(",");
int length = splitted.length;
double[] features = new double[length-1];
for (int i=0; i<length-1; i++) {
features[i] = Double.parseDouble(splitted[i]);
}
if (length < 2) {
return Collections.emptyIterator();
} else {
return Collections.singleton(new Instance(
Vectors.dense(features),
Double.parseDouble(splitted[length-1])
)).iterator();
}
}, encoder);
System.out.println("dfInstance");
dfInstance.printSchema();
dfInstance.show(5);
我在控制台上得到以下输出:
encoder.schema()
root
|-- indexedFeatures: vector (nullable = true)
|-- indexedLabel: double (nullable = false)
dfInstance
root
|-- indexedFeatures: struct (nullable = true)
|-- indexedLabel: double (nullable = true)
+---------------+------------+
|indexedFeatures|indexedLabel|
+---------------+------------+
| []| 0.0|
| []| 0.0|
| []| 1.0|
| []| 0.0|
| []| 1.0|
+---------------+------------+
only showing top 5 rows
编码器架构正确地将索引特征行数据类型显示为向量。但是当我应用编码器并进行转换时,它会给我一行结构类型,不包含任何真实对象。
我想了解为什么 Spark 为我提供结构类型而不是正确的矢量类型。
实际上,我的答案并不是解释为什么你会得到一个结构类型。但是基于上一个问题,我可能会提供一种解决方法。
原始输入使用 DataFrameReader 的 csv 函数解析,再次使用 VectorAssembler:
Dataset<Row> csv = spark.read().option("inferSchema", "true")
.csv(inputDf.select("value").as(Encoders.STRING()));
String[] fieldNames = csv.schema().fieldNames();
VectorAssembler assembler = new VectorAssembler().setInputCols(
Arrays.copyOfRange(fieldNames, 0, fieldNames.length-1))
.setOutputCol("indexedFeatures");
Dataset<Row> result = assembler.transform(csv)
.withColumn("indexedLabel", functions.col(fieldNames[fieldNames.length-1]))
.select("indexedFeatures", "indexedLabel");