TorchData: PyTorch Composable Data loading utility library
Learn How to load image data with TorchData and train Image classifier
- In this tutorial, we will learn about TorchData
- Bonus: Training Image Classifier with TorchData x Lightning Flash ⚡️
PyTorch 1.11 came with a new libray called TorchData
. It provides common data loading primitives for easily constructing flexible and performant data pipelines. TorchData promotes composable data loading for code reusablity with DataPipes
.
DataPipes
is the building block of TorchData and works out of the box with PyTorch DataLoader
. It can be chained to form a data pipeline where data will be transformed from each DataPipe
.
For example if we have image dataset in a folder with a CSV mapping of classes and we want to create a DataLoader that returns batch of image Tensor and labels.
For this we need we do the following steps:
- Read and Parse the CSV
- a. Get image filepath b. Decode label
- Read Image
- Convert Image to Tensor
- Return image Tensor and Label index
These steps can be a chained DataPipe where the initial data will flow from the first step to the very last applying transformations in each step.
Now, lets see how to do the same with TorchData code.
# install libraries
!pip install torchdata lightning-flash -q
from torchdata.datapipes.iter import (
FileOpener,
FileLister
)
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_tensor
We will use CIFAR-10 dataset which has the same structure as we discussed.
FileOpener
and parse_csv
to read the csv data.
ROOT = "/Users/aniket/datasets/cifar-10/train"
csv_dp = FileLister(f"{ROOT}/../trainLabels.csv")
csv_dp = FileOpener(csv_dp)
csv_dp = csv_dp.parse_csv()
for i, e in enumerate(csv_dp):
if i>10: break
print(e)
We don't need the header of csv in our datapipe ([id, label]
), so we will pass skip_lines=1
to the parse_csv
method.
csv_dp = FileLister(f"{ROOT}/../trainLabels.csv")
csv_dp = FileOpener(csv_dp)
csv_dp = csv_dp.parse_csv(skip_lines=1)
for i, e in enumerate(csv_dp):
if i>10: break
print(e)
labels = {e: i for i, e in enumerate(set([e[1] for e in csv_dp]))}
Now, we have a DataPipe csv_dp
which flows file id
and label
. We need to convert the file id into filepath and label in label index.
We can map functions to the DataPipe and even form a chain of mapping to apply transformations.
def get_filename(data):
idx, label = data
return f"{ROOT}/{idx}.png", label
dp = csv_dp.map(get_filename)
for i, e in enumerate(dp):
if i>4: break
print(e)
from IPython.display import display
def load_image(data):
file, label = data
return Image.open(file), label
dp = dp.map(load_image)
for i, e in enumerate(dp):
display(e[0])
print(e[1])
if i>=5: break
Finally we map the datapipe to process image to Tensor and label to index.
def process(data):
img, label = data
return to_tensor(img), labels[label]
dp = dp.map(process)
Bonus: Training Image Classifier with TorchData x Lightning Flash ⚡️
If you have come this far then I have a bonus for you. Train an image classifier using DataPipe and PyTorch Lightning Flash ⚡️
Flash is a high-level deep learning framework for fast prototyping, baselining, finetuning and solving deep learning problems. It features a set of tasks for you to use for inference and finetuning out of the box, and an easy to implement API to customize every step of the process for full flexibility.
Flash expects the dataloader to be in form of a dictionary with keys input
and target
where input will contain our image tensor and target will be the label index.
dp = dp.map(lambda x: {"input": x[0], "target": x[1]})
As we discussed that DataPipes
are fully compatible with DataLoader so this is how you convert a DataPipe to DataLoader.
DataPipe
supports shuffling and batching as well but this must be used either in DataPipe or DataLoader, otherwise the samples will be batched/shuffled more than once.
dl = DataLoader(
dp,
batch_size=32,
shuffle=True,
)
Training an Image Classifier with Flash is super easy. Flash provides Deep Learning tasks based APIs that you can use to train your model. Currently, our task is image classification so let's import the ImageClassifier and build our model.
from flash.image import ImageClassifier
import flash
# Create an Image Classifier Model
model = ImageClassifier(num_classes=len(labels), backbone="efficientnet_b0", pretrained=False)
# Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.fit(model, dl)