我们可以避免在 PyTorch 的批处理范数中指定"num_features",就像 Tensorflow 一样吗?



这是TF中的批规范:

model = BatchNormalization(momentum=0.15, axis=-1)(model)

下面是Torch的批规范:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

可以看到,还有一个参数:num_features。真烦人。

假设我不希望火炬中的affine, TF和火炬中的批范数应该相同。有没有办法避免指定"num_features"在批处理规范PyTorch,就像Tensorflow?

如果你真的不喜欢指定这个参数,你可能想看看lazy batch norm

否则,您可以指定num_features为任何您喜欢的(None?),只要affinetrack_running_stats都是False。如果您查看批处理规范函数的基类(可在此链接中获得):

class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()

可以看到,当affine为True时,num_features被用来设置self.weightself.bias,当track_running_stats为True时,running_meanrunning_std也被用来设置。

相关内容

最新更新