pytorch
  1. pytorch-mnist-dataset

MNIST Dataset - ( Image Recognition with PyTorch )

Heading h2

Syntax

PyTorch

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)

Example

PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net()

criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

epochs = 10
for e in range(epochs):
    running_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    else:
        print(f"Training loss: {running_loss/len(train_loader)}")

Output

PyTorch

Training loss: 0.22206621855286126
Training loss: 0.0886614791141126
Training loss: 0.06659299924081652
Training loss: 0.05443531216634639
Training loss: 0.04667200877327673
Training loss: 0.039775651846576455
Training loss: 0.03484904246182252
Training loss: 0.03386085802796453
Training loss: 0.02903437405911022
Training loss: 0.026367333260664717

Explanation

The MNIST dataset is a collection of handwritten digits that are commonly used for training image recognition models. It consists of 60,000 training images and 10,000 testing images of size 28x28 pixels. Each image is labeled with the corresponding digit.

The PyTorch code above shows how to load the MNIST dataset and create a simple convolutional neural network (CNN) to recognize handwritten digits. The CNN has two convolutional layers and two fully connected layers. The loss function used is negative log-likelihood loss (NLLLoss) and the optimizer used is Adam.

Use

The MNIST dataset is commonly used as a benchmark for image recognition models. It can be used to train and test CNNs, deep neural networks (DNNs), and other machine learning models.

Important Points

  • The MNIST dataset is a collection of handwritten digits used for training image recognition models
  • It consists of 60,000 training images and 10,000 testing images of size 28x28 pixels
  • The PyTorch code above shows how to load the MNIST dataset and create a simple CNN model
  • The CNN has two convolutional layers and two fully connected layers
  • The loss function used is negative log-likelihood loss (NLLLoss) and the optimizer used is Adam

Summary

In summary, the MNIST dataset is a widely used benchmark for image recognition models. It can be loaded and used in PyTorch to train and test CNNs and other machine learning models. The PyTorch code example above illustrates how to load the dataset and create a simple CNN model for recognizing handwritten digits.

Published on: