今天在写代码时想把我的数据集按8:2划分为训练集和验证集,然后就去网上搜索PyTorch有没有现成的函数,发现PyTorch提供一个函数torch.utils.data.random_split(dataset, lengths) 可以按照给定的长度将数据集划分成没有重叠的新数据集组合。下面以CIFAR100的训练集为例,首先读取数据集:

fullset = torchvision.datasets.CIFAR100(root='./data', train=True, download=False,transform=transform_train)

然后根据数据集的长度,计算出训练集和验证集的长度

train_size = int(0.8 * len(fullset))
test_size = len(fullset) - train_size

接着使用torch.utils.data.random_split(dataset, lengths)划分数据集

trainset, testset = torch.utils.data.random_split(fullset, [train_size, test_size])

最后使用torch.utils.data.DataLoader来包装数据集

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=2)