我是PyTorch的新手。现在,我有两个名为A和B的数据集(例如:MNIST(。我想把A和B混合在一起形成一个新的数据集。我想打乱这个新的数据集。在培训期间,我需要确定批次中的样品是否属于A。我该如何做到这一点?
这两个问题如下:1( 如何混合两个数据集并对其进行混洗?2( 如何确定新数据集中的样本是否属于原始数据集A?
通过定义自定义数据集和一些标志标签,您可以实现这一点。这是示例代码:
import torch
from torch.utils.data import DataLoader, Dataset
class to_dataset(Dataset):
def __init__(self , data_A, data_B):
self.lena = data_A.shape[0]
self.len = data_A.shape[0] + data_B.shape[0]
self.A = torch.from_numpy(data_A).float()
self.B = torch.from_numpy(data_B).float()
# returns dataset A data with flag label 0
# dataset B data with flag label 1
def __getitem__(self, index):
if index < self.lena:
return self.A[index], 0
retutn self.B[index-self.lena], 1
def __len__(self):
return self.len
#reading sample numpy dataset
data_a = np.load(pathofA)
data_b = np.load(pathofB)
# loading custom dataset
dataset = to_dataset(data_a, data_b)
#loading dataloader with training data
train_loader = DataLoader(dataset=dataset, batch_size=bsize, shuffle=True)
#sample train loop
for epoch in range(epochs):
for data, label in train_loader:
for d,l in zip(data, label):
if l == 0:
print('from A')
else:
print('from B')