Loading

AI Blitz #8

Smoke Elimination Using Conditional GAN

A Conditional GAN to remove smoke from images

devesh_darshan

This is an implementation of a Conditional GAN to eliminate smoke from images using pytorch and fastai.

Downloading Dataset

In [ ]:
!pip install --upgrade aicrowd-cli
In [ ]:
API_KEY = ""
!aicrowd login --api-key $API_KEY
In [ ]:
!aicrowd dataset download --challenge f1-smoke-elimination -j 3
In [ ]:
!rm -rf data
!mkdir data


!unzip train.zip -d data/train >/dev/null
!unzip val.zip -d data/val >/dev/null
!unzip test.zip -d data/test >/dev/null
!unzip sample_submission.zip -d data/sample_submission >/dev/null
In [ ]:
!pip install fastai --upgrade
In [ ]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import os

import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import cv2

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
In [ ]:
class config:
  
  TRAIN_IMAGES = "/content/data/train"
  VAL_IMAGES = "/content/data/val"

cfg = config()

creating training,validation and testing csv

In [ ]:
#training
filenames = os.listdir("/content/data/train/clear")
clear = [os.path.join(cfg.TRAIN_IMAGES, "clear", x) for x in filenames]
smokes = [os.path.join(cfg.TRAIN_IMAGES, "smoke", x) for x in filenames]
In [ ]:
dfx = pd.DataFrame({
    "filename":filenames,
    "clear":clear,
    "smoke":smokes
})
In [ ]:
#validation
filenames = os.listdir("/content/data/val/clear")
clear = [os.path.join(cfg.VAL_IMAGES, "clear", x) for x in filenames]
smokes = [os.path.join(cfg.VAL_IMAGES, "smoke", x) for x in filenames]
In [ ]:
dfx_valid = pd.DataFrame({
    "filename":filenames,
    "clear":clear,
    "smoke":smokes
})
In [ ]:
#testing
filenames = os.listdir("/content/data/test/smoke")
smokes = [os.path.join("/content/data/test", "smoke", x) for x in filenames]
In [ ]:
dfx_test = pd.DataFrame({
    "filename":filenames,
    "smoke":smokes
})
In [ ]:
dfx_valid.head(3)
Out[ ]:
filename clear smoke
0 320.jpg /content/data/val/clear/320.jpg /content/data/val/smoke/320.jpg
1 1450.jpg /content/data/val/clear/1450.jpg /content/data/val/smoke/1450.jpg
2 238.jpg /content/data/val/clear/238.jpg /content/data/val/smoke/238.jpg

Training methodology:

  1. Generator: Comprises of a U-Net architecture with a pretrained backbone (here, resnet18), inspired from the paper: Image-to-Image Translation with Conditional Adversarial Networks.
  2. Discriminator: A Patch Discriminator inspired from Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
  3. Loss Functions: L1 Loss (Mean Absolute Error) for supervised training, and GAN Loss (a binary cross entropy loss with fake[0] and real[1] labels ) for unsupervised training.

Important Points !!

Here, one of the most important thing is to take care of the mean and standard deviation of the images.

  1. As we are using a pretrained backbone (here, resnet18 trained on ImageNet Dataset), the input image (image with smoke) has to normalized according to ImageNet stats.
  2. As the images generated by the generator have pixel range of [-1,1], (because of the tanH function applied at the end), the clear image has to be scaled to this [-1,1] range also.
  3. The second point will ensure that there is no disperancy while calculating the L1 Loss as well as these images will be then fed into the discriminator, hence they have to be in the same range.

Utils, Dataset and DataLoader

In [ ]:
#some utility functions important for training and visualization

#ImageNet stats
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

#scaling function
def scaling(X, high, low):
  X_std = (X - X.min()) / (X.max() - X.min())
  X_scaled = X_std * (high - low) + low
  return X_scaled

#to convert image from numpy(H,W,C) to tensor(C,H,W)
def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

#unnormalizing function
def unnormalize(img_tensor, type = "1-1"):
  img = img_tensor.permute((1,2,0)).cpu().detach().numpy()
  if type == '1-1':
    img = (img + 1)/2
    img = scaling(img, 255, 0)
  elif type == 'imagenet':
    img = ((img * std) + mean)
    img = scaling(img, 255, 0)
  elif type == '0-1':
    img = img * 255.0
  return img.astype('uint8')
