我写的是:
def forward(self, x):
x = self.bert(x)
x = x.view(x.shape[0], -1)
x = self.fc(self.dropout(self.bn(x)))
return x
但是它不能很好地工作,错误是'MaskedLMOutput'对象没有属性'view'。我正在考虑输入可能不是"张量"类型,所以我将其更改为如下:
def forward(self, x):
x = torch.tensor(x) # this part
x = self.bert(x)
x = x.view(x.shape[0], -1)
x = self.fc(self.dropout(self.bn(x)))
return x
但它仍然出错,同样的错误'MaskedLMOutput'对象没有属性'view'。
谁能告诉我如何解决这个问题?多谢谢。
完整的错误信息:
------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [5], in <cell line: 8>()
6 optimizer = optim.Adam(bert_punc.parameters(), lr=learning_rate_top)
7 criterion = nn.CrossEntropyLoss()
----> 8 bert_punc, optimizer, best_val_loss = train(bert_punc, optimizer, criterion, epochs_top,
9 data_loader_train, data_loader_valid, save_path, punctuation_enc, iterations_top, best_val_loss=1e9)
Input In [3], in train(model, optimizer, criterion, epochs, data_loader_train, data_loader_valid, save_path, punctuation_enc, iterations, best_val_loss)
17 inputs.requires_grad = False
18 labels.requires_grad = False
---> 19 output = model(inputs)
20 loss = criterion(output, labels)
21 loss.backward()
File ~anaconda3libsite-packagestorchnnmodulesmodule.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~anaconda3libsite-packagestorchnnparalleldata_parallel.py:166, in DataParallel.forward(self, *inputs, **kwargs)
163 kwargs = ({},)
165 if len(self.device_ids) == 1:
--> 166 return self.module(*inputs[0], **kwargs[0])
167 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
168 outputs = self.parallel_apply(replicas, inputs, kwargs)
File ~anaconda3libsite-packagestorchnnmodulesmodule.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File D:BertPunc-originalmodel.py:21, in BertPunc.forward(self, x)
18 x = torch.tensor(x)
19 x = self.bert(x)
---> 21 x = x.view(x.shape[0], -1)
22 x = self.fc(self.dropout(self.bn(x)))
23 return x
AttributeError: 'MaskedLMOutput' object has no attribute 'view'
我想这应该可以帮助你解决这个错误。https://stackoverflow.com/a/72601533/13748930self.bert(x)之后的输出是MaskedLMOutput类的一个对象。