如何将Pytorch模块(具有就地操作)更改为可微分



我的层是这样的(我正在制作一个LSTM层,在每个时间步长中应用丢弃,输入通过10次,并返回输出的平均值(

import torch
from torch import nn

class StochasticLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout_rate: float):
"""
Args:
- dropout_rate: should be between 0 and 1
"""
super(StochasticLSTM, self).__init__()
self.iter = 10
self.input_size = input_size
self.hidden_size = hidden_size
if not 0 <= dropout_rate <= 1:
raise Exception("Dropout rate should be between 0 and 1")
self.dropout = dropout_rate
self.bernoulli_x = torch.distributions.Bernoulli(
torch.full((self.input_size,), 1 - self.dropout)
)
self.bernoulli_h = torch.distributions.Bernoulli(
torch.full((hidden_size,), 1 - self.dropout)
)
self.Wi = nn.Linear(self.input_size, self.hidden_size)
self.Ui = nn.Linear(self.hidden_size, self.hidden_size)
self.Wf = nn.Linear(self.input_size, self.hidden_size)
self.Uf = nn.Linear(self.hidden_size, self.hidden_size)
self.Wo = nn.Linear(self.input_size, self.hidden_size)
self.Uo = nn.Linear(self.hidden_size, self.hidden_size)
self.Wg = nn.Linear(self.input_size, self.hidden_size)
self.Ug = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, input, hx=None):
"""
input shape (sequence, batch, input dimension)
output shape (sequence, batch, output dimension)
return output, (hidden_state, cell_state)
"""
T, B, _ = input.shape
if hx is None:
hx = torch.zeros((self.iter, T + 1, B, self.hidden_size), dtype=input.dtype)
else:
hx = hx.unsqueeze(0).repeat(self.iter, T + 1, B, self.hidden_size)
c = torch.zeros((self.iter, T + 1, B, self.hidden_size), dtype=input.dtype)
o = torch.zeros((self.iter, T, B, self.hidden_size), dtype=input.dtype)
for it in range(self.iter):
# Dropout
zx = self.bernoulli_x.sample()
zh = self.bernoulli_h.sample()
for t in range(T):
x = input[t] * zx
h = hx[it, t] * zh
i = torch.sigmoid(self.Ui(h) + self.Wi(x))
f = torch.sigmoid(self.Uf(h) + self.Wf(x))
o[it, t] = torch.sigmoid(self.Uo(h) + self.Wo(x))
g = torch.tanh(self.Ug(h) + self.Wg(x))
c[it, t + 1] = f * c[it, t] + i * g
hx[it, t + 1] = o[it, t] * torch.tanh(c[it, t + 1])
o = torch.mean(o, axis=0)
c = torch.mean(c[:, 1:], axis=0)
hx = torch.mean(hx[:, 1:], axis=0)
return o, (hx, c)

当我优化网络时,会出现错误one of the variables needed for gradient computation has been modified by an inplace operation。我们可以发现一些就地操作,如o[it, t] = torch.sigmoid(self.Uo(h) + self.Wo(x))

当我想找到平均值时,如何避免这种就地操作?

感谢

相反,使用Python列表来收集张量结果,并在最后将列表堆叠为一个张量,例如,而不是

t = torch.zeros(5, 5)
for i in range(5):
t[i,:] = ...

做这个

t = []
for i in range(5):
t.append(...)
t = torch.stack(t)

最新更新