But what are PyTorch DataLoaders really?
Creating custom ways (without magic) to order, batch and combine your data with PyTorch DataLoaders.
- DataLoaders are magic.
- What's the plan?
- What are DataLoaders and Datasets?
- Samplers
- BatchSampler
- Custom Collate Functions
- Conclusion
You can often get away with using something magical. You can bury your head in the sand and ignore the mysterious methods behind it, all while enjoying the benefits that come from this magic. But at some point, either curiousity will get the better of you or you'll be missing the flexibility you need, and you'll want to try to demystify the sorcery.
In my opinion, the best libraries have an element of magic to them. They hide away some gory details with a little bit of polish and slight of hand that leave the world looking orderly and simple. The really great libraries allow you to peek behind the curtain at your own pace, slowly revealing the complexity and flexibility within.
I believe PyTorch is one of those libraries. It has lots of composable abstractions that you can learn about independenlty which neatly layer together to make a powerful, customisable and elegant framework. In this tutorial, we're going to dive into some of the details of PyTorch DataLoader
s in the hopes of discovering how it works behind the scenes and how we can customise it to our liking.
I recommend you to run for this yourself as a Jupyter Notebook and create your own your Samplers and collate functions. All the code from this post is available on Github.
PyTorch DataLoader
s are great for iterating over batches of a Dataset
like:
for xb, yb in dataloader:
...
where xb
and yb
are batches of your inputs and labels.
This tutorial is going to be about some of the more advanced features of DataLoader
s which should explain what happens behind the scenes when you iterate over your dataloaders and help you customise different parts of that using PyTorch native features.
To be specific, we're going to go over custom collate functions and Samplers.
For this tutorial to be useful, you should probably know what DataLoader
s and Dataset
s are but I will refresh your memory. For a deeper dive, I recommend Jeremy Howard's tutorial What is torch.nn really ? and the PyTorch docs Writing Custom Datasets, DataLoaders and Transforms.
A quick refresher: PyTorch Datasets are just things that have a length and are indexable so that len(dataset)
will work and dataset[index]
will return a tuple of (x,y).
So we'll create two lists for x and y values:
xs = list(range(10))
ys = list(range(10,20))
print('xs values: ', xs)
print('ys values: ', ys)
Then we use Python's zip
function to combine them so dataset[index]
returns (x,y)
for that index:
dataset = list(zip(xs,ys))
dataset[0] # returns the tuple (x[0], y[0])
len(dataset)
We could also get the same functionality by using a class with the "dunder/magic methods" __getitem__
(for dataset[index]
functionality) and __len__
(for len(dataset)
functionality)
#collapse-show
class MyDataset:
def __init__(self, xs, ys):
self.xs = xs
self.ys = ys
def __getitem__(self, i):
return self.xs[i], self.ys[i]
def __len__(self):
return len(self.xs)
dataset = MyDataset(xs, ys)
dataset[2] # returns the tuple (x[2], y[2])
len(dataset)
Then we just wrap that in a DataLoader and we can iterate it but now they're magically tensors
and we can use DataLoader
s handy configurations like shuffling, batching, multi-processing, etc.:
from torch.utils.data import DataLoader
for x, y in DataLoader(dataset):
print(x,y)
But the real fun is that we can get batches of these by setting batch_size
:
for x, y in DataLoader(dataset, batch_size=2):
print(x,y)
And we can shuffle these batches by just setting shuffle=True
:
for x, y in DataLoader(dataset, batch_size=2, shuffle=True):
print(x,y)
As you can see, it doesn't just shuffle the batches but instead, it shuffles the data and then batches.
OK... but can we customise this shuffling or batching??? Yes, we can customise the shuffling with a custom Sampler and we can customise the batching with a custom collate function.
Every DataLoader
has a Sampler
which is used internally to get the indices for each batch. Each index is used to index into your Dataset
to grab the data (x, y). You can ignore this for now, but DataLoader
s also have a batch_sampler
which returns the indices for each batch in a list if batch_size
is greater than 1.
Don't worry if this is a bit confusing, it'll be more clear after a few examples hopefully:
Let's have a look at the internal .sampler
property of a few DataLoader
s and see how it changes when the DataLoader configurations change:
When shuffle=False
(default) with batch_size=0
, the sampler
returns each index in 0,1,2,3,4...
as you iterate.
default_sampler = DataLoader(dataset).sampler
So when we iterate over the sampler we should get the indices:
for i in default_sampler:
# iterating over the SequentialSampler
print(i)
👍
type(default_sampler)
We can see it has a sampler
property internally which is a SequentialSampler
.
Let's import SequentialSampler
to see if we can use it ourself:
from torch.utils.data.sampler import SequentialSampler
sampler = SequentialSampler(dataset)
for x in sampler:
print(x)
So it just returns indices as you iterate over it. Great, what about when shuffle=True
?
When shuffled, we should expect randomly shuffled indices:
random_sampler = DataLoader(dataset, shuffle=True).sampler
for index in random_sampler:
print(index)
So shuffle=True
changes the sampler
internally, which returns random indices each iteration.
type(random_sampler)
We can see it's a RandomSampler
so let's import that and use it ourself.
from torch.utils.data.sampler import RandomSampler
random_sampler = RandomSampler(dataset)
for x in random_sampler:
print(x)
We can pass this in explicitly to a DataLoader
using the sampler
parameter like this:
dl = DataLoader(dataset, sampler=random_sampler)
for i in dl.sampler:
print(i)
So we've seen that every DataLoader
has a sampler
internally which is either SequentialSampler
or RandomSampler
depending on the value of shuffle
, and these are iterated over to get the indices of the Dataset
to use.
That's great and all, but what if we want to customise the order of the data, other than shuffled or sequential. That's where custom Sampler
s come in.
From the docs:
Every
Sampler
subclass has to provide an__iter__
method, providing a way to iterate over indices of dataset elements, and a__len__
method that returns the length of the returned iterators.
So all we have to do to create a custom sampler is subclass Sampler
and have a __iter__
method (for iterating through the indices) and a __len__
method for the length.
As a small toy example, say we wanted the first half of the dataset to always happen first, then the second half to happen later in training and we still to shuffle these two halfs independently:
#collapse-hide
import random
from torch.utils.data.sampler import Sampler
class IndependentHalvesSampler(Sampler):
def __init__(self, dataset):
halfway_point = int(len(dataset)/2)
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))
def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
return iter(self.first_half_indices + self.second_half_indices)
def __len__(self):
return len(self.first_half_indices) + len(self.second_half_indices)
So we've subclassed Sampler
, we've stored the both halves of the indices in two lists and when __iter__
is called (whenever the sampler is iterated over), it'll shuffle them independently and return an iterator of the two lists merged.
our_sampler = IndependentHalvesSampler(dataset)
print('First half indices: ', our_sampler.first_half_indices)
print('Second half indices:', our_sampler.second_half_indices)
for i in our_sampler:
print(i)
So you can see that a shuffled [0,1,2,3,4] happen first, and then a shuffled [5, 6, 7, 8, 9] happen last.
And we can pass it to a DataLoader
like so:
dl = DataLoader(dataset, sampler=our_sampler)
for xb, yb in dl:
print(xb, yb)
So we've seen what's responsible for the order of the indices and we've seen how PyTorch uses Sampler
s internally. We've also seen how to create our own Sampler
subclass and pass it to PyTorch's DataLoader
.
You may have noticed a small problem above, if I make the batch size > half of the dataset, some indices in the two halves of the dataset will be appear in the same batch.
batch_size=7
dl = DataLoader(dataset, batch_size=batch_size, sampler=our_sampler)
for xb, yb in dl:
print(xb, yb)
This goes against our original goal because we wanted the first half of the dataset to always happen first. Let's say we want all batches in the first half to be separate from the second half... that's where batch_sampler
s come in.
batch_size = 3
default_batch_sampler = DataLoader(dataset, batch_size=batch_size).batch_sampler
for i, batch_indices in enumerate(default_batch_sampler):
print(f'Batch #{i} indices: ', batch_indices)
Internally, PyTorch uses a BatchSampler
to chunk together the indices into batches. We can make custom Sampler
s which return batches of indices and pass them using the batch_sampler
argument. This is a bit more powerful in terms of customisation than sampler
because you can choose both the order and the batches at the same time.
For example, say for some reason you wanted to only batch certain things together (like only if they're the same length), or if you wanted to show some examples more often than other, a custom BatchSampler
is great for this.
So rather than returning each index separately, the batch_sampler
iterates through batches of indices. PyTorch uses the sampler
internally to select the order, and the batch_sampler
to batch together batch_size
amount of indices.
type(default_batch_sampler)
We can see it's a BatchSampler
internally. Let's import this to see what it does:
from torch.utils.data.sampler import BatchSampler
Here's the BatchSampler
docstring:
print(BatchSampler.__doc__)
So you can initialise it with a Sampler
, batch_size
and drop_last
(whether to remove the last batch), and it will return batches of indices when you iterate over it.
batch_sampler = BatchSampler(our_sampler, batch_size=2, drop_last=False)
for i, batch_indices in enumerate(batch_sampler):
print(f'Batch #{i} indices: ', batch_indices)
As you can see, we can pass our custom Sampler
to BatchSampler
to control the order, and leave it responsible for batching the indices. We can then pass it to a DataLoader
in the batch_sampler
argument.
Similar to a custom sampler
, you can also create a batch_sampler
. Why? If for some reason you wanted to only batch certain things together (like only if they're the same length), or if you wanted to show some examples more often than others, a custom BatchSampler
is great for this.
To create a custom batch_sampler
, we just do the same as we did with a custom Sampler
but our iterator returns batches of indices, rather than individual indices.
Let's create a BatchSampler
which only batches together values from the first half of our dataset.
def chunk(indices, chunk_size):
return torch.split(torch.tensor(indices), chunk_size)
class EachHalfTogetherBatchSampler(Sampler):
def __init__(self, dataset, batch_size):
halfway_point = len(dataset) // 2
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))
self.batch_size = batch_size
def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
first_half_batches = chunk(self.first_half_indices, self.batch_size)
second_half_batches = chunk(self.second_half_indices, self.batch_size)
combined = list(first_half_batches + second_half_batches)
combined = [batch.tolist() for batch in combined]
random.shuffle(combined)
return iter(combined)
def __len__(self):
return (len(self.first_half_indices) + len(self.second_half_indices)) // self.batch_size
So we've subclassed Sampler
, we've stored the indices in two lists (as before) and when __iter__
is called (whenever the batch_sampler
is iterated over), it'll first batch them using a method we've called chunk
.
Then we merge the batches and finally, we shuffle the batches and return an iterator of them.
batch_size = 2
each_half_together_batch_sampler = EachHalfTogetherBatchSampler(dataset, batch_size)
for x in each_half_together_batch_sampler:
print(x)
Great, as we hoped, none of the first and second half are batched together.
And now, we can pass this to DataLoader
using the batch_sampler
argument:
for i, (xb,yb) in enumerate(DataLoader(dataset, batch_sampler=each_half_together_batch_sampler)):
print(f'Batch #{i}. x{i}:', xb)
print(f' y{i}:', yb)
Ok, great. That's how PyTorch chooses which elements in my Dataset
to batch together... but where does that batching actually happen? And can we customise that ?
Internally, PyTorch uses a Collate Function to combine the data in your batches together (*see note).
By default, a function called default_collate checks what type of data your Dataset
returns and tries it's best to combine them data into a batch like a (x_batch, y_batch).
But what if we had custom types or multiple different types of data which we wanted to handle which default_collate
couldn't merge? We could edit our Dataset
so that they are mergable and that's solves some of the types issues BUT what if how we merged them depended on 'batch-level' information like the largest value in the batch.
For problems like these, custom collate functions are a handy way of solving them.
From the PyTorch docs:
Users may use customized
collate_fn
to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.
You will need to match your custom collate function with the output of indexing your Dataset
. If you dataset returns a tuple (x, y)
when indexed into (like dataset[0]
), then your collate function will need to take a list of tuples like [(x0,y0), (x4,y4), (x2,y2)... ]
which is batch_size
in length.
One thing that custom collate functions are often used for is for padding variable length batches. So let's change our dataset so that each x is a list, and they're all different sizes.
xs = list([torch.randint(0, 10, (x,)) for x in range(1, 11)])
xs
dataset = list(zip(xs,ys))
dataset[5]
Now, if we try with the defaul collate function, it'll raise a RuntimeError
.
try:
for xb, yb in DataLoader(dataset, batch_size=2):
print(xb)
except RuntimeError as e:
print('RuntimeError: ', e)
With variable sized xs and a custom collate function, we could pad them to match the longest in the batch using torch.nn.utils.rnn.pad_sequence
.
from torch.nn.utils.rnn import pad_sequence
def pad_x_collate_function(batch):
# batch looks like [(x0,y0), (x4,y4), (x2,y2)... ]
xs = [sample[0] for sample in batch]
ys = [sample[1] for sample in batch]
# If you want to be a little fancy, you can do the above in one line
# xs, ys = zip(*samples)
xs = pad_sequence(xs, batch_first=True, padding_value=0)
return xs, torch.tensor(ys)
And now, we can pass this pad_x_collate_function
to collate_fn
in DataLoader
and it will pad each batch.
for xb, yb in DataLoader(dataset, batch_size=2, collate_fn=pad_x_collate_function):
print(xb)
Here is it with shuffle=True
.
for xb, yb in DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=pad_x_collate_function):
print('xs: ', xb)
print('ys: ', yb)
But there's a bit of an issue, some of the smaller values look the they have too much padding. Luckily, we've already created something that'll help here. We can use our EachHalfTogetherBatchSampler
custom batch_sampler
so that the first and second half are batched separately.
each_half_together_batch_sampler = EachHalfTogetherBatchSampler(dataset, batch_size=2)
for xb, yb in DataLoader(dataset, collate_fn=pad_x_collate_function, batch_sampler=each_half_together_batch_sampler):
print(xb)
And there we go, you can see the zero padding (but not too much!) at the end of each batch.
I recommend you to run for this yourself and create your own your Samplers and collate functions. All the code from this post is available on Github.
I personally love learning about new parts of PyTorch and finding ways to interact with them. What do you think about these styles of explorations? Did you learn a bit about DataLoaders, Samplers and collate functions by reading this article? If so, feel free to share it, and you’re also more than welcome to contact me (via Twitter) if you have any questions, comments, or feedback.
Thanks for reading!
Follow me on Twitter here for more stuff like this.