You can often get away with using something magical. You can bury your head in the sand and ignore the mysterious methods behind it, all while enjoying the benefits that come from this magic. But at some point, either curiousity will get the better of you or you'll be missing the flexibility you need, and you'll want to try to demystify the sorcery.

In my opinion, the best libraries have an element of magic to them. They hide away some gory details with a little bit of polish and slight of hand that leave the world looking orderly and simple. The really great libraries allow you to peek behind the curtain at your own pace, slowly revealing the complexity and flexibility within.

I believe PyTorch is one of those libraries. It has lots of composable abstractions that you can learn about independenlty which neatly layer together to make a powerful, customisable and elegant framework. In this tutorial, we're going to dive into some of the details of PyTorch DataLoaders in the hopes of discovering how it works behind the scenes and how we can customise it to our liking.

I recommend you to run for this yourself as a Jupyter Notebook and create your own your Samplers and collate functions. All the code from this post is available on Github.

# What's the plan?

PyTorch DataLoaders are great for iterating over batches of a Dataset like:

for xb, yb in dataloader:
...

where xb and yb are batches of your inputs and labels.

This tutorial is going to be about some of the more advanced features of DataLoaders which should explain what happens behind the scenes when you iterate over your dataloaders and help you customise different parts of that using PyTorch native features.

To be specific, we're going to go over custom collate functions and Samplers.

# What are DataLoaders and Datasets?

For this tutorial to be useful, you should probably know what DataLoaders and Datasets are but I will refresh your memory. For a deeper dive, I recommend Jeremy Howard's tutorial What is torch.nn really ? and the PyTorch docs Writing Custom Datasets, DataLoaders and Transforms.

A quick refresher: PyTorch Datasets are just things that have a length and are indexable so that len(dataset) will work and dataset[index] will return a tuple of (x,y).

Here's a little example that's mostly taken from fastbook Chapter 4 to just quickly illustrate how simple a Dataset is:

So we'll create two lists for x and y values:

xs = list(range(10))
ys = list(range(10,20))
print('xs values: ', xs)
print('ys values: ', ys)

xs values:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
ys values:  [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]


Then we use Python's zip function to combine them so dataset[index] returns (x,y) for that index:

dataset = list(zip(xs,ys))
dataset[0] # returns the tuple (x[0], y[0])

(0, 10)
len(dataset)

10

## Use __getitem__ and __len__

We could also get the same functionality by using a class with the "dunder/magic methods" __getitem__ (for dataset[index] functionality) and __len__ (for len(dataset) functionality)

#collapse-show

class MyDataset:
def __init__(self, xs, ys):
self.xs = xs
self.ys = ys

def __getitem__(self, i):
return self.xs[i], self.ys[i]

def __len__(self):
return len(self.xs)

dataset = MyDataset(xs, ys)
dataset[2] # returns the tuple (x[2], y[2])

(2, 12)
len(dataset)

10

## Now use a DataLoader

Then we just wrap that in a DataLoader and we can iterate it but now they're magically tensors and we can use DataLoaders handy configurations like shuffling, batching, multi-processing, etc.:

from torch.utils.data import DataLoader

print(x,y)

tensor([0]) tensor([10])
tensor([1]) tensor([11])
tensor([2]) tensor([12])
tensor([3]) tensor([13])
tensor([4]) tensor([14])
tensor([5]) tensor([15])
tensor([6]) tensor([16])
tensor([7]) tensor([17])
tensor([8]) tensor([18])
tensor([9]) tensor([19])


But the real fun is that we can get batches of these by setting batch_size:

for x, y in DataLoader(dataset, batch_size=2):
print(x,y)

tensor([0, 1]) tensor([10, 11])
tensor([2, 3]) tensor([12, 13])
tensor([4, 5]) tensor([14, 15])
tensor([6, 7]) tensor([16, 17])
tensor([8, 9]) tensor([18, 19])


And we can shuffle these batches by just setting shuffle=True:

for x, y in DataLoader(dataset, batch_size=2, shuffle=True):
print(x,y)

