加载程序的数据类型无效-Pytorch Lightning DataModule



我正在尝试一个文本摘要练习,我有两列文本和摘要(标签(的训练和测试数据集。我使用的是T5、Pytorch和Lightning包装器,我有一个Pytorch数据集类,我可以确认它工作正常,并将以下内容作为文本字典返回,将id、标签和掩码作为张量返回。

return dict(
text=text,
summary = data_row['summary'],
text_input_ids = text_encoding['input_ids'].flatten(),
text_attention_mask = text_encoding['attention_mask'].flatten(),
labels = labels.flatten(),
labels_attention_mask = summary_encoding['attention_mask'].flatten()
)

然后,我有一个Lightning Data Module类,它将数据帧转换为PyTorch数据集,并将它们安装到数据加载器、返回train、val和测试数据加载器

class TextSummaryDataModule(pl.LightningModule):
def __init__(
self, 
train_df: pd.DataFrame, 
test_df: pd.DataFrame, 
tokenizer: T5Tokenizer, 
batch_size: int=8, 
text_max_token_len: int=512, 
summary_max_token_len: int=128
):

super().__init__()

self.train_df = train_df
self.test_df = test_df
self.tokenizer = tokenizer
self.batch_size = batch_size
self.text_max_token_len = text_max_token_len
self.summary_max_token_len = summary_max_token_len
def setup(self):
self.train_dataset = TextSummaryDataset(
self.train_df,
self.tokenizer,
self.text_max_token_len,
self.summary_max_token_len
)
self.test_dataset = TextSummaryDataset(
self.test_df,
self.tokenizer,
self.text_max_token_len,
self.summary_max_token_len
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle=True,
num_workers=2
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size = self.batch_size,
shuffle=False,
num_workers=2
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size = self.batch_size,
shuffle=False,
num_workers=2
)

一切都在工作,直到我尝试执行模型,我得到以下警告和错误

  1. 用户警告:您定义了validation_step,但没有val_datalader。跳过验证循环-我已经在数据模块中明确定义并返回了这一点

  2. 加载程序的数据类型无效:TextSummaryDataModule-我已经确认我正在返回一个包含标记、attention\ymask和文本和摘要标签的字典

遗憾的是,我在这里使用了pl.LightningModule而不是DataModule。。。

最新更新