In [ ]:
#Dataset Class
class Data(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
    
    def __getitem__(self, item):
        cl = self.dataframe.iloc[item]['clear']
        sm = self.dataframe.iloc[item]['smoke']
        cl_image = np.asarray(Image.open(cl).convert("RGB").resize((224,224))) #resizing from 256 to 224 because training time for one epoch was over 20 minutes for size of 256.
        sm_image = np.asarray(Image.open(sm).convert("RGB").resize((224,224))) #reducing the size to 224, reduced the training time for one epoch to about 10 minutes.
        
        return {
            "clear": img2tensor(scaling(cl_image, 1, -1)), #clear image scaled between [-1,1]
            "smoke": img2tensor((sm_image/255.0 - mean)/std) #smoke image normalize acc to ImageNet stats
        }
    
    def __len__(self):
        return self.dataframe.shape[0]
In [ ]:
#creating dataset and dataloader

ds = Data(dfx)
ds_valid = Data(dfx_valid)

train_dl = DataLoader(ds, batch_size = 16, num_workers = 4, pin_memory=True, shuffle = True)
val_dl = DataLoader(ds_valid, batch_size = 16, num_workers = 4, pin_memory=True)

Generator and Discriminator

In [ ]:
#Generator

from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18, resnet34
from fastai.vision.models.unet import DynamicUnet

def build_res_unet(n_input=3, n_output=3, size=224):
    body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size))
    return net_G
In [ ]:
#Disciminator

class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, batch_norm=False, activation=None):
        super(ConvolutionalBlock, self).__init__()

        if activation is not None:
            activation = activation.lower()
            assert activation in {'prelu', 'leakyrelu', 'tanh'}
        layers = list()
        layers.append(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=kernel_size // 2))

        if batch_norm is True:
            layers.append(nn.BatchNorm2d(num_features=out_channels))

        # An activation layer, if wanted
        if activation == 'prelu':
            layers.append(nn.PReLU())
        elif activation == 'leakyrelu':
            layers.append(nn.LeakyReLU(0.2))
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        self.conv_block = nn.Sequential(*layers)

    def forward(self, input):
        output = self.conv_block(input)  # (N, out_channels, w, h)

        return output

class Discriminator(nn.Module):
    def __init__(self, kernel_size=3, n_channels=64, n_blocks=8, fc_size=1024):
        super(Discriminator, self).__init__()

        in_channels = 3
        conv_blocks = list()
        for i in range(n_blocks):
            out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
            conv_blocks.append(
                ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))
            in_channels = out_channels
        self.conv_blocks = nn.Sequential(*conv_blocks)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))

        self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)

        self.leaky_relu = nn.LeakyReLU(0.2)

        self.fc2 = nn.Linear(1024, 1)

    def forward(self, imgs):
        batch_size = imgs.size(0)
        output = self.conv_blocks(imgs)
        output = self.adaptive_pool(output)
        output = self.fc1(output.view(batch_size, -1))
        output = self.leaky_relu(output)
        logit = self.fc2(output)

        return logit
In [ ]:
#GAN Loss

class GANLoss(nn.Module):
    def __init__(self, real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        self.loss = nn.BCEWithLogitsLoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

Main Training Module

In [ ]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        self.net_G = net_G.to(self.device)
        self.net_D = Discriminator().to(self.device)
        self.GANcriterion = GANLoss().to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.cl = data['clear'].to(self.device)
        self.sm = data['smoke'].to(self.device)
        
    def forward(self):
        self.fake_cl = torch.tanh(self.net_G(self.sm)) #tanH function to ensure that the outputs are between [-1,1]
    
    #to train discriminator
    def backward_D(self):
        fake_preds = self.net_D(self.fake_cl.detach()) #here we are detaching the generator output from the training graph 
                                                      #because we don't want the loss to backpropagate through generator also.
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_preds = self.net_D(self.cl)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    #to train generator
    def backward_G(self):
        fake_preds = self.net_D(self.fake_cl)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_cl, self.cl) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
      #training discriminator
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
      #training generator
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()
In [ ]:
#visualizing function to see the performance of the generator after defined iterations.

