如何判断新数据集中的样本是否属于PyTorch中的原始数据集



我是PyTorch的新手。现在,我有两个名为A和B的数据集(例如:MNIST(。我想把A和B混合在一起形成一个新的数据集。我想打乱这个新的数据集。在培训期间,我需要确定批次中的样品是否属于A。我该如何做到这一点?

这两个问题如下:1( 如何混合两个数据集并对其进行混洗?2( 如何确定新数据集中的样本是否属于原始数据集A?

通过定义自定义数据集和一些标志标签,您可以实现这一点。这是示例代码:

import torch
from torch.utils.data import DataLoader, Dataset 
class to_dataset(Dataset):
def __init__(self , data_A, data_B):
self.lena   = data_A.shape[0]
self.len    = data_A.shape[0] + data_B.shape[0]
self.A      = torch.from_numpy(data_A).float()
self.B      = torch.from_numpy(data_B).float()

# returns dataset A data with flag label 0
# dataset B data with flag label 1
def __getitem__(self, index):
if index < self.lena:            
return self.A[index], 0
retutn self.B[index-self.lena], 1
def __len__(self):
return self.len

#reading sample numpy dataset
data_a  = np.load(pathofA)
data_b  = np.load(pathofB)

# loading custom dataset
dataset = to_dataset(data_a, data_b)
#loading dataloader with training data
train_loader = DataLoader(dataset=dataset, batch_size=bsize, shuffle=True)
#sample train loop
for epoch in range(epochs):
for data, label in train_loader:
for d,l in zip(data, label):  
if l == 0:
print('from A')
else:
print('from B')

相关内容

  • 没有找到相关文章

最新更新