Data Augmentation Process - ( Image Classification with PyTorch )
Heading h2
Syntax
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
Example
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
train_transform = transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = torchvision.datasets.ImageFolder(root='train', transform=train_transform)
test_dataset = torchvision.datasets.ImageFolder(root='test', transform=test_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
classes = train_dataset.classes
data_iter = iter(train_loader)
images, labels = data_iter.next()
def imshow(img):
img = img / 2 + 0.5
plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
plt.show()
imshow(torchvision.utils.make_grid(images))
print([classes[labels[j]] for j in range(32)])
Output
The output of the above code will be a grid of randomly augmented images from training data along with their respective labels.
Explanation
Data augmentation is the process of generating new data based on the existing data to improve the performance of a model. In PyTorch, this can be achieved using the torchvision.transforms
module, which provides a set of predefined image transformations.
In the above example, two different transforms are defined: train_transform
which includes random rotation, random horizontal flip, resizing, converting to tensor and normalization, whereas test_transform
includes resizing, converting to tensor and normalization.
The ImageFolder
class is then used to create train and test datasets from the available directories of images and the respective transforms are applied to them.
The DataLoader
class is used to create data loaders for the train and test datasets. Finally, a set of transformed images is plotted using the imshow()
function to observe the effect of transformation on images.
Use
Data augmentation is an essential technique in deep learning to improve the performance of a model. It is particularly useful when working with image data. Data augmentation helps to increase the size of the training set, introduce variability in the training data, and reduce overfitting.
Important Points
- Data augmentation is the process of generating new data based on the existing data to improve the performance of a model
- PyTorch provides
torchvision.transforms
module for various image transformations ImageFolder
class is used to create train and test datasetsDataLoader
class is used to create data loaders for the train and test datasets- Data augmentation helps to increase the size of the training set, introduce variability in the training data, and reduce overfitting
Summary
In conclusion, data augmentation is an essential technique in deep learning to improve the performance of a model. PyTorch provides the torchvision.transforms
module which can be used to perform various image transformations for data augmentation. The ImageFolder
and DataLoader
classes are used to create train and test datasets, and data loaders respectively. Data augmentation helps to improve the generalization of the model by introducing variability in the training data, reducing overfitting, and hence helps to improve model performance.