我正在尝试使用逻辑回归制作分类器根据像素(特征)的值预测右数字(标签)。
我在 Java 中使用 Apache Spark,并且在将 mnist 数据库中的数据转换为 libsvm 格式后,我正在使用它,这是我的代码:
package ml;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;
public class MNIST5 {
static String trainImagesPath = "train-images.idx3-ubyte";
static String trainLabelsPath = "train-labels.idx1-ubyte";
static String testImagesPath = "t10k-images.idx3-ubyte";
static String testLabelsPath = "t10k-labels.idx1-ubyte";
static SparkConf conf = new SparkConf()
.setMaster("local")
.setAppName("Machine learning - MNIST Example");
static SparkContext sc = SparkContext.getOrCreate(conf);
public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
mnist_spark_logistic_regression();
//saveMnistDataLibsvmFormat();
}
static void mnist_spark_logistic_regression(){
long t;
System.out.println("Loading training data ...");
t = System.currentTimeMillis();
JavaRDD<LabeledPoint> trainData = MLUtils.loadLibSVMFile(sc, "mnist-train-data.txt").toJavaRDD();
System.out.println(System.currentTimeMillis()-t+" ms"); // 6661 ms
System.out.println("Training logistic regression classifier ...");
t = System.currentTimeMillis();
// Run training algorithm to build the model.
LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS()
.setNumClasses(10);
//lr.optimizer().setUpdater(new L1Updater());
LogisticRegressionModel model = lr.run(trainData.rdd());
System.out.println(System.currentTimeMillis()-t+" ms"); // 1951 ms
// print weights and intercept
System.out.println("numClasses: "+model.numClasses());
System.out.println("numFeatures: "+model.numFeatures());
System.out.println("Weights: "+model.weights());
System.out.println("Wlength: "+model.weights().size());
System.out.println("Intercept: "+model.intercept());
System.out.println("Loading testing data ...");
t = System.currentTimeMillis();
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, "mnist-test-data.txt").toJavaRDD();
System.out.println(System.currentTimeMillis()-t+" ms"); // 11356 ms
System.out.println("Compute raw scores on the test set ...");
t = System.currentTimeMillis();
// Compute raw scores on the test set.
JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(
(p) -> {
return new Tuple2<>(model.predict(p.features()), p.label());
}
);
System.out.println(System.currentTimeMillis()-t+" ms"); // 47 ms
System.out.println("Iterate ...");
t = System.currentTimeMillis();
JavaRDD<Integer> wyw = testData.map(new Function<LabeledPoint, Integer>() {
@Override
public Integer call(LabeledPoint t1) throws Exception {
double yb = model.predict(t1.features());
if(yb==t1.label())
System.out.println("label: "+t1.label()+", predicted: "+yb);
return 0;
}
});
wyw.collect();
System.out.println(System.currentTimeMillis()-t+" ms");
System.out.println("Evaluating ...");
t = System.currentTimeMillis();
// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
double accuracy = metrics.accuracy();
System.out.println("Accuracy = " + accuracy); // 0.098
System.out.println(System.currentTimeMillis()-t+" ms"); // 1108 ms
// Save and load model
model.save(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
System.out.println(sameModel);
}
static ArrayList<LabeledPoint> getData(String imagesPath, String labelsPath){
JavaRDD<LabeledPoint> data;
ArrayList<LabeledPoint> lpts = new ArrayList<>();
FileInputStream inImage = null;
FileInputStream inLabel = null;
try {
inImage = new FileInputStream(imagesPath);
inLabel = new FileInputStream(labelsPath);
int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfRows = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
int numberOfPixels = numberOfRows * numberOfColumns;
double[] imgPixels = new double[numberOfPixels];
for(int i = 0; i < numberOfImages; i++) {
//if(i % 100 == 0) {System.out.println("Number of images extracted: " + i);}
for(int p = 0; p < numberOfPixels; p++) {
imgPixels[p] = inImage.read();
}
int label = inLabel.read();
LabeledPoint lp = LabeledPoint.apply(label, Vectors.dense(imgPixels));
lpts.add(lp);
}
}
catch (FileNotFoundException e) { e.printStackTrace(); }
catch (IOException e) { e.printStackTrace(); }
finally {
if (inImage != null) {
try {
inImage.close();
} catch (IOException e) { e.printStackTrace(); }
}
if (inLabel != null) {
try {
inLabel.close();
} catch (IOException e) { e.printStackTrace(); }
}
}
return lpts;
}
static JavaRDD<LabeledPoint> loadData(String imagesPath, String labelsPath){
ArrayList<LabeledPoint> lpts = getData(imagesPath, labelsPath);
JavaSparkContext jsc = new JavaSparkContext(sc);
JavaRDD<LabeledPoint> data = jsc.parallelize(lpts);
return data;
}
static void saveMnistDataLibsvmFormat() throws FileNotFoundException, UnsupportedEncodingException{
ArrayList<LabeledPoint> data = getData(testImagesPath, testLabelsPath);
PrintWriter writer = new PrintWriter("mnist-test-data.txt", "UTF-8");
for(LabeledPoint lp : data){
StringBuilder s = new StringBuilder();
s.append(lp.label()).append(" ");
int i;
double[] arr = lp.features().toArray();
for(i=0;i<arr.length-1;i++)
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]).append(" ");
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]);
writer.println(s.toString());
}
// writer.println("The first line");
// writer.println("The second line");
writer.close();
ArrayList<LabeledPoint> data2 = getData(trainImagesPath, trainLabelsPath);
PrintWriter writer2 = new PrintWriter("mnist-train-data.txt", "UTF-8");
for(LabeledPoint lp : data2){
StringBuilder s = new StringBuilder();
s.append(lp.label()).append(" ");
int i;
double[] arr = lp.features().toArray();
for(i=0;i<arr.length-1;i++)
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]).append(" ");
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]);
writer2.println(s.toString());
}
// writer.println("The first line");
// writer.println("The second line");
writer2.close();
}
}
权重的值都等于零,我不明白为什么?请帮忙,谢谢。
你的意思是什么
`if(yb==t1.label())
System.out.println("label: "+t1.label()+", predicted: "+yb);
return 0;`
它始终返回 0。