Also can provide final weights
Description¶
This solution is based on GCANet(https://github.com/cddlyf/GCANet). The net was trained for a long time with weight decay. Protocol of the training are presented above. Final weights also presented
Some constants and gdrive mount¶
In [5]:
import torch
workDir='/home/data/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
In [6]:
# this mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd '/home'
!mkdir 'data'
%cd '/home/data'
Data loading¶
In [7]:
!pip install -U aicrowd-cli
In [9]:
API_KEY = '52ab6eb031245b7028158e2f3e993174' #Please enter your API Key from [https://www.aicrowd.com/participants/me]
!aicrowd login --api-key $API_KEY
In [10]:
!aicrowd dataset download --challenge f1-smoke-elimination -j 3
In [11]:
!rm -rf data
!mkdir data
!unzip -q train.zip -d data/train
!unzip -q val.zip -d data/val
!unzip -q test.zip -d data/test
!mv train.csv data/train.csv
!mv val.csv data/val.csv
!mv sample_submission.csv data/sample_submission.csv
Custom dataset class and dataloaders¶
In [12]:
import torch
from torch.utils.data import Dataset,DataLoader,RandomSampler
from torchvision import transforms as T
import pandas as pd
import numpy as np
from PIL import Image
class ImageDataset(Dataset):
def __init__(self,ImageFold,dsLen):
self.ImageFold=ImageFold
self.dsLen=dsLen
def __len__(self):
return self.dsLen
def __getitem__(self,ind):
im_smoke=self.load_image(ind,'smoke/')
im_clear=self.load_image(ind,'clear/')
im_smoke = (np.asarray(im_smoke)/255)
im_clear = (np.asarray(im_clear)/255)
im_smoke = torch.from_numpy(im_smoke).float()
im_clear = torch.from_numpy(im_clear).float()
return im_smoke.permute(2,0,1), im_clear.permute(2,0,1)
def load_image(self,ind,t):
return Image.open(self.ImageFold+t+str(ind)+'.jpg')
In [13]:
ds_train=ImageDataset(workDir+'data/train/',20000)
dl_train=DataLoader(ds_train,batch_size=16,shuffle=True,num_workers=2)
ds_val=ImageDataset(workDir+'data/val/',2000)
dl_val=DataLoader(ds_val,batch_size=16,shuffle=True,num_workers=2)
Net object creation¶
In [14]:
!git clone https://github.com/cddlyf/GCANet
from GCANet import GCANet
net=GCANet.GCANet(3)
#net.load_state_dict(torch.load('/content/drive/MyDrive/state_dict_model_2.pt'))
Out[14]:
Fuction for drawing images¶
In [18]:
import matplotlib.pyplot as plt
%matplotlib inline
def show_image(hazy_image, gt_image, predicted_image):
title = ['Ground Truth Image','Hazy Image', 'Predicted']
plt.figure(figsize=(15, 15))
display_list = [
hazy_image.cpu().permute(1, 2, 0).numpy(),
gt_image.cpu().permute(1, 2, 0).numpy(),
predicted_image.detach().cpu().permute(1, 2, 0).numpy()
]
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off')
plt.show()
Training loop¶
In [16]:
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
#import dataloader
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def train(dehaze_net,lr,n_epoch):
criterion = nn.MSELoss().to(device)
# lr=0.001
optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=lr)
dehaze_net.train()
for epoch in range(n_epoch):
print('Epoch :' +str(epoch))
epoch_loss=0
for iteration, (img_haze,img_orig) in enumerate(dl_train):
img_orig = img_orig.to(device)
img_haze = img_haze.to(device)
clean_image = dehaze_net(img_haze)
loss = criterion(clean_image, img_orig)
optimizer.zero_grad()
loss.backward()
# torch.nn.utils.clip_grad_norm(dehaze_net.parameters(),0.01)
optimizer.step()
epoch_loss+= loss.detach().cpu().numpy()*len(img_orig)
if ((iteration+1) % 10) == 0:
print("Loss at iteration", iteration+1, ":", loss.item())
if iteration % 200 == 0:
torch.save(dehaze_net.state_dict(), '/content/drive/MyDrive/state_dict_model_checkpoint.pt')
show_image(img_orig[0], img_haze[0], clean_image[0])
if iteration % 600 == 0:
lr=lr
#/2
optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=lr)
print('Epoch train loss: '+str(epoch_loss/len(dl_train.dataset)))
epoch_loss=0
with torch.no_grad():
for iter_val, (img_haze,img_orig) in enumerate(dl_val):
img_orig = img_orig.to(device)
img_haze = img_haze.to(device)
clean_image = dehaze_net(img_haze)
clean_image=clean_image.detach().cpu().numpy()*255
img_orig=img_orig.detach().cpu().numpy()*255
loss = ((clean_image - img_orig)**2).mean()
epoch_loss+= loss*len(img_orig)
print('Epoch val loss: '+str(epoch_loss/len(dl_val.dataset)))
# dehaze_net.val()
# Validation Stage
return dehaze_net
In [17]:
net.to(device)
net=train(net,0.001,4)
net=train(net,0.0005,4)
net=train(net,0.0003,4)
net=train(net,0.0001,4)
net=train(net,0.00005,4)
net=train(net,0.00001,4)
In [ ]:
net.to(device)
net=train(net,0.00001,5)
torch.save(net.state_dict(), '/content/drive/MyDrive/state_dict_model_2.pt')
In [ ]:
net.to(device)
net=train(net,0.00005,2)
In [ ]:
net.to(device)
net=train(net,0.00001,5)
torch.save(net.state_dict(), '/content/drive/MyDrive/state_dict_model_2.pt')
In [ ]:
net.to(device)
net=train(net,0.000001,5)
torch.save(net.state_dict(), '/content/drive/MyDrive/state_dict_model_2.pt')
In [ ]:
!rm -rf '/home/data/data/test/clear/*'
net.eval()
i=0
from torchvision.utils import save_image
import os
for im in (os.listdir('/home/data/data/test/smoke')):
ind=im.split('.')[0]
hzim=Image.open('/home/data/data/test/smoke/'+str(ind)+'.jpg')
hzim=(np.asarray(hzim)/255)
hzim=torch.from_numpy(hzim).float()
hzim=hzim.permute(2,0,1)
hzim=hzim.reshape((1,hzim.size()[0],hzim.size()[1],hzim.size()[2]))
hzim=hzim.to(device)
clim=net(hzim)
clim=clim[0]
save_image(clim,'/home/data/data/test/clear/'+im)
!pwd
%cd '/home/data/data/test/'
!zip submission.zip -r 'clear/' > /dev/null
!aicrowd submission create -c f1-smoke-elimination -f submission.zip
In [ ]:
for key in state_dict.keys():
state_dict[key] = torch.from_numpy(state_dict[key])
Content
Comments
You must login before you can post a comment.