Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted sampler for multiple datasets #4

Open
deepsworld opened this issue Oct 25, 2021 · 1 comment
Open

Weighted sampler for multiple datasets #4

deepsworld opened this issue Oct 25, 2021 · 1 comment
Labels
enhancement New feature or request

Comments

@deepsworld
Copy link
Member

deepsworld commented Oct 25, 2021

A weighted sampler that samples data from multiple datasets for ease of combining from different datasets.

@deepsworld deepsworld added the enhancement New feature or request label Oct 25, 2021
@deepsworld
Copy link
Member Author

deepsworld commented Oct 25, 2021

I have a rough implementation for it.

import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader, WeightedRandomSampler

class custom_dataset0(Dataset):
    def __init__(self):
        super().__init__()
        self.tensor_data = torch.tensor([i for i in range(80)])

    def __getitem__(self, index):
        return self.tensor_data[index], torch.tensor(0)

    def __len__(self):
        return len(self.tensor_data)

class custom_dataset1(Dataset):
    def __init__(self):
        super().__init__()
        self.tensor_data = torch.tensor([i for i in range(20)])

    def __getitem__(self, index):
        return self.tensor_data[index], torch.tensor(1)

    def __len__(self):
        return len(self.tensor_data)

dataset0 = custom_dataset0()
dataset1 = custom_dataset1()

datasets = [dataset0, dataset1]
concat_dataset = ConcatDataset(datasets)
lengths = torch.tensor([len(dataset) for dataset in datasets])
# calculate weights based on length of each dataset
dataset_weights = 1 / lengths
# dataset_weights = [0.2, 0.8] # can also use custom weights
weights = torch.ones(lengths.sum().item(), dtype=torch.float32)
indice = 0
for i, idx in enumerate(lengths):
    weights[indice:indice + idx] = dataset_weights[i] 
    indice += idx
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
dataloader = DataLoader(concat_dataset, batch_size=16, sampler=sampler)
for i, data in enumerate(dataloader):
    val, dataset_no = data
    print("batch index {}, dataset0/dataset1: {}/{}".format(i, (dataset_no == 0).sum(), (dataset_no == 1).sum()))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant