Image classification of Chest X Rays in one of three classes: Normal, Viral Pneumonia, COVID-19
Dataset from COVID-19 Radiography Dataset on Kaggle
%matplotlib inline
import os
import shutil
import random
import torch
import torchvision
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
torch.manual_seed(0)
print('Using PyTorch version', torch.__version__)
Using PyTorch version 1.13.1
/Users/thomas/miniforge3/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
class_names = ['normal', 'viral', 'covid']
root_dir = '/Users/thomas/Desktop/Data/COVID-19_Radiography_Dataset'
source_dirs = ['normal', 'viral', 'covid']
if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
os.mkdir(os.path.join(root_dir, 'test'))
for i, d in enumerate(source_dirs):
os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))
for c in class_names:
os.mkdir(os.path.join(root_dir, 'test', c))
for c in class_names:
images = [x for x in os.listdir(os.path.join(root_dir, c,'images')) if x.endswith('png')]
selected_images = random.sample(images, 30)
for image in selected_images:
source_path = os.path.join(root_dir, c, 'images', image)
target_path = os.path.join(root_dir, 'test', c, image)
shutil.move(source_path, target_path)
class ChestXRayDataset(torch.utils.data.Dataset):
def __init__(self, image_dirs, transform):
def get_images(class_name):
images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
print(f'Found {len(images)} {class_name} examples')
return images
self.images = {}
self.class_names = ['normal', 'viral', 'covid']
for class_name in self.class_names:
self.images[class_name] = get_images(class_name)
self.image_dirs = image_dirs
self.transform = transform
def __len__(self):
return sum([len(self.images[class_name]) for class_name in self.class_names])
def __getitem__(self, index):
class_name = random.choice(self.class_names)
index = index % len(self.images[class_name])
image_name = self.images[class_name][index]
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
return self.transform(image), self.class_names.index(class_name)
train_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(224, 224)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_dirs = {
'normal': f'{root_dir}/normal/images',
'viral': f'{root_dir}/viral/images',
'covid': f'{root_dir}/covid/images'
}
train_dataset = ChestXRayDataset(train_dirs, train_transform)
Found 10162 normal examples Found 1315 viral examples Found 3586 covid examples
test_dirs = {
'normal': f'{root_dir}/test/normal',
'viral': f'{root_dir}/test/viral',
'covid': f'{root_dir}/test/covid'
}
test_dataset = ChestXRayDataset(test_dirs, test_transform)
Found 30 normal examples Found 30 viral examples Found 30 covid examples
batch_size = 6
dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print('Number of training batches', len(dl_train))
print('Number of test batches', len(dl_test))
Number of training batches 2511 Number of test batches 15
class_names = train_dataset.class_names
def show_images(images, labels, preds):
plt.figure(figsize=(8, 4))
for i, image in enumerate(images):
plt.subplot(1, 6, i + 1, xticks=[], yticks=[])
image = image.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = image * std + mean
image = np.clip(image, 0., 1.)
plt.imshow(image)
col = 'green'
if preds[i] != labels[i]:
col = 'red'
plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=col)
plt.tight_layout()
plt.show()
images, labels = next(iter(dl_train))
show_images(images, labels, labels)
images, labels = next(iter(dl_test))
show_images(images, labels, labels)
resnet18 = torchvision.models.resnet18(pretrained=True)
print(resnet18)
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) )
/Users/thomas/miniforge3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /Users/thomas/miniforge3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=3e-5)
def show_preds():
resnet18.eval()
images, labels = next(iter(dl_test))
outputs = resnet18(images)
_, preds = torch.max(outputs, 1)
show_images(images, labels, preds)
show_preds()
def train(epochs):
print('Starting training..')
for e in range(0, epochs):
print('='*20)
print(f'Starting epoch {e + 1}/{epochs}')
print('='*20)
train_loss = 0.
val_loss = 0.
resnet18.train() # set model to training phase
for train_step, (images, labels) in enumerate(dl_train):
optimizer.zero_grad()
outputs = resnet18(images)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
if train_step % 20 == 0:
print('Evaluating at step', train_step)
accuracy = 0
resnet18.eval() # set model to eval phase
for val_step, (images, labels) in enumerate(dl_test):
outputs = resnet18(images)
loss = loss_fn(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
accuracy += sum((preds == labels).numpy())
val_loss /= (val_step + 1)
accuracy = accuracy/len(test_dataset)
print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')
show_preds()
resnet18.train()
if accuracy >= 0.995:
print('Performance condition satisfied, stopping..')
return
train_loss /= (train_step + 1)
print(f'Training Loss: {train_loss:.4f}')
print('Training complete..')
%%time
train(epochs=1)
Starting training.. ==================== Starting epoch 1/1 ==================== Evaluating at step 0 Validation Loss: 1.1196, Accuracy: 0.3556
Evaluating at step 20 Validation Loss: 0.8646, Accuracy: 0.6444
Evaluating at step 40 Validation Loss: 0.5868, Accuracy: 0.7556
Evaluating at step 60 Validation Loss: 0.5104, Accuracy: 0.8667
Evaluating at step 80 Validation Loss: 0.3629, Accuracy: 0.8889
Evaluating at step 100 Validation Loss: 0.2339, Accuracy: 0.9333
Evaluating at step 120 Validation Loss: 0.2098, Accuracy: 0.9333
Evaluating at step 140 Validation Loss: 0.2151, Accuracy: 0.9222
Evaluating at step 160 Validation Loss: 0.1966, Accuracy: 0.9667
Evaluating at step 180 Validation Loss: 0.1482, Accuracy: 0.9444
Evaluating at step 200 Validation Loss: 0.1808, Accuracy: 0.9333
Evaluating at step 220 Validation Loss: 0.1753, Accuracy: 0.9556
Evaluating at step 240 Validation Loss: 0.2091, Accuracy: 0.9444
Evaluating at step 260 Validation Loss: 0.1247, Accuracy: 0.9778
Evaluating at step 280 Validation Loss: 0.1527, Accuracy: 0.9444
Evaluating at step 300 Validation Loss: 0.1118, Accuracy: 0.9889
Evaluating at step 320 Validation Loss: 0.2177, Accuracy: 0.9222
Evaluating at step 340 Validation Loss: 0.1131, Accuracy: 0.9556
Evaluating at step 360 Validation Loss: 0.1544, Accuracy: 0.9111
Evaluating at step 380 Validation Loss: 0.1667, Accuracy: 0.9444
Evaluating at step 400 Validation Loss: 0.1877, Accuracy: 0.9556
Evaluating at step 420 Validation Loss: 0.0843, Accuracy: 0.9778
Evaluating at step 440 Validation Loss: 0.1029, Accuracy: 0.9556
Evaluating at step 460 Validation Loss: 0.1424, Accuracy: 0.9556
Evaluating at step 480 Validation Loss: 0.1003, Accuracy: 0.9667
Evaluating at step 500 Validation Loss: 0.1160, Accuracy: 0.9444
Evaluating at step 520 Validation Loss: 0.1007, Accuracy: 0.9667
Evaluating at step 540 Validation Loss: 0.1030, Accuracy: 0.9889
Evaluating at step 560 Validation Loss: 0.0995, Accuracy: 0.9667
Evaluating at step 580 Validation Loss: 0.1543, Accuracy: 0.9444
Evaluating at step 600 Validation Loss: 0.0523, Accuracy: 0.9889
Evaluating at step 620 Validation Loss: 0.0726, Accuracy: 0.9667
Evaluating at step 640 Validation Loss: 0.1794, Accuracy: 0.9333
Evaluating at step 660 Validation Loss: 0.0737, Accuracy: 0.9778
Evaluating at step 680 Validation Loss: 0.1839, Accuracy: 0.9000
Evaluating at step 700 Validation Loss: 0.0737, Accuracy: 1.0000
Performance condition satisfied, stopping.. CPU times: user 13min 20s, sys: 2min 36s, total: 15min 57s Wall time: 4min 51s
show_preds()
torch.save(resnet18.state_dict(), 'covid_classifier.pt')
# Load the model and set in eval
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
resnet18.load_state_dict(torch.load('covid_classifier.pt'))
resnet18.eval()
def predict_image_class(image_path):
image = Image.open(image_path).convert('RGB')
image = test_transform(image)
# Please note that the transform is defined already in a previous code cell
image = image.unsqueeze(0)
output = resnet18(image)[0]
probabilities = torch.nn.Softmax(dim=0)(output)
probabilities = probabilities.cpu().detach().numpy()
predicted_class_index = np.argmax(probabilities)
predicted_class_name = class_names[predicted_class_index]
return probabilities, predicted_class_index, predicted_class_name
image_path = '/Users/thomas/Desktop/Data/COVID-19_Radiography_Dataset/covid/images/COVID-1003.png'
probabilities, predicted_class_index, predicted_class_name = predict_image_class(image_path)
print('Probabilities:', probabilities)
print('Predicted class index:', predicted_class_index)
print('Predicted class name:', predicted_class_name)
Probabilities: [1.9799490e-01 3.5683042e-05 8.0196941e-01] Predicted class index: 2 Predicted class name: covid