On-the-fly Augmentation with PyTorch Geometric and Lightning: What Tutorials Don't Teach
So much of life, it seems to me, is determined by pure randomness.
— Sidney Poitier
On-the-fly data augmentation is a practice which applies random noise to data at the time of loading. This allows for a significant increase in the effective size of your dataset, as each piece of data can be randomly transformed every time it’s loaded, providing a potentially unlimited number of variations. As you iterate over your dataset, randomness can be injected in creative ways including executing some, none, or all of a series of augmentation steps at random each with their own stochastic direction and magnitude.
A Lonely Melody Beyond Computer Vision
Deep Learning curricula treat augmentation as a core component. Tutorials include sections on data augmentation’s ability to mitigate against common challenges such as class imbalance, error propagation, and scarce data, yet little is published outside of computer vision. You’ll quickly find that resources become scarce once you venture outside cropping and rotating images of cats. The reality is, augmentation strategies are not as well-documented for other types of data, like text or graphs. The journey to implementing custom augmentations here can feel like wandering uncharted territory.
In my case, working with complex machine learning pipelines that chain models together means augmentation can be a daunting but necessary voyage. I did some research on open-source solutions to simulate noise that can be introduced by OCR models and found nlpaug to be initially promising as it was the only one which touted a “pre-defined OCR mapping” that could apply noise conditioned on common OCR mistakes. I was intrigued, so I dug through the source code until I found the highly coveted mapping file. Needless to say, I was a bit disappointed. . . but ChatGPT quickly helped me whip up a much more comprehensive mapping that could be easily substituted. If you’re thinking of using this library I would also caution about this issue.
Harmonizing Graph Augmentation
Unlike images or text, graph data requires specialized augmentation techniques. The connectivity of nodes and edges present unique topology that must be maintained during augmentation, or augmented only with careful intention.
PyTorch provides a powerful and flexible toolkit for data augmentation, primarily through the use of the Transforms
class. At its core, a Transform
in PyTorch is a function that takes in some data and returns a transformed version of that data. This could be as simple as resizing an image, flipping text characters at random, or moving data to the GPU. Transforms can be chained together and then attached to a PyTorch Dataset
class where they are used to modify data each time it is loaded, generally by a PyTorch DataLoader
.
PyTorch Geometric extends this paradigm to graph data with its own Dataset
, DataLoader
, and Transform
classes. Just like with pure PyTorch, custom transforms can be defined simply by inheriting from the BaseTransform
class and implementing the __call__
and optionally __init__
methods. Transforms
can be then be simply attached to your Dataset
and will be called as the data is fetched, like so:1
import torch_geometric.transforms as T
from torch_geometric.datasets import TUDataset
transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()])
dataset = TUDataset(path, name='MUTAG', transform=transform)
While this is the recommended method according to both frameworks, there is a more straightforward approach where you call the transform
explicitly on data points or batches:
transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()])
dataset = TUDataset(path, name='MUTAG', transform=transform)
# Call the transform on the first observation in the datset
data = transform(dataset[0])
I caution against attaching transforms
directly to the dataset so that you don’t make the same mistakes I did:
Pre-process on Clean Data
What the tutorials don’t consider is that people often do a lot of things with their datasets before starting to train on them. Attaching transforms
directly to the dataset means your data will be augmented anytime you access it. This muddles crucial pre-processing steps such as normalization, standardization, and imputation et al. where you want to obtain static and reproducible constants. If you’re calculating normalization constants on your training dataset and then applying those same constants for inference, the last thing you want is reproducibility issues. Before wising up I actually implemented this terrible sin several places in my pre-processing code to mitigate the issue — writing this blog is part of my atonement:
# Temporarily remove the transforms
original_transform = None
if dataset.transform is not None:
original_transform = dataset.transform
dataset.transform = None
# Normalize data
...
# Restore the transforms
dataset.transform = original_transform
Don’t do this! Take control of your transforms
and apply them only when you need to. Don’t attach them at the hip of your dataset.
Guard Your Original Data
In the midst of handling transformations, it’s easy to overlook a crucial point: preserving the integrity of your original data. In-place transformations can inadvertently modify your data, leading to potential inconsistencies and headaches down the line. In my case, I found augmentations were being compounded epoch-by-epoch as my original data was overwritten. In fact, this mistake was carried on from official PyG examples I was building off. Fortunately, the PyG team was able to address the issue, which should hopefully be in the next release. In the meantime, remember to create shallow copies before you transform!
Augment the Dataset, Not the Training Time
Attaching transforms to your dataset seems like an intuitive and efficient approach, but losing control over when and where your augmentation is applied can mean significant slow downs with hidden causes. In my case, I had implemented a custom PyG Dataset
class following the official PyG example where I had some attribute self.data_list
which contained a list of independent graphs, represented by PyG Data
objects. It was only after my epoch training time increased from 22 seconds to 38 minutes that I realized the data augmentations were being called for every graph in the data_list
each time a single mini-batch was being loaded. I had never implemented a custom __getitem__
method in my custom dataset and was relying on inheriting the method from the base Dataset
class. It had never failed me until I realized that it was indexing my entire data_list
and applying transforms on it without my knowledge.
PyTorch and PyTorch Geometric Datasets
are incredibly powerful tools with excellent flexibility, but that can make things go awry if you haven’t stored your dataset in canonical ways, and even sometimes if you do. . . This was the last straw for me. I was determined to take control of the randomness in my life and call transform(data)
explicitly!
Composing Augmentation with PyTorch Lightning
Lastly, let’s talk about PyTorch Lightning. While it’s a fantastic tool that simplifies a lot of the boilerplate code in PyTorch, it comes with its own quirks. Certain aspects require a little extra care and handling to ensure that your model training runs smoothly. Lightning abstracts away most of the training loop and requires users simply specify train_dataloader
and val_dataloader
methods to return some iterator, generally a PyTorch DataLoader
.2 These methods can be implemented either directly in the LightningModule
or in the optional LightningDataModule
. After this, mini-batches are sampled and dispatched according to Lightning’s smart “auto” acceleration strategies which generally means they end up where you want them to go, such as on the GPU/TPU.
Attaching transforms
directly to your Dataset
seems to fit in well with this paradigm and just like most things with PyTorch Lightning, somehow they are magically taken care of! The problem is when they are taken care of. Depending on the operations underpinning the augmentation strategies, it may be more efficient to apply them on the CPU or accelerator. Some augmentations, like those from Kornia are designed to be entirely differentiable and may be faster operating directly on tensors already stored on the accelerator. Most strategies are best done on CPU which can take advantage of PyTorch’s multiprocessing features built into DataLoader
workers.3
Lightning has implemented two very useful data hooks which allow users to have more control over what happens when batches are loaded before and after they are sent to accelerators. These can be implemented as methods in either the LightningModule
or the LightningDataModule
:
# Operates on a mini-batch before it is transferred to the accelerator
on_before_batch_transfer(batch, dataloader_idx):
# Operates on a mini-batch after it is transferred to the accelerator
on_after_batch_transfer(batch, dataloader_idx)
Reading this excellent Lightning tutorial on batched data augmentation finally gave me the tools I needed to take full control of how my data augmentation was being performed. A simple additional hook in my LightningDataModule
solved the last of my problems:
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
"""
Called before a batch is transferred to the device.
We apply data augmentation here only during training.
"""
if self.trainer.training:
batch = self.transforms(batch)
return batch
This ensured that my augmentation was performed (1) only on specific mini-batches as they are loaded, (2) not on validation or testing data, (3) on the correct hardware, and (4) only within the training loop.
Conducting the Symphony of Randomness
Randomness is not just fundamental to life but also embedded deep within machine learning craft. Respect and wield the power of randomness wisely. Balance it with stability and reproducibility. Navigating through the uncertain terrains of machine learning means not only adapting to randomness but also shaping it to our advantage. In doing so, we turn Poitier’s quote on its head: life may be determined by pure randomness, but we are its agent.
If you’re interested in more data science / python related blogs, please check out https://python-bloggers.com
See https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html.↩︎
There are also optional
test_dataloader
options as well.↩︎See https://pytorch.org/docs/stable/data.html#multi-process-data-loading.↩︎