Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ We have a list of candidate papers to implement: https://github.com/chainer/mode
- Neural Relational Inference for Interacting Systems [[paper](https://arxiv.org/abs/1802.04687)] [[code](https://github.com/chainer/models/tree/master/nri)]
- SiamRPN and SiamMask [[paper](https://arxiv.org/abs/1812.05050)] [[code](https://github.com/STVIR/pysot)]
- Learning to learn by gradient descent by gradient descent [[paper](https://arxiv.org/abs/1606.04474)] [[code](https://github.com/chainer/models/tree/master/learning_to_learn)]
- Learning to Simplify: Fully Convolutional Networks for Rough Sketch Cleanup [[paper](http://www.f.waseda.jp/hfs/SimoSerraSIGGRAPH2016.pdf)] [[code](https://github.com/chainer/models/tree/master/simplifying_rough_sketches)]

## License
MIT License (see `LICENSE` file).
Expand Down
133 changes: 133 additions & 0 deletions simplifying_rough_sketches/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Rough Sketch Simplification using FCNN in PyTorch

This repository contains code of the paper [Learning to Simplify: Fully Convolutional Networks for Rough Sketch Cleanup](http://www.f.waseda.jp/hfs/SimoSerraSIGGRAPH2016.pdf) which is tested and trained on custom datasets. It is based on Chainer.

## Overview

The paper presents novel technique to simplify sketch drawings based on learning a series of convolution operators. Image of any dimension can be fed into the network, and it outputs the image of same dimension as the input image.

![model](images/model.png)

The architecture consists of encoder and a decoder, the first part acts as an encoder and spatially compresses the image, the second part, processes and extracts the essential lines from the image, and the third and last part acts as a decoder which converts the small more simple representation to an grayscale image of the same resolution as the input. This is all done using convolutions.
The down- and up-convolution architecture may seem similar to a simple filter banks. However, it is important to realize that the number of channels is much larger where resolution is lower, e.g., 1024 where the size is 1/8. This ensures that information that leads to clean lines is carried through the low-resolution part; the network is trained to choose which information to carry by the encoder- decoder architecture. Padding is used to compensate for the kernel size and ensure the output is the same size as the input when a stride of 1 is used. Pooling layers are replaced by convolutional layers with increased strides to lower the resolution from the previous layer.



## Contents
- [Rough Sketch Simplification using FCNN in PyTorch](#rough-sketch-simplification-using-fcnn-in-pytorch)
- [Overview](#overview)
- [Contents](#contents)
- [1. Setup Instructions and Dependencies](#1-setup-instructions-and-dependencies)
- [2. Dataset](#2-dataset)
- [3. Training the model](#3-training-the-model)
- [5. Model Architecture](#5-model-architecture)
- [6. Observations](#6-observations)
- [Training](#training)
- [Predicitons](#predicitons)
- [Loss](#loss)
- [7. Repository overview](#7-repository-overview)


## 1. Setup Instructions and Dependencies

Clone the repositiory on your local machine.


Start a virtual environment using python3
``` Batchfile
virtualenv env
```


Install the dependencies
``` Batchfile
pip install -r requirements.txt
```

You can also use google collab notebook.


## 2. Dataset

The authors have not provided dataset for the paper. So I created my own. I have uploaded the dataset on drive, the link to which can be found [here](https://drive.google.com/open?id=14NQTqITAiw8o-JgdnumQ-K0asLRwJy7q). Feel free to use it.

Create two folders inside the root directory of dataset, `Input` and `Taget` and place the images inside the corresponding directory. It is important to keep the names same for both input and target images.

## 3. Training the model

To train the model, run

```Batchfile
python main.py --train=True
```

optional arguments:

| argument | default | desciption|
| --- | --- | --- |
| -h, --help | None | show help message and exit |
| --gpu_id GPU_ID, -g GPU_ID | -1 | GPU ID (negative value indicates CPU) |
| --out OUT, -o OUT |result | Directory to output the result |
| --batch_size BATCH_SIZE, -b BATCH_SIZE | 8 | Batch Size |
| --height HEIGHT, -ht HEIGHT | 64 | height of the image to resize to |
| --width WIDTH, -wd WIDTH | 64 | width of the image to resize to |
| --samples SAMPLES | False | See sample training images |
| --num_epochs NUM_EPOCHS | 75 | Number of epochs to train on |
| --train TRAIN | True | train the model |
| --root ROOT | . | Root Directory for Input and Target images. |
| --n_folds N_FOLDS | 7 | Number of folds in k-fold cross validation. |
| --save_model SAVE_MODEL | True | Save model after training. |
| --load_model LOAD_MODEL | None | Path to existing model. |
| --predict PREDICT | None | Path of rough sketch to simplify using existing model |

## 5. Model Architecture

![archi](images/archi.png)

## 6. Observations


### Training

| Epoch | Prediction |
| --- | --- |
| 2 | ![epoch2](pred/2.png) |
| 60 | ![epoch40](pred/60.png) |
| 100 | ![epoch80](pred/100.png) |
| 140 | ![epoch120](pred/140.png) |

### Predicitons

![pred3](pred/pred3.png)
![pred2](pred/pred8.png)
![pred1](pred/pred1.png)


### Loss

![loss](images/loss.png)

## 7. Repository overview

This repository contains the following files and folders

1. **images**: Contains resourse images.

2. **pred**: Contains prediction images.

3. `dataset.py`: code for dataset generation.

4. `model.py`: code for model as described in the paper.

5. `predict.py`: function to simplify image using model.

6. `read_data.py`: code to read images.

7. `utils.py`: Contains helper functions.

8. `train_val.py`: function to train and validate models.

9. `main.py`: contains main code to run the model.

10. `requirements.txt`: Lists dependencies for easy setup in virtual environments.

35 changes: 35 additions & 0 deletions simplifying_rough_sketches/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from PIL import Image
import numpy as np

from chainer.dataset import dataset_mixin

class CustomDataset(dataset_mixin.DatasetMixin):

def __init__(self, input_path, target_path, height=424, width=424):
self.input_img = input_path
self.target_img = target_path
self.height = height
self.width = width

def __len__(self):
return len(self.input_img)

def preprocess(self, img_path):
img = Image.open(img_path)
img = img.convert('L')
img = img.resize((self.height, self.width))
img = np.asarray(img)/255.0
img = img.reshape(1,self.height, self.width)
return img

def __getitem__(self, index):
X = self.preprocess(self.input_img[index])
y = self.preprocess(self.target_img[index])

return X, y

def get_example(self, index):
X = self.preprocess(self.input_img[index])
y = self.preprocess(self.target_img[index])

return X, y
Binary file added simplifying_rough_sketches/images/archi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added simplifying_rough_sketches/images/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added simplifying_rough_sketches/images/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
151 changes: 151 additions & 0 deletions simplifying_rough_sketches/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import chainer

import time
import argparse
import os

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from model import Net
from dataset import CustomDataset
from read_data import get_data
from utils import samples
from train_val import train, validate
from utils import argument_parser, create_directory
from predict import predict

def main():

args = argument_parser()

gpu_id = args.gpu_id
# root = args.root
height = args.height
width = args.width
BATCHSIZE = args.batch_size
N_FOLD = args.n_folds
num_epochs = args.num_epochs

out = args.out
create_directory(out)

root = args.root
Input = os.path.join(root, 'Input')
Target = os.path.join(root,'Target')
# height = width = 424

input_images, target_images = get_data(Input, Target)


if args.samples:
show_samples(input_images)
show_samples(target_images)


if args.predict and args.load_model:

model_path = args.load_model
img_path = args.predict

model = Net()

from chainer import serializers
serializers.load_npz(model_path, model)
print ('Model Loaded')


img = predict(model, img_path, height, width, gpu_id)
pilimg = Image.fromarray(np.uint8(img))
pilimg.save(os.path.join(out,f'predict{hash(img_path)}.png'))
print (f'Image saved in {out} directory.')


elif args.train:

print ('Training Started.')

model = Net()

if gpu_id >= 0:
chainer.backends.cuda.get_device_from_id(gpu_id).use()
model.to_gpu()
print ('Switched to CUDA backend')
else:
print ('Training on CPU')


optimizer = chainer.optimizers.AdaDelta(rho=0.9)
optimizer.setup(model)
optimizer.use_cleargrads(True)

dataset = CustomDataset(input_path=input_images, target_path=target_images, height=height, width=width)
train_val_set = chainer.datasets.get_cross_validation_datasets_random(dataset=dataset, n_fold=N_FOLD, seed=None)

total_loss_val, total_loss_train = [],[]

import time
since = time.time()
epoch_num = num_epochs
best_val_loss = 1000

for fold, train_val in enumerate(train_val_set):

train_set, val_set = train_val

train_loader = chainer.iterators.MultithreadIterator(dataset=train_set, batch_size=BATCHSIZE, shuffle=True, repeat=False)
val_loader = chainer.iterators.MultithreadIterator(dataset=val_set, batch_size=BATCHSIZE, shuffle=True, repeat=False)

for epoch in range(1, epoch_num+1):
# chainer.cuda.memory_pool.free_all_blocks()
avg_loss_train, loss_train = train(train_loader, model, optimizer, epoch, gpu_id)
loss_val = validate(val_loader, model, optimizer, epoch, gpu_id)
total_loss_val.append(loss_val)
total_loss_train.append(avg_loss_train)

if loss_val < best_val_loss:
best_val_loss = loss_val
print('*****************************************************')
print(f'best record: [Fold {fold}] [epoch {epoch}], [val loss {loss_val:.5f}]')
print('*****************************************************')

if epoch%1 == 0:
print ('Fold: {0} Epoch: {1}'.format(fold,epoch))
samples(input_images[0], target_images[0], model, out, height, width, gpu_id)
img = predict(model, input_images[0], height, width, gpu_id)
pilimg = Image.fromarray(np.uint8(img))
pilimg.save(os.path.join(out,str(fold)+'_'+str(epoch)+'.png'))

end = time.time()

print ('Training Completed')
if args.save_model:
print('Saving model in models directory')
from chainer import serializers
create_directory('models')
serializers.save_npz(os.path.join('models','simplifying_rough_sketches.model'), model)

print ('Time Taken: ',end-since)
fig = plt.figure(num = 2)
fig1 = fig.add_subplot()
fig1.plot(total_loss_train, label = 'training loss')
fig1.plot(total_loss_val, label = 'validation loss')
plt.legend(loc='upper left')
plt.savefig(os.path.join(out,'loss.png'))
plt.close(fig)


try:
for i in range (5):
import random
k = random.randint(1,63)
samples(input_images[k], target_images[k], model, out)
except Exception as e:
pass

else:
print ('Exiting')

if __name__ == '__main__':
main()
Loading