tensor([2, 5]) tensor([12, 15])
tensor([7, 9]) tensor([17, 19])
tensor([1, 4]) tensor([11, 14])
tensor([0, 6]) tensor([10, 16])
tensor([8, 3]) tensor([18, 13])


As you can see, it doesn't just shuffle the batches but instead, it shuffles the data and then batches.

OK... but can we customise this shuffling or batching??? Yes, we can customise the shuffling with a custom Sampler and we can customise the batching with a custom collate function.

# Samplers

Every DataLoader has a Sampler which is used internally to get the indices for each batch. Each index is used to index into your Dataset to grab the data (x, y). You can ignore this for now, but DataLoaders also have a batch_sampler which returns the indices for each batch in a list if batch_size is greater than 1.

Don't worry if this is a bit confusing, it'll be more clear after a few examples hopefully:

Let's have a look at the internal .sampler property of a few DataLoaders and see how it changes when the DataLoader configurations change:

## SequentialSampler

When shuffle=False(default) with batch_size=0, the sampler returns each index in 0,1,2,3,4... as you iterate.

default_sampler = DataLoader(dataset).sampler


So when we iterate over the sampler we should get the indices:

for i in default_sampler:
# iterating over the SequentialSampler
print(i)

0
1
2
3
4
5
6
7
8
9


👍

type(default_sampler)

torch.utils.data.sampler.SequentialSampler

We can see it has a sampler property internally which is a SequentialSampler.

Let's import SequentialSampler to see if we can use it ourself:

from torch.utils.data.sampler import SequentialSampler

sampler = SequentialSampler(dataset)

for x in sampler:
print(x)

0
1
2
3
4
5
6
7
8
9


So it just returns indices as you iterate over it. Great, what about when shuffle=True?

## RandomSampler

When shuffled, we should expect randomly shuffled indices:

random_sampler = DataLoader(dataset, shuffle=True).sampler
for index in random_sampler:
print(index)

3
0
7
5
2
4
6
9
8
1


So shuffle=True changes the sampler internally, which returns random indices each iteration.

type(random_sampler)

torch.utils.data.sampler.RandomSampler

We can see it's a RandomSampler so let's import that and use it ourself.

from torch.utils.data.sampler import RandomSampler

random_sampler = RandomSampler(dataset)

for x in random_sampler:
print(x)

9
7
0
3
2
4
6
5
8
1


We can pass this in explicitly to a DataLoader using the sampler parameter like this:

dl = DataLoader(dataset, sampler=random_sampler)
for i in dl.sampler:
print(i)

2
1
0
6
8
3
9
5
7
4


So we've seen that every DataLoader has a sampler internally which is either SequentialSampler or RandomSampler depending on the value of shuffle, and these are iterated over to get the indices of the Dataset to use.

## Custom Sampler

That's great and all, but what if we want to customise the order of the data, other than shuffled or sequential. That's where custom Samplers come in.

From the docs:

Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators.

So all we have to do to create a custom sampler is subclass Sampler and have a __iter__ method (for iterating through the indices) and a __len__ method for the length.

As a small toy example, say we wanted the first half of the dataset to always happen first, then the second half to happen later in training and we still to shuffle these two halfs independently:

Note: Making the order of the data less random is generally bad for training neural networks but let’s forget about that for this example please.

#collapse-hide
import random
from torch.utils.data.sampler import Sampler

class IndependentHalvesSampler(Sampler):
def __init__(self, dataset):
halfway_point = int(len(dataset)/2)
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))

def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
return iter(self.first_half_indices + self.second_half_indices)

def __len__(self):
return len(self.first_half_indices) + len(self.second_half_indices)


So we've subclassed Sampler, we've stored the both halves of the indices in two lists and when __iter__ is called (whenever the sampler is iterated over), it'll shuffle them independently and return an iterator of the two lists merged.

our_sampler = IndependentHalvesSampler(dataset)
print('First half indices: ', our_sampler.first_half_indices)
print('Second half indices:', our_sampler.second_half_indices)

First half indices:  [0, 1, 2, 3, 4]
Second half indices: [5, 6, 7, 8, 9]

for i in our_sampler:
print(i)

1
2
4
3
0
7
9
8
5
6


So you can see that a shuffled [0,1,2,3,4] happen first, and then a shuffled [5, 6, 7, 8, 9] happen last.

And we can pass it to a DataLoader like so:

dl = DataLoader(dataset, sampler=our_sampler)
for xb, yb in dl:
print(xb, yb)

tensor([4]) tensor([14])
tensor([0]) tensor([10])
tensor([1]) tensor([11])
tensor([3]) tensor([13])
tensor([2]) tensor([12])
tensor([7]) tensor([17])
tensor([8]) tensor([18])
tensor([5]) tensor([15])
tensor([6]) tensor([16])
tensor([9]) tensor([19])


So we've seen what's responsible for the order of the indices and we've seen how PyTorch uses Samplers internally. We've also seen how to create our own Sampler subclass and pass it to PyTorch's DataLoader.

## A slight problem

You may have noticed a small problem above, if I make the batch size > half of the dataset, some indices in the two halves of the dataset will be appear in the same batch.

batch_size=7
for xb, yb in dl:
print(xb, yb)

tensor([4, 2, 0, 3, 1, 5, 7]) tensor([14, 12, 10, 13, 11, 15, 17])
tensor([8, 6, 9]) tensor([18, 16, 19])


This goes against our original goal because we wanted the first half of the dataset to always happen first. Let's say we want all batches in the first half to be separate from the second half... that's where batch_samplers come in.

# BatchSampler

batch_size = 3
for i, batch_indices in enumerate(default_batch_sampler):
print(f'Batch #{i} indices: ', batch_indices)

Batch #0 indices:  [0, 1, 2]
Batch #1 indices:  [3, 4, 5]
Batch #2 indices:  [6, 7, 8]
Batch #3 indices:  [9]


Internally, PyTorch uses a BatchSampler to chunk together the indices into batches. We can make custom Samplers which return batches of indices and pass them using the batch_sampler argument. This is a bit more powerful in terms of customisation than sampler because you can choose both the order and the batches at the same time.

For example, say for some reason you wanted to only batch certain things together (like only if they're the same length), or if you wanted to show some examples more often than other, a custom BatchSampler is great for this.

So rather than returning each index separately, the batch_sampler iterates through batches of indices. PyTorch uses the sampler internally to select the order, and the batch_sampler to batch together batch_size amount of indices.

type(default_batch_sampler)

torch.utils.data.sampler.BatchSampler

We can see it's a BatchSampler internally. Let's import this to see what it does:

from torch.utils.data.sampler import BatchSampler


Here's the BatchSampler docstring:

print(BatchSampler.__doc__)

Wraps another sampler to yield a mini-batch of indices.

Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If True, the sampler will drop the last batch if
its size would be less than batch_size

Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]



So you can initialise it with a Sampler, batch_size and drop_last (whether to remove the last batch), and it will return batches of indices when you iterate over it.

batch_sampler = BatchSampler(our_sampler, batch_size=2, drop_last=False)
for i, batch_indices in enumerate(batch_sampler):
print(f'Batch #{i} indices: ', batch_indices)

Batch #0 indices:  [2, 4]
Batch #1 indices:  [1, 3]
Batch #2 indices:  [0, 9]
Batch #3 indices:  [8, 6]
Batch #4 indices:  [5, 7]


As you can see, we can pass our custom Sampler to BatchSampler to control the order, and leave it responsible for batching the indices. We can then pass it to a DataLoader in the batch_sampler argument.

## Custom Batch Sampler

Similar to a custom sampler, you can also create a batch_sampler. Why? If for some reason you wanted to only batch certain things together (like only if they're the same length), or if you wanted to show some examples more often than others, a custom BatchSampler is great for this.

To create a custom batch_sampler, we just do the same as we did with a custom Sampler but our iterator returns batches of indices, rather than individual indices.

Let's create a BatchSampler which only batches together values from the first half of our dataset.

def chunk(indices, chunk_size):

class EachHalfTogetherBatchSampler(Sampler):
def __init__(self, dataset, batch_size):
halfway_point = len(dataset) // 2
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))
self.batch_size = batch_size

def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
first_half_batches  = chunk(self.first_half_indices, self.batch_size)
second_half_batches = chunk(self.second_half_indices, self.batch_size)
combined = list(first_half_batches + second_half_batches)
combined = [batch.tolist() for batch in combined]
random.shuffle(combined)
return iter(combined)

def __len__(self):
return (len(self.first_half_indices) + len(self.second_half_indices)) // self.batch_size


So we've subclassed Sampler, we've stored the indices in two lists (as before) and when __iter__ is called (whenever the batch_sampler is iterated over), it'll first batch them using a method we've called chunk.

Then we merge the batches and finally, we shuffle the batches and return an iterator of them.

batch_size = 2
each_half_together_batch_sampler = EachHalfTogetherBatchSampler(dataset, batch_size)
for x in each_half_together_batch_sampler:
print(x)

[1]
[5]
[8, 6]
[7, 9]
[4, 2]
[0, 3]


Great, as we hoped, none of the first and second half are batched together.

And now, we can pass this to DataLoader using the batch_sampler argument:

for i, (xb,yb) in enumerate(DataLoader(dataset, batch_sampler=each_half_together_batch_sampler)):
print(f'Batch #{i}. x{i}:', xb)
print(f'          y{i}:', yb)

Batch #0. x0: tensor([7, 5, 8])
y0: tensor([17, 15, 18])
Batch #1. x1: tensor([9, 6])
y1: tensor([19, 16])
Batch #2. x2: tensor([2, 0])
y2: tensor([12, 10])
Batch #3. x3: tensor([3, 1, 4])
y3: tensor([13, 11, 14])


Ok, great. That's how PyTorch chooses which elements in my Dataset to batch together... but where does that batching actually happen? And can we customise that ?

# Custom Collate Functions

Internally, PyTorch uses a Collate Function to combine the data in your batches together (*see note). By default, a function called default_collate checks what type of data your Dataset returns and tries it's best to combine them data into a batch like a (x_batch, y_batch).

Note: For simplicity, we’re going to assume automatic batching is enabled. See the PyTorch docs for details about collate functions when automatic batching is disabled: https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler

But what if we had custom types or multiple different types of data which we wanted to handle which default_collate couldn't merge? We could edit our Dataset so that they are mergable and that's solves some of the types issues BUT what if how we merged them depended on 'batch-level' information like the largest value in the batch.

Note: For other fancy uses of custom collate functions, there’s some cool examples in the popular huggingface/transformers library.

For problems like these, custom collate functions are a handy way of solving them.

From the PyTorch docs:

Users may use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.

## Input to your collate function

You will need to match your custom collate function with the output of indexing your Dataset. If you dataset returns a tuple (x, y) when indexed into (like dataset[0]), then your collate function will need to take a list of tuples like [(x0,y0), (x4,y4), (x2,y2)... ] which is batch_size in length.

One thing that custom collate functions are often used for is for padding variable length batches. So let's change our dataset so that each x is a list, and they're all different sizes.

xs = list([torch.randint(0, 10, (x,)) for x in range(1, 11)])

xs

[tensor([5]),
tensor([3, 8]),
tensor([7, 7, 1]),
tensor([2, 6, 5, 7]),
tensor([5, 1, 4, 0, 7]),
tensor([0, 1, 2, 1, 4, 9]),
tensor([2, 3, 0, 9, 3, 4, 4]),
tensor([2, 8, 8, 5, 7, 8, 2, 8]),
tensor([5, 4, 0, 2, 1, 9, 5, 3, 2]),
tensor([9, 2, 4, 7, 4, 3, 6, 6, 6, 7])]
dataset = list(zip(xs,ys))
dataset[5]