def visualize(model, data ,save=True):
    mse = []
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_cl = model.fake_cl #[-1,1]
    sm = model.sm #imagenet normalised
    cl = model.cl #[-1,1]
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        s = unnormalize(sm[i], type = 'imagenet')
        c = unnormalize(cl[i], type = '1-1')
        f = unnormalize(fake_cl[i], type = '1-1')
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(s)
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(c)
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(f)
        ax.axis("off")
        mse.append(np.mean((c - f)**2))
    plt.show()
    print(f"Mean MSE of 5 images {np.mean(mse)}")
    if save:
        fig.savefig(f"smoke_clear_{time.time()}.png")
In [ ]:
#Average meters to keep track of the losses

class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")
In [ ]:
#training function
def train_model(model, train_dl, epochs, display_every=100):
    data = next(iter(val_dl))
    for e in range(epochs):
        loss_meter_dict = create_loss_meters()
        i = 0                                  
        for data in tqdm(train_dl):
            model.setup_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['clear'].size(0))
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict)
                visualize(model, data, save=False)
In [ ]:
#the weights loaded here are of the model trained 10 epochs
model = MainModel(net_G = build_res_unet())
model.load_state_dict(torch.load("/content/drive/MyDrive/Projects/AI_Crowd/Smoke Elimination/gan.pt", map_location=device))
Out[ ]:
<All keys matched successfully>
In [ ]:
train_model(model, train_dl, 10)
Output hidden; open in https://colab.research.google.com to view.
In [ ]:
#saving weights
torch.save(model.net_G.state_dict(), "/content/drive/MyDrive/Projects/AI_Crowd/Smoke Elimination/res18-unet.pt")
torch.save(model.state_dict(), "/content/drive/MyDrive/Projects/AI_Crowd/Smoke Elimination/gan.pt")

Submission

In [ ]:
!rm -rf clear
!mkdir clear
In [ ]:
#testing
filenames = os.listdir("/content/data/test/smoke")
smokes = [os.path.join("/content/data/test", "smoke", x) for x in filenames]
In [ ]:
dfx_test = pd.DataFrame({
    "filename":filenames,
    "smoke":smokes
})
In [ ]:
class TestData(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
    
    def __getitem__(self, item):
        filename = self.dataframe.iloc[item]['filename']
        sm = self.dataframe.iloc[item]['smoke']
        sm_image = np.asarray(Image.open(sm).convert("RGB").resize((224,224)))
        
        return {
            "filename": filename,
            "smoke": img2tensor((sm_image/255.0 - mean)/std)
        }
    
    def __len__(self):
        return self.dataframe.shape[0]
In [ ]:
ds_test = TestData(dfx_test)
In [ ]:
import PIL
model.eval()
for idx in tqdm(range(len(ds_test))):
  filename = ds_test[idx]['filename']
  data = ds_test[idx]['smoke'].unsqueeze(0).cuda()
  out = torch.tanh(model.net_G(data)).squeeze(0)
  f = unnormalize(out, type = '1-1')
  f = Image.fromarray(f).resize((256, 256))
  f.save(os.path.join("clear", f"{filename}"))
In [ ]:
!zip submission.zip -r clear/ > /dev/null
In [ ]:
!aicrowd submission create -c f1-smoke-elimination -f submission.zip

Training Tips

  1. Result breakdown : for 10 epochs the lb score was 76, for 20 epochs the score came down to 54.
  2. I wasn't able to train more because of academic reasons :(
  3. To what extent should I train ? : The training should last till the discriminator loss starts to increase continously. The reason is, when the discriminator loss is increasing, it simply means that the generator is now generating images that even the discriminator is getting fooled, meaning the images generated are quite realistic.
  4. Pretraining your Generator: Generally GANs are trained for a very long time (about 100 epochs, sometimes over 100 epochs), to reduce this training time, one can pretrain the generator on the training data and then load these pretrain weights in the GAN and start training.
  5. More Loss Functions: To stablize training one can use Wasserstein Loss, and to compare high level differences between two similar images, Perceptual Loss.
  6. Image Size: To reduce training time and memory allocation, I reduced the image size from 256 to 224, but if you have a powerfull workstation, train on the size of 256, this will definitely give better results, as my lb result is greatly impacted due to resizing the image back to 256 from 224.

Thank You !


Comments

You must login before you can post a comment.

Execute