如何使用Pytorch的Dataloader和Sampler


什么是Dataloader?

它用于从

数据集中批量检索。
基本上使用torch.utils.data.DataLoader
作为图像,我认为数据集是所有数据的列表,并且Dataloader就像每个迷你批处理的数据集的内容集一样。
datsets = [データセット全て]
Dataloader = [[batch_1], [batch_2], ... [batch_n]]
所以
len(datasets)="すべてのデータの数"
len(Dataloader)="イテレーションの数"
这将是。
(当然,因为它是一个迭代器,所以并不是所有数据都照原样包含。它只是一个图像。无法通过切片[:x]检索。)

数据加载器已定义__iter____next__,因此
如果将其指定为iter(Dataloader).__next__(),则可以从头开始一次取出一批。
调试时很方便。

通过检查PyTorch转换/数据集/数据加载器的基本行为,可以轻松理解数据集和数据加载器的行为。

什么是采样器

sampler是Dataloader的参数,就像一个可以决定如何批处理数据集的设置。
基本上,采样器是一个类,它返回一个数据索引。

我认为testloader = torch.utils.data.DataLoader(testset, batch_size=n,shuffle=True)对于正常学习已经足够。
但是,当每个班级的训练图像有很大偏差并且您想要以相同的比例提供它们时,或者当您想要从每个班级中取出相同的数字并将它们放在网络中以进行远程学习时,等等。

torch.utils.data中大约有4个采样器,但是我认为我不会使用它们太多,所以这次我想自己动手。
但是,创建的不是采样器,而是batch_sampler。 batch_sampler也是Dataloader的参数之一,它返回多个数据的索引,而不是一一对应。
这次作为一个假设,让我们考虑从上一示例中所示的所有类中选择一些,并从每个类中提取相同的数字。假设n_classes是要选择的类的数量,而n_samples是要从一个类中提取的数量。所有数据的数量将为n_classes*n_samples

此代码基于此。

采样器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch

from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler


class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset)
        self.labels_list = []
        for _, label in loader:
            self.labels_list.append(label)
        self.labels = torch.LongTensor(self.labels_list)
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

采样器和batch_sampler都需要定义__iter__。
__iter__返回n_classes*n_samples索引。
__len__是データセットの数//バッチサイズ,所以它是len(batch_sampler)="イテレーションの数"

让我们看看它是如何实际取出的。这次我想尝试mnist。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

n_classes = 5
n_samples = 8

mnist_train =  torchvision.datasets.MNIST(root="mnist/mnist_train", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),]))

balanced_batch_sampler = BalancedBatchSampler(mnist_train, n_classes, n_samples)

dataloader = torch.utils.data.DataLoader(mnist_train, batch_sampler=balanced_batch_sampler)
my_testiter = iter(dataloader)
images, target = my_testiter.next()


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

imshow(torchvision.utils.make_grid(images))

image.png

您可以看到从0到9中选择了5种类型,并且取出了8张纸。

参考

检查PyTorch变换/数据集/数据加载器的基本操作
[详细信息(?)] pytorch?CNN CIFAR10?简介
pytorh官方网站采样器