Logistic regression in apache spark mllib - mnist



我正在尝试使用逻辑回归制作分类器根据像素(特征)的值预测右数字(标签)。
我在 Java 中使用 Apache Spark,并且在将 mnist 数据库中的数据转换为 libsvm 格式后,我正在使用它,这是我的代码:

package ml;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;

public class MNIST5 {
    static String trainImagesPath = "train-images.idx3-ubyte";
    static String trainLabelsPath = "train-labels.idx1-ubyte";
    static String testImagesPath = "t10k-images.idx3-ubyte";
    static String testLabelsPath = "t10k-labels.idx1-ubyte";

    static SparkConf conf = new SparkConf()
            .setMaster("local")
            .setAppName("Machine learning - MNIST Example");
    static SparkContext sc = SparkContext.getOrCreate(conf);
    public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
        mnist_spark_logistic_regression();
        //saveMnistDataLibsvmFormat();
    }
    static void mnist_spark_logistic_regression(){
        long t;
        System.out.println("Loading training data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> trainData = MLUtils.loadLibSVMFile(sc, "mnist-train-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 6661 ms

        System.out.println("Training logistic regression classifier ...");
        t = System.currentTimeMillis();
        // Run training algorithm to build the model.         
        LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS()           
            .setNumClasses(10);
        //lr.optimizer().setUpdater(new L1Updater());        
        LogisticRegressionModel model = lr.run(trainData.rdd());
        System.out.println(System.currentTimeMillis()-t+" ms"); // 1951 ms
        // print weights and intercept
        System.out.println("numClasses: "+model.numClasses());
        System.out.println("numFeatures: "+model.numFeatures());
        System.out.println("Weights: "+model.weights());
        System.out.println("Wlength: "+model.weights().size());
        System.out.println("Intercept: "+model.intercept());

        System.out.println("Loading testing data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, "mnist-test-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 11356 ms

        System.out.println("Compute raw scores on the test set ...");
        t = System.currentTimeMillis();
        // Compute raw scores on the test set.
        JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(
            (p) -> {
                return new Tuple2<>(model.predict(p.features()), p.label());
            }
        );
        System.out.println(System.currentTimeMillis()-t+" ms"); // 47 ms

        System.out.println("Iterate ...");
        t = System.currentTimeMillis();
        JavaRDD<Integer> wyw = testData.map(new Function<LabeledPoint, Integer>() {
            @Override
            public Integer call(LabeledPoint t1) throws Exception {
                double yb = model.predict(t1.features());
                if(yb==t1.label())
                    System.out.println("label: "+t1.label()+", predicted: "+yb);
                return 0;
            }
        });
        wyw.collect();
        System.out.println(System.currentTimeMillis()-t+" ms");

        System.out.println("Evaluating ...");
        t = System.currentTimeMillis();
        // Get evaluation metrics.
        MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
        double accuracy = metrics.accuracy();
        System.out.println("Accuracy = " + accuracy); // 0.098
        System.out.println(System.currentTimeMillis()-t+" ms"); // 1108 ms

        // Save and load model
        model.save(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        System.out.println(sameModel);
    }

    static ArrayList<LabeledPoint> getData(String imagesPath, String labelsPath){        
        JavaRDD<LabeledPoint> data;
        ArrayList<LabeledPoint> lpts = new ArrayList<>();
        FileInputStream inImage = null;
        FileInputStream inLabel = null;
        try {
            inImage = new FileInputStream(imagesPath);
            inLabel = new FileInputStream(labelsPath);
            int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfRows  = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
            int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
            int numberOfPixels = numberOfRows * numberOfColumns;
            double[] imgPixels = new double[numberOfPixels];
            for(int i = 0; i < numberOfImages; i++) {
                //if(i % 100 == 0) {System.out.println("Number of images extracted: " + i);}
                for(int p = 0; p < numberOfPixels; p++) {
                    imgPixels[p] = inImage.read();
                }
                int label = inLabel.read();
                LabeledPoint lp = LabeledPoint.apply(label, Vectors.dense(imgPixels));
                lpts.add(lp);
            }
        } 
        catch (FileNotFoundException e) { e.printStackTrace(); } 
        catch (IOException e) { e.printStackTrace(); } 
        finally {
            if (inImage != null) {
                try {
                    inImage.close();
                } catch (IOException e) { e.printStackTrace(); }
            }
            if (inLabel != null) {
                try {
                    inLabel.close();
                } catch (IOException e) { e.printStackTrace(); }
            }
        }
        return lpts;
    }
    static JavaRDD<LabeledPoint> loadData(String imagesPath, String labelsPath){        
        ArrayList<LabeledPoint> lpts = getData(imagesPath, labelsPath);
        JavaSparkContext jsc = new JavaSparkContext(sc);
        JavaRDD<LabeledPoint> data = jsc.parallelize(lpts);
        return data;
    }
    static void saveMnistDataLibsvmFormat() throws FileNotFoundException, UnsupportedEncodingException{
        ArrayList<LabeledPoint> data = getData(testImagesPath, testLabelsPath);        
        PrintWriter writer = new PrintWriter("mnist-test-data.txt", "UTF-8");
        for(LabeledPoint lp : data){
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer.println(s.toString());
        }
        // writer.println("The first line");
        // writer.println("The second line");
        writer.close();
        ArrayList<LabeledPoint> data2 = getData(trainImagesPath, trainLabelsPath);
        PrintWriter writer2 = new PrintWriter("mnist-train-data.txt", "UTF-8");
        for(LabeledPoint lp : data2){
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer2.println(s.toString());
        }
        // writer.println("The first line");
        // writer.println("The second line");
        writer2.close();
    }
}


权重的值都等于零,我不明白为什么?请帮忙,谢谢。

你的意思是什么

            `if(yb==t1.label())
                System.out.println("label: "+t1.label()+", predicted: "+yb);
            return 0;`

它始终返回 0。

最新更新