冻结后存在Batchnorm层,用于从DeeplabV3+模型进行推理,这可以吗



导出经过训练的模型后,BatchNorm层仍然存在。我在某个地方读到,出于两个原因,这些应该被删除:

  1. 网络输出可能错误
  2. 整个网络的速度

嗯,我对1有疑问。但第二个事实听起来合乎逻辑,所以我的问题是:那么如何过滤掉这些层呢?

环境:来自Tensorflow GitHub的模型,并在Tensorflow 1.15.3上进行了培训。

出口二手:

python deeplab/export_model.py 
--num_classes=2 --model_variant="mobilenet_v3_large_seg" 
--dataset="123" 
--checkpoint_path=training 
--crop_size=384 
--crop_size=384 
--export_path=graph.pb

摘录网络图:

(<tf.Tensor 'MobilenetV3/MobilenetV3/input:0' shape=(1, 768, 768, 3) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/weights:0' shape=(3, 3, 3, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/weights/read:0' shape=(3, 3, 3, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/Conv2D:0' shape=(1, 384, 384, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/gamma:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/gamma/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/beta:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/beta/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_mean:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_mean/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_variance:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_variance/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:0' shape=(1, 384, 384, 16) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:1' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:2' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:3' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:4' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:5' shape=<unknown> dtype=float32>)
(<tf.Tensor 'MobilenetV3/Conv/hard_swish/add/y:0' shape=() dtype=float32>,)

两者都为真。批量规范化只适用于训练期间,就像Dropout一样。实际上你不需要自己处理。

在推理过程中,只需使用model.prpredict,库就会处理它。也就是说,所有的批处理规范化和丢弃层都将被停用。

如果你需要做比预测更花哨的事情,你也可以通过Training=False这一论点。查看文档。https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

相关内容

  • 没有找到相关文章

最新更新