今天在写代码时想把我的数据集按8:2划分为训练集和验证集,然后就去网上搜索PyTorch有没有现成的函数,发现PyTorch提供一个函数torch.utils.data.random_split(dataset, lengths)
可以按照给定的长度将数据集划分成没有重叠的新数据集组合。下面以CIFAR100的训练集为例,首先读取数据集:
fullset = torchvision.datasets.CIFAR100(root='./data', train=True, download=False,transform=transform_train)
然后根据数据集的长度,计算出训练集和验证集的长度
train_size = int(0.8 * len(fullset))
test_size = len(fullset) - train_size
接着使用torch.utils.data.random_split(dataset, lengths)
划分数据集
trainset, testset = torch.utils.data.random_split(fullset, [train_size, test_size])
最后使用torch.utils.data.DataLoader
来包装数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=2)
添加评论
评论
久别人潮
博主的网站我测了下 访问好快 请问是国内几兆的
或者是用了什么加速 wp主题能这么快 我第一次看到
浅笑顾盼
@久别人潮 阿里云5M轻量应用服务器