Pytorch DataLoader - выберите набор данных класса STL10

Можно ли вытащить только тогда, когда class = 0 в наборе данных STL10 в PyTorch torchvision? Я могу проверять их в цикле, но мне нужно получать партии изображений класса 0

# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)


# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

for i, (images, labels) in enumerate(train_loader):
    if labels[0] == 0:...

изменить на основе ответа iacolippo - теперь это работает:

# Set params
batch_size = 25
label_class = 0   # only airplane images

# Return only images of certain class (eg. airplanes = class 0)
def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)

# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))

person Adam12344    schedule 14.07.2018    source источник


Ответы (1)


Если вам нужны образцы только из одного класса, вы можете получить индексы образцов с тем же классом из экземпляра Dataset с чем-то вроде

def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

тогда вы можете использовать SubsetRandomSampler для рисования выборок только из списка индексов одного класса

torch.utils.data.sampler.SubsetRandomSampler(indices)
person iacolippo    schedule 14.07.2018