(tensor([0, 1, 2, 1, 4, 9]), 15)

Now, if we try with the defaul collate function, it'll raise a RuntimeError.

try:
for xb, yb in DataLoader(dataset, batch_size=2):
print(xb)
except RuntimeError as e:
print('RuntimeError: ', e)

RuntimeError:  stack expects each tensor to be equal size, but got [1] at entry 0 and [2] at entry 1


With variable sized xs and a custom collate function, we could pad them to match the longest in the batch using torch.nn.utils.rnn.pad_sequence.

from torch.nn.utils.rnn import pad_sequence

# batch looks like [(x0,y0), (x4,y4), (x2,y2)... ]
xs = [sample[0] for sample in batch]
ys = [sample[1] for sample in batch]
# If you want to be a little fancy, you can do the above in one line
# xs, ys = zip(*samples)
return xs, torch.tensor(ys)


And now, we can pass this pad_x_collate_function to collate_fn in DataLoader and it will pad each batch.

for xb, yb in DataLoader(dataset, batch_size=2, collate_fn=pad_x_collate_function):
print(xb)

tensor([[5, 0],
[3, 8]])
tensor([[7, 7, 1, 0],
[2, 6, 5, 7]])
tensor([[5, 1, 4, 0, 7, 0],
[0, 1, 2, 1, 4, 9]])
tensor([[2, 3, 0, 9, 3, 4, 4, 0],
[2, 8, 8, 5, 7, 8, 2, 8]])
tensor([[5, 4, 0, 2, 1, 9, 5, 3, 2, 0],
[9, 2, 4, 7, 4, 3, 6, 6, 6, 7]])


Here is it with shuffle=True.

for xb, yb in DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=pad_x_collate_function):
print('xs: ', xb)
print('ys: ', yb)

xs:  tensor([[9, 2, 4, 7, 4, 3, 6, 6, 6, 7],
[2, 8, 8, 5, 7, 8, 2, 8, 0, 0]])
ys:  tensor([19, 17])
xs:  tensor([[2, 6, 5, 7],
[5, 0, 0, 0]])
ys:  tensor([13, 10])
xs:  tensor([[5, 4, 0, 2, 1, 9, 5, 3, 2],
[2, 3, 0, 9, 3, 4, 4, 0, 0]])
ys:  tensor([18, 16])
xs:  tensor([[5, 1, 4, 0, 7],
[3, 8, 0, 0, 0]])
ys:  tensor([14, 11])
xs:  tensor([[7, 7, 1, 0, 0, 0],
[0, 1, 2, 1, 4, 9]])
ys:  tensor([12, 15])


## Another slight problem

But there's a bit of an issue, some of the smaller values look the they have too much padding. Luckily, we've already created something that'll help here. We can use our EachHalfTogetherBatchSampler custom batch_sampler so that the first and second half are batched separately.

each_half_together_batch_sampler = EachHalfTogetherBatchSampler(dataset, batch_size=2)
print(xb)

tensor([[7, 7, 1, 0, 0],
[5, 1, 4, 0, 7]])
tensor([[5, 4, 0, 2, 1, 9, 5, 3, 2],
[0, 1, 2, 1, 4, 9, 0, 0, 0]])
tensor([[2, 3, 0, 9, 3, 4, 4, 0],
[2, 8, 8, 5, 7, 8, 2, 8]])
tensor([[2, 6, 5, 7],
[5, 0, 0, 0]])
tensor([[9, 2, 4, 7, 4, 3, 6, 6, 6, 7]])
tensor([[3, 8]])


And there we go, you can see the zero padding (but not too much!) at the end of each batch.

# Conclusion

I recommend you to run for this yourself and create your own your Samplers and collate functions. All the code from this post is available on Github.

I personally love learning about new parts of PyTorch and finding ways to interact with them. What do you think about these styles of explorations? Did you learn a bit about DataLoaders, Samplers and collate functions by reading this article? If so, feel free to share it, and you’re also more than welcome to contact me (via Twitter) if you have any questions, comments, or feedback.