Introduction to Image Classification using CNNs in PyTorch

Subscribe to my newsletter and never miss my upcoming articles

In this article, we will learn about using CNNs in PyTorch. We will use the standard CIFAR10 dataset and perform image classification on it.

Importing Libraries

import torch
import matplotlib.pyplot as plt
import numpy as np

Data loading

import torchvision
import torchvision.transforms as transforms
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                        download=True, 
                                        transform=transforms.ToTensor())

Output:

Files already downloaded and verified
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
dataiter = iter(trainloader)
images, labels = dataiter.next()

print(images.shape)

print(images[1].shape)
print(labels[1].item())

Output:

torch.Size([4, 3, 32, 32])
torch.Size([3, 32, 32])
0

Visualise data

img = images[1]
print(type(img))

Output:

<class 'torch.Tensor'>
npimg = img.numpy()
print(npimg.shape)

Output:

(3, 32, 32)
npimg = np.transpose(npimg, (1, 2, 0))
print(npimg.shape)

Output:

(32, 32, 3)
plt.figure(figsize = (1,1))
plt.imshow(npimg)
plt.show()

Output:

output_10_0.png

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
imshow(torchvision.utils.make_grid(images))
print(' '.join(classes[labels[j]] for j in range(4)))

Output:

output_12_0.png

ship plane ship ship

Single Convolutional Layer

import torch.nn as nn

class FirstCNN(nn.Module):
    def __init__(self): 
        super(FirstCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=(1,1), stride=(2,2)) # padding=(1,1), stride=(2,2))

    def forward(self, x):
        x = self.conv1(x)
        return x
net = FirstCNN()
out = net(images)
out.shape

Output:

torch.Size([4, 16, 16, 16])
for param in net.parameters():
    print(param.shape)

Output:

torch.Size([16, 3, 3, 3])
torch.Size([16])
out1 = out[0, 0, :, :].detach().numpy()
print(out1.shape)

Output:

(30, 30)
plt.imshow(out[0, 0, :, :].detach().numpy())
plt.show()

Output:

output_19_0.png

Deep Convolutional Network

class FirstCNN_v2(nn.Module):
    def __init__(self): 
        super(FirstCNN_v2, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 8, 3),   # (N, 3, 32, 32) -> (N, 8, 30, 30)
            nn.Conv2d(8, 16, 3)   # (N, 8, 30, 30) -> (N, 16, 28, 28)
        )

    def forward(self, x):
        x = self.model(x)
        return x
net = FirstCNN_v2()
out = net(images)
out.shape

Output:

torch.Size([4, 16, 28, 28])
plt.imshow(out[0, 0, :, :].detach().numpy())

Output:

<matplotlib.image.AxesImage at 0x7fa818bc6710>

output_23_1.png

class FirstCNN_v3(nn.Module):
    def __init__(self): 
        super(FirstCNN_v3, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 6, 5),          # (N, 3, 32, 32) -> (N, 6, 28, 28)
            nn.AvgPool2d(2, stride=2),   # (N, 6, 28, 28) -> (N, 6, 14, 14)
            nn.Conv2d(6, 16, 5),         # (N, 6, 14, 14) -> (N, 16, 10, 10)
            nn.AvgPool2d(2, stride=2)    # (N, 16, 10, 10) -> (N, 16, 5, 5)
        )

    def forward(self, x):
        x = self.model(x)
        return x
net = FirstCNN_v3()
out = net(images)
out.shape

Output:

torch.Size([4, 16, 5, 5])
plt.imshow(out[0, 0, :, :].detach().numpy())

Output:

<matplotlib.image.AxesImage at 0x7fa818be2400>

output_26_1.png

LeNet

class LeNet(nn.Module):
    def __init__(self): 
        super(LeNet, self).__init__()
        self.cnn_model = nn.Sequential(
            nn.Conv2d(3, 6, 5),         # (N, 3, 32, 32) -> (N,  6, 28, 28)
            nn.Tanh(),
            nn.AvgPool2d(2, stride=2),  # (N, 6, 28, 28) -> (N,  6, 14, 14)
            nn.Conv2d(6, 16, 5),        # (N, 6, 14, 14) -> (N, 16, 10, 10)  
            nn.Tanh(),
            nn.AvgPool2d(2, stride=2)   # (N,16, 10, 10) -> (N, 16, 5, 5)
        )
        self.fc_model = nn.Sequential(
            nn.Linear(400,120),         # (N, 400) -> (N, 120)
            nn.Tanh(),
            nn.Linear(120,84),          # (N, 120) -> (N, 84)
            nn.Tanh(),
            nn.Linear(84,10)            # (N, 84)  -> (N, 10)
        )

    def forward(self, x):
        print(x.shape)
        x = self.cnn_model(x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        print(x.shape)
        x = self.fc_model(x)
        print(x.shape)
        return x
net = LeNet()
out = net(images)

Output:

torch.Size([4, 3, 32, 32])
torch.Size([4, 16, 5, 5])
torch.Size([4, 400])
torch.Size([4, 10])
print(out)

Output:

tensor([[ 0.0818,  0.0688, -0.0482, -0.1218, -0.1274, -0.0431,  0.1100,  0.0679,
          0.0587, -0.1229],
        [ 0.0884,  0.0577, -0.0446, -0.1189, -0.1232, -0.0427,  0.1160,  0.0535,
          0.0494, -0.1148],
        [ 0.0672,  0.0677, -0.0417, -0.1379, -0.1311, -0.0554,  0.1107,  0.0425,
          0.0451, -0.1298],
        [ 0.0750,  0.0620, -0.0458, -0.1272, -0.1160, -0.0595,  0.1157,  0.0523,
          0.0496, -0.1173]], grad_fn=<AddmmBackward>)
max_values, pred_class = torch.max(out.data, 1)
print(pred_class)

Output:

tensor([6, 6, 6, 6])

Training LeNet

class LeNet(nn.Module):
    def __init__(self): 
        super(LeNet, self).__init__()
        self.cnn_model = nn.Sequential(
            nn.Conv2d(3, 6, 5),         # (N, 3, 32, 32) -> (N,  6, 28, 28)
            nn.Tanh(),
            nn.AvgPool2d(2, stride=2),  # (N, 6, 28, 28) -> (N,  6, 14, 14)
            nn.Conv2d(6, 16, 5),        # (N, 6, 14, 14) -> (N, 16, 10, 10)  
            nn.Tanh(),
            nn.AvgPool2d(2, stride=2)   # (N,16, 10, 10) -> (N, 16, 5, 5)
        )
        self.fc_model = nn.Sequential(
            nn.Linear(400,120),         # (N, 400) -> (N, 120)
            nn.Tanh(),
            nn.Linear(120,84),          # (N, 120) -> (N, 84)
            nn.Tanh(),
            nn.Linear(84,10)            # (N, 84)  -> (N, 10)
        )

    def forward(self, x):
        x = self.cnn_model(x)
        x = x.view(x.size(0), -1)
        x = self.fc_model(x)
        return x
batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Output:

Files already downloaded and verified
Files already downloaded and verified
def evaluation(dataloader):
    total, correct = 0, 0
    for data in dataloader:
        inputs, labels = data
        outputs = net(inputs)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
    return 100 * correct / total
net = LeNet()
import torch.optim as optim

loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(net.parameters())
%%time
loss_arr = []
loss_epoch_arr = []
max_epochs = 16

for epoch in range(max_epochs):

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data

        opt.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        opt.step()

        loss_arr.append(loss.item())

    loss_epoch_arr.append(loss.item())

    print('Epoch: %d/%d, Test acc: %0.2f, Train acc: %0.2f' % (epoch, max_epochs, evaluation(testloader), evaluation(trainloader)))


plt.plot(loss_epoch_arr)
plt.show()

Output:

Epoch: 0/16, Test acc: 38.39, Train acc: 38.13
Epoch: 1/16, Test acc: 43.67, Train acc: 43.74
Epoch: 2/16, Test acc: 46.30, Train acc: 46.62
Epoch: 3/16, Test acc: 49.37, Train acc: 50.37
Epoch: 4/16, Test acc: 50.15, Train acc: 51.86
Epoch: 5/16, Test acc: 52.14, Train acc: 54.40
Epoch: 6/16, Test acc: 52.72, Train acc: 56.28
Epoch: 7/16, Test acc: 53.53, Train acc: 57.73
Epoch: 8/16, Test acc: 54.44, Train acc: 58.82
Epoch: 9/16, Test acc: 54.61, Train acc: 59.97
Epoch: 10/16, Test acc: 55.91, Train acc: 61.58
Epoch: 11/16, Test acc: 55.41, Train acc: 61.88
Epoch: 12/16, Test acc: 55.28, Train acc: 63.09
Epoch: 13/16, Test acc: 56.54, Train acc: 64.56
Epoch: 14/16, Test acc: 56.37, Train acc: 64.63
Epoch: 15/16, Test acc: 56.54, Train acc: 66.50

output_38_1.png

CPU times: user 7min 39s, sys: 9.16 s, total: 7min 48s
Wall time: 7min 49s

Move to GPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Output:

cuda:0
def evaluation(dataloader):
    total, correct = 0, 0
    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
    return 100 * correct / total
net = LeNet().to(device)
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(net.parameters())
%%time
max_epochs = 16

for epoch in range(max_epochs):

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        opt.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        opt.step()

    print('Epoch: %d/%d' % (epoch, max_epochs))

Output:

Epoch: 0/16
Epoch: 1/16
Epoch: 2/16
Epoch: 3/16
Epoch: 4/16
Epoch: 5/16
Epoch: 6/16
Epoch: 7/16
Epoch: 8/16
Epoch: 9/16
Epoch: 10/16
Epoch: 11/16
Epoch: 12/16
Epoch: 13/16
Epoch: 14/16
Epoch: 15/16
CPU times: user 1min 37s, sys: 1.97 s, total: 1min 38s
Wall time: 1min 39s
print('Test acc: %0.2f, Train acc: %0.2f' % (evaluation(testloader), evaluation(trainloader)))

Output:

Test acc: 55.23, Train acc: 65.33

Basic Visualisation

imshow(torchvision.utils.make_grid(images))

Output:

output_46_0.png

net = net.to('cpu')
out = net(images)
print(out.shape)

Output:

torch.Size([4, 10])
out = net.cnn_model[0](images)
out.shape

Output:

torch.Size([4, 6, 28, 28])
image_id = 3
plt.figure(figsize = (2,2))
imshow(images[image_id,])

Output:

output_50_0.png

plt.figure(figsize = (6,6))
plt.subplot(321)
for i in range(6):
    ax1 = plt.subplot(3, 2, i+1)
    plt.imshow(out[image_id, i, :, :].detach().numpy(), cmap="binary")
plt.show()

Output:

/usr/local/lib/python3.6/dist-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: 
Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.
  "Adding an axes using the same arguments as a previous axes "

output_51_1.png

Thanks for Reading:) Please share any feedbacks down in the comments 👨‍💻️.

Btw, New Year is coming !

happy.new_.year_[1].gif

Nisarg Kapkar's photo

Really informative! Thanks for sharing!

Aankit sharma's photo

Nicely done!