从 Java 中使用 DNNRegressor 生成的模型中读取权重



我正在寻找一些使用Java从Tensorflow模型中提取权重和偏差的示例代码。有谁知道这是否可能?我使用 python API 会更好吗?

谢谢 -法案

像下面这样的东西对我有用:

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Session.Runner;
import org.tensorflow.Tensor;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;

private static String graphPath = "models/export/exporter/1390282892";
public class PrintWeights {
public static void main(String[] args) throws Exception {
try (SavedModelBundle b = SavedModelBundle.load(graphPath, "serve")) {
Graph graph = b.graph();
GraphDef graphDef = GraphDef.parseFrom(graph.toGraphDef());
try (Session session = b.session()) {
Runner runner = session.runner();
Tensor<Float> tensor = (Tensor<Float>)runner.fetch("dnn/hiddenlayer_0/kernel").run().get(0);
float[][] weights = new float[(int)tensor.shape()[0]][(int)tensor.shape()[1]];
tensor.copyTo(weights);
System.out.println(Arrays.deepToString(weights).replace("], ", "]n"));
}
}
}
}

最新更新