Facial Expression ResNet¶

In [2]:
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torchvision.models as models
%matplotlib inline
In [3]:
data_dir = 'Dataset'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)
len
['test', 'labels.txt', 'train']
['fear', 'surprise', 'sadness', 'neutral', 'happiness', 'anger', 'disgust']
Out[3]:
<function len(obj, /)>
In [5]:
len(os.listdir(data_dir+'/train/sadness'))
Out[5]:
5483
In [7]:
train_tfms = tt.Compose([
#                          tt.RandomCrop(32, padding=4, padding_mode='reflect'),
                         tt.RandomHorizontalFlip(),
                         tt.RandomRotation(30),
                         tt.ColorJitter(brightness=0.1, contrast=0.25, saturation=0.35, hue=0.05),
                         tt.RandomRotation(10, expand=False, center=None, fill=None),
                         tt.ToTensor()
                        ])

valid_tfms = tt.Compose([tt.ToTensor()])
In [8]:
train_ds = ImageFolder(data_dir+'/train', train_tfms)
valid_ds = ImageFolder(data_dir+'/test', valid_tfms)
In [9]:
batch_size = 128
In [10]:
train_dl = DataLoader(
    train_ds, 
    batch_size, 
    shuffle=True, 
    num_workers=4, 
    pin_memory=True
                     )

valid_dl = DataLoader(
    valid_ds, 
    batch_size*2, 
    num_workers=4, 
    pin_memory=True
                     )
In [11]:
def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(16, 12))
        ax.set_xticks([]) 
        ax.set_yticks([])
        ax.imshow(make_grid(images[:64], nrow=16).permute(1, 2, 0))
        break
In [12]:
show_batch(train_dl)
In [13]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)
In [14]:
device = get_default_device()
device
Out[14]:
device(type='cpu')
In [15]:
train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)
In [16]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassification(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
In [17]:
class ResNet18(ImageClassification):
    def __init__(self, num_classes):
        super().__init__()
        
        self.network = models.resnet18(pretrained=True)
        num_ftrs = self.network.fc.in_features
        self.network.fc = nn.Linear(num_ftrs, num_classes)
        
    def forward(self, x):
        return self.network(x)
In [18]:
model = ResNet18(7)
model = to_device(model, device)
/Users/thomas/miniforge3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/Users/thomas/miniforge3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/thomas/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
  0%|          | 0.00/44.7M [00:00<?, ?B/s]
In [19]:
from tqdm.notebook import tqdm
In [20]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.Adam):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in tqdm(train_loader):
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history
In [21]:
history = [evaluate(model, valid_dl)]
history
Out[21]:
[{'val_loss': 2.684027910232544, 'val_acc': 0.02994791604578495}]
In [22]:
epochs = 25
max_lr = 0.05
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
In [23]:
%%time
history += fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, grad_clip=grad_clip, weight_decay=weight_decay, opt_func=opt_func)
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [0], last_lr: 0.00406, train_loss: 1.6611, val_loss: 1.8482, val_acc: 0.2928
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [1], last_lr: 0.00992, train_loss: 1.7300, val_loss: 1.7319, val_acc: 0.3085
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [2], last_lr: 0.01856, train_loss: 1.6462, val_loss: 1.6059, val_acc: 0.3702
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [3], last_lr: 0.02849, train_loss: 1.5116, val_loss: 1.6611, val_acc: 0.3869
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [4], last_lr: 0.03799, train_loss: 1.4874, val_loss: 1.4343, val_acc: 0.4660
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [5], last_lr: 0.04541, train_loss: 1.4927, val_loss: 1.5085, val_acc: 0.4231
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [6], last_lr: 0.04947, train_loss: 1.5001, val_loss: 1.9278, val_acc: 0.3288
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [7], last_lr: 0.04990, train_loss: 1.4964, val_loss: 1.4895, val_acc: 0.4457
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [8], last_lr: 0.04910, train_loss: 1.4920, val_loss: 1.6443, val_acc: 0.3742
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [9], last_lr: 0.04752, train_loss: 1.4967, val_loss: 1.6195, val_acc: 0.3502
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [10], last_lr: 0.04523, train_loss: 1.4969, val_loss: 1.6724, val_acc: 0.3827
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [11], last_lr: 0.04228, train_loss: 1.4838, val_loss: 1.7898, val_acc: 0.2256
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [12], last_lr: 0.03877, train_loss: 1.4774, val_loss: 1.6228, val_acc: 0.3634
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [13], last_lr: 0.03483, train_loss: 1.4704, val_loss: 1.6102, val_acc: 0.3926
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [14], last_lr: 0.03056, train_loss: 1.4525, val_loss: 1.4849, val_acc: 0.4395
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [15], last_lr: 0.02612, train_loss: 1.4423, val_loss: 1.5967, val_acc: 0.3918
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [16], last_lr: 0.02164, train_loss: 1.4195, val_loss: 1.4332, val_acc: 0.4630
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [17], last_lr: 0.01727, train_loss: 1.4058, val_loss: 1.4902, val_acc: 0.4108
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [18], last_lr: 0.01315, train_loss: 1.3789, val_loss: 1.3356, val_acc: 0.4869
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [19], last_lr: 0.00941, train_loss: 1.3568, val_loss: 1.2901, val_acc: 0.5106
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [20], last_lr: 0.00617, train_loss: 1.3294, val_loss: 1.2360, val_acc: 0.5267
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [21], last_lr: 0.00354, train_loss: 1.2994, val_loss: 1.2406, val_acc: 0.5307
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [22], last_lr: 0.00159, train_loss: 1.2760, val_loss: 1.1773, val_acc: 0.5497
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [23], last_lr: 0.00040, train_loss: 1.2574, val_loss: 1.1584, val_acc: 0.5552
  0%|          | 0/253 [00:00<?, ?it/s]
Epoch [24], last_lr: 0.00000, train_loss: 1.2502, val_loss: 1.1527, val_acc: 0.5614
CPU times: user 7h 46min 14s, sys: 52min 12s, total: 8h 38min 27s
Wall time: 2h 53min 23s
In [24]:
def plot_acc(history):
    acc = [x['val_acc'] for x in history]
    plt.plot(acc,'-x')
    plt.xlabel('epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs. no. of epochs')
In [25]:
plot_acc(history)
In [26]:
def plot_losses(history):
    train_loss = [x.get('train_loss') for x in history]
    val_loss = [x['val_loss'] for x in history]
    plt.plot(train_loss, '-bx')
    plt.plot(val_loss, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend('Training', 'Validation')
In [27]:
plot_losses(history)
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'T' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'r' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'a' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'i' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'n' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
/var/folders/16/3468kndx5l1_zj5r84tsybgc0000gn/T/ipykernel_63425/1029045736.py:8: UserWarning: Legend does not support 'g' instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend('Training', 'Validation')
In [28]:
def plot_lrs(history):
    lrs = np.concatenate([x.get('lrs',[]) for x in history])
    plt.plot(lrs)
    plt.xlabel('epoch')
    plt.ylabel('learning rate')
    plt.title('Learning Rate vs. no. of epochs')
In [29]:
plot_lrs(history)
In [ ]: