fast.ai:如何在验证期间获取每批损失



我正在使用fast.ai实现AWD-LSTM模型。现在,我能够获得所有批次的验证损失平均值:

from fastai.text import *  
data_lm = (TextList.from_csv("data/penn", "concatenated.csv", cols='text')
.split_from_df("is_valid")
.label_for_lm()
.databunch())  
learner = language_model_learner(data_lm, AWD_LSTM, pretrained=False)  
learner.fit_one_cycle(10, 1e-2)  
learner.export("exported.pkl")
itemlist = TextList.from_csv("data/penn", "concatenated.csv", cols='text')  
newlearner = load_learner(path="data/penn", test=itemlist, file="exported.pkl")  
loss, acc = newlearner.validate(newlearner.data.test_dl)

但是如何获得每批的验证损失呢?

我尝试过的事情包括:
1.尝试附加Recorder。但似乎Recorder不监控验证,learner.losses只存储每批火车损失。
2.使用fastai.basic_train.loss_batch(learner.model, xb, yb, learner.loss_func),其中xbyb只是torch.Tensors。但这种方法给出了以下AttributeError

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-14-5fa44a2d640f> in <module>
5 xb = torch.ones((64, 20)).cuda().long()
6 yb = torch.ones((64, 20)).cuda().long()
----> 7 loss_batch(newlearner.model, xb, yb, newlearner.loss_func)
~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
27     out = cb_handler.on_loss_begin(out)
28     if not loss_func: return to_detach(out), to_detach(yb[0])
---> 29     loss = loss_func(out, *yb)
30 
31     if opt is not None:
~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/layers.py in __call__(self, input, target, **kwargs)
237 
238     def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
--> 239         input = input.transpose(self.axis,-1).contiguous()
240         target = target.transpose(self.axis,-1).contiguous()
241         if self.floatify: target = target.float()
AttributeError: 'tuple' object has no attribute 'transpose'

我现在得到了解决方案。

cb_handler = CallbackHandler(newlearner.callbacks + [], None)
losses, acc = fastai.basic_train.validate(
newlearner.model, 
newlearner.data.test_dl, 
newlearner.loss_func, 
cb_handler,  # This is necessary
average=False)

最新更新