什么是Dataloader?
它用于从
数据集中批量检索。
基本上使用
作为图像,我认为数据集是所有数据的列表,并且Dataloader就像每个迷你批处理的数据集的内容集一样。
所以
这将是。
(当然,因为它是一个迭代器,所以并不是所有数据都照原样包含。它只是一个图像。无法通过切片
数据加载器已定义
如果将其指定为
调试时很方便。
通过检查PyTorch转换/数据集/数据加载器的基本行为,可以轻松理解数据集和数据加载器的行为。
什么是采样器
sampler是Dataloader的参数,就像一个可以决定如何批处理数据集的设置。
基本上,采样器是一个类,它返回一个数据索引。
我认为
但是,当每个班级的训练图像有很大偏差并且您想要以相同的比例提供它们时,或者当您想要从每个班级中取出相同的数字并将它们放在网络中以进行远程学习时,等等。
torch.utils.data中大约有4个采样器,但是我认为我不会使用它们太多,所以这次我想自己动手。
但是,创建的不是采样器,而是batch_sampler。 batch_sampler也是Dataloader的参数之一,它返回多个数据的索引,而不是一一对应。
这次作为一个假设,让我们考虑从上一示例中所示的所有类中选择一些,并从每个类中提取相同的数字。假设
此代码基于此。
采样器
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import numpy as np import torch from torch.utils.data import DataLoader from torch.utils.data.sampler import BatchSampler class BalancedBatchSampler(BatchSampler): """ BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples. Returns batches of size n_classes * n_samples """ def __init__(self, dataset, n_classes, n_samples): loader = DataLoader(dataset) self.labels_list = [] for _, label in loader: self.labels_list.append(label) self.labels = torch.LongTensor(self.labels_list) self.labels_set = list(set(self.labels.numpy())) self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] for label in self.labels_set} for l in self.labels_set: np.random.shuffle(self.label_to_indices[l]) self.used_label_indices_count = {label: 0 for label in self.labels_set} self.count = 0 self.n_classes = n_classes self.n_samples = n_samples self.dataset = dataset self.batch_size = self.n_samples * self.n_classes def __iter__(self): self.count = 0 while self.count + self.batch_size < len(self.dataset): classes = np.random.choice(self.labels_set, self.n_classes, replace=False) indices = [] for class_ in classes: indices.extend(self.label_to_indices[class_][ self.used_label_indices_count[class_]:self.used_label_indices_count[ class_] + self.n_samples]) self.used_label_indices_count[class_] += self.n_samples if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): np.random.shuffle(self.label_to_indices[class_]) self.used_label_indices_count[class_] = 0 yield indices self.count += self.n_classes * self.n_samples def __len__(self): return len(self.dataset) // self.batch_size |
采样器和batch_sampler都需要定义__iter__。
__iter__返回
__len__是
让我们看看它是如何实际取出的。这次我想尝试mnist。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch import torchvision import torchvision.transforms as transforms from torchvision import datasets import numpy as np import matplotlib.pyplot as plt n_classes = 5 n_samples = 8 mnist_train = torchvision.datasets.MNIST(root="mnist/mnist_train", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),])) balanced_batch_sampler = BalancedBatchSampler(mnist_train, n_classes, n_samples) dataloader = torch.utils.data.DataLoader(mnist_train, batch_sampler=balanced_batch_sampler) my_testiter = iter(dataloader) images, target = my_testiter.next() def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) imshow(torchvision.utils.make_grid(images)) |
您可以看到从0到9中选择了5种类型,并且取出了8张纸。
参考
检查PyTorch变换/数据集/数据加载器的基本操作
[详细信息(?)] pytorch?CNN CIFAR10?简介
pytorh官方网站采样器