Semantic Segmentation Using Fully Convolutional Networks

TL,DR: This blog post presents my implementation of the FCN paper. Code for this can be found on my Github.


Disclaimer: What follows is not a tutorial, they're my implementation notes. If you're looking for material on what FCNs are and such - then you probably don't stand to gain much from reading on. For that, I suggest going through the paper linked above. What follows is probably most useful to people who are either planning to or are in the midst of implementing an FCN themselves. My aim here is to talk about details that revealed themselves to me only after several torturous hours of scouring the depths of the Internet.


Structure of this post:

Not sure if reading this *really* long post is worth your time? Maybe jump to the results section first, and then come back to the rest of it.


What is Semantic Segmentation?

Semantic Segmentation is the task of labeling each pixel of an input image with class labels from a predetermined set of classes. For example, given the following image:

input

the segmentation result should look like:

output

See how all the pixels in the input image that belong to the cat are color coded brown and the background is color coded black? That is the end result we hope to achieve.


How do we do it?

In this blog post, we will see how Fully Convolutional Networks (FCNs) can be used to perform semantic segmentation. A blazing fast walkthrough of the FCN paper by Long, Shelhamer and Darrell: we take a VGG-Net trained for image classification on the ILSVRC dataset and "convolutionalize" it (more on this "convolutionalization" in the following sections). We then append a score layer and upsampling layers to the end of the pretrained, convolutionalized VGG-Net and then use per-pixel softmax to obtain dense predictions. Skip connections from the VGG pooling layers to the upsampling layers are also added. We do all of this using Tensorflow.

I'll analyze each aspect of this walkthrough in greater detail, but first a small note on the dataset I've used. PASCAL-VOC 2012 is a very standard dataset used for training segmentation models. It contains $ \sim 1000 $ images of varying dimensions each in the training and testing sets. More information about the dataset can be found here.


The Pretrained VGG-Net

VGG-Net is a CNN architecture originally presented by Simonyan and Zisserman for the ImageNet Classification Task 2014. AFAIK, it did not win first place that year but still went on to become a VERY popular model because of its simplicity. All five variants of VGG-Net presented in the original paper consist only of convolutional and dense layers. The FCN that I've implemented (and the one that is presented in the FCN paper) uses a pretrained VGG-16 for its initial layers. Before I proceed further, it would help to familiarize oneself with the VGG-16 architecture:

vgg-16

About the notation in the above diagram: each convolution layer is specified by an ordered pair (filter size, number of output channels). Each fully connected layer is specified by a single number (number of output units). It can be seen from the diagram that the above VGG-Net has 16 "trainable" layers (that is, 16 layers with weights) - hence the name VGG-16. If you're referring the original VGG paper, you'll find that the above architecture is referred to as the D variant. For our FCN implementation, we use all the convolution and pooling layers from the above architecture as is (ie. all layers from conv1_1 to pool5, both inclusive, are used as they are presented above in the FCN implementation). The fully connected layers are converted to convolutional layers (see next section on convolutionalization) and then used.

Coming to the more practical aspect of obtaining and using the trained VGG weights. Luckily for us, the authors of VGG-Net have kindly made their trained models publicly availble on their group's webpage. They've written and trained their model in Caffe; which poses a minor hurdle for those of us using Tensorflow. I say minor because we only need to go an extra step of using caffe-tensorflow to convert the caffe weights into numpy binary format.

I should note, at this point, that one might be tempted to not use the pretrained weights at all. I initially had the misconception that initializing with the VGG weights is not very important - that the network would learn the weights when I train it end-to-end. This, unfortunately, is not the case. Training the VGG-Net is not a single step process - it is trained in stages, and with increasingly difficult images. The pretrained VGG that we will take for granted in the remainder of this post is no small feat to achieve. It is critical that we have it in our FCN. I cannot harp on this enough: training a VGG from scratch is a sizeable task - one that took the authors weeks of training to complete - it is not to be underestimated.


Convolutionalize Dense Layers

Here's the deal with this. Dense layers face the limitation that their input must always have a fixed size. Images, as they go, can be of any size. That is, we may have images of different sizes within a training batch. Likewise, it is not necessary that the image we perform inference on should have dimensions same as any of the training images. Note that this issue of dense layers requiring fixed input size doesn't present with convolutional layers. The input to a convolutional layer can be of any size and the weights of the layer would still work and still give outputs. The VGG-16 architecture presented above has three fully connected (dense) layers at the very end (fc1, fc2 and fc3) and we must take care of them if we are to build an FCN. As the name goes, "Fully Convolutional Networks" should only have convolutional layers. Hence, to accomodate the name, we must convert the dense VGG layers to convolutional layers. There exists a beautiful way to go about this.

Suppose we have a dense layer that requires HWHW input units and produces DD outputs. Ie. the weight matrix has dimensions D×HWD \times HW. This dense layer can be convolutionalized as follows: each row of the dense layer weight matrix can be reshaped into a filter of dimenstions H×WH \times W. Thus we end up with DD filters of size H×WH \times W. Note that this convolutionalized layer can now take arbitrarily sized inputs (it is exactly the same as any other convolution layer). Also note that if the input to this layer is of size H×WH \times W, then it performs the exact same operation as the original dense layer. I find this quite intuitive and easy to understand, but perhaps a better explanation of this convolutionalization procedure is available in the notes of CS231n here. In code, this reshaping of fully-connected weights of VGG-16 looks something like this:

fc1_convolutionalized_filters = numpy.reshape(vgg_params['fc1']['weights'], [7, 7, 512, 4096])

How did I get the above line of code? Note that the original VGG-16 was trained for the ImageNet classification dataset - where each RGB image was of dimesions 224×224224 \times 224. So, the input to fc1 in the original VGG-16 would have been 7×77 \times 7. This is easy to work out from the following facts: the spatial resolution remains unchanged at convolutional layers on account of SAME padding, and the the spatial resolution reduces by factor of 22 at each pooling layer. It would be useful to remember these observations as they will be useful later in the post as well. There are five pooling layers between conv1_1 and fc1 and hence, the input to fc1 would have the dimensions 224/25×224/25224/2^5 \times 224/2^5 or 7×77 \times 7. Also since the last conv layer before fc1 (ie. conv5_3) has 512512 output channels, the input volume of fc1 would also have 512512 channels. All in all, in the original VGG architecture, fc1 sees an input of dimensions 7×7×5127 \times 7 \times 512 and gives 40964096 outputs. That is, the weight matrix of layer fc1 has 40964096 rows and 2508825088 columns. Following the convolutionalization logic presented in the last paragraph, each row (of 2508825088 weights) can be reshaped into a filter of size 7×7×5127 \times 7 \times 512 and we would hence end up with 40964096 such filters. In Tensorflow, we are required to specify convolution layer weights in the format [filter_height, filter_width, filter_depth, output_channels]. Hence we simply reshape the VGG weights into the shape [77, 77, 512512, 40964096]. Note that the bias terms can be used as they are:

fc1_convolutionalized_biases = vgg_params['fc1']['biases']

Finally, we can use these numpy arrays to create the convolutionalized fc1 Tensorflow graph node as follows:

fc1 = \
  tf.nn.relu(
    tf.nn.bias_add(
      tf.nn.conv2d(
        pool5,
        fc1_convolutionalized_filters,
        [1, 1, 1, 1],
        'SAME'),
      fc1_convolutionalized_biases
    )
  )

Similarly for fc2 and fc3,

fc2_convolutionalized_filters = numpy.reshape(vgg_params['fc2']['weights'], [1, 1, 4096, 4096])
fc2_convolutionalized_biases  = vgg_params['fc2']['biases']
fc2 = \
  tf.nn.relu(
    tf.nn.bias_add(
      tf.nn.conv2d(
        fc1,
        fc2_convolutionalized_filters,
        [1, 1, 1, 1],
        'SAME'),
      fc2_convolutionalized_biases
    )
  )
fc3_convolutionalized_filters = numpy.reshape(vgg_params['fc3']['weights'], [1, 1, 4096, 4096])
fc3_convolutionalized_biases  = vgg_params['fc3']['biases']
fc3 = \
  tf.nn.relu(
    tf.nn.bias_add(
      tf.nn.conv2d(
        fc2,
        fc3_convolutionalized_filters,
        [1, 1, 1, 1],
        'SAME'),
      fc3_convolutionalized_biases
    )
  )

Score Layers and Color Maps

For the first time ever in this post, we present the complete FCN architecture:

fcn-architecture

First off, note how the fully connected layers from the VGG architecture have now become convolution layers. Also now we see the addition of three new kinds of layers: score layers, upsample layers and addition layers. In this section we will tackle score layers.

Score layers are nothing but convolutional layers having kernel dimensions 1×11 \times 1 and 2121 output channels. My (intuitive) understanding of the nomenclature is as follows:

Suppose the input volume to a score layer has dimensions H×W×FH \times W \times F. This is thought of as an image of dimensions H×WH \times W and each pixel of the image is made up of FF features. The score layer goes over each pixel, and at each pixel location predicts (on the basis of the FF features at that pixel location) scores for each possible output class. In our case, there are a total of 2121 output classes - 2020 object classes and background. Hence, the score layer in our case goes over each pixel location and predicts (at each pixel location) 2121 scores - each score representing the likeliness of an output class.

Note that the output of our score layers will have dimesions H×W×21H \times W \times 21. At each pixel location, we can obtain a class prediction by simply picking the argmax of all 2121 scores at that location. However, this doesn't lead us to a very visually appealing representation of the segmentation. We would like to associate each each output class with a color, so that we can map our H×WH \times W class predictions to an H×WH \times W image (like the segmented cat image presented at the very beginning of this post). Ie. we need a color map.

The MATLAB modules provided along with the PASCAL-VOC dataset have the following code:

% VOCLABELCOLORMAP Creates a label color map such that adjacent indices have different
% colors.  Useful for reading and writing index images which contain large indices,
% by encoding them as RGB images.
%
% CMAP = VOCLABELCOLORMAP(N) creates a label color map with N entries.
function cmap = VOClabelcolormap(N)

if nargin==0
    N=256;
end
cmap = zeros(N,3);
for i=1:N
    id = i-1; r=0;g=0;b=0;
    for j=0:7
        r = bitor(r, bitshift(bitget(id,1),7 - j));
        g = bitor(g, bitshift(bitget(id,2),7 - j));
        b = bitor(b, bitshift(bitget(id,3),7 - j));
        id = bitshift(id,-3);
    end
    cmap(i,1)=r; cmap(i,2)=g; cmap(i,3)=b;
end
cmap = cmap / 255;

For the MATLAB averse, wllhf has kindly provided this gist which contains a Python script that does the exact same thing as the above MATLAB code. That is, the script provides a mapping between output classes and colors:

colormap


Upsampling Layers

The second new kind of layer we see in the FCN architecture are upsampling layers. Specifically, there are three of these: upsample2 upsample4 and upsample32. As one might expect from the name, upsampling layers increase the dimensions of the input that is provided to them by some factor. The way I've named each of my upsampling layers is: upsampleXYZ, where XYZ is the factor of upsampling wrt the lowest dimesions that the input to the network is pooled to (Ie. the output of pool5). So, if the output of pool5 has spatial dimensions H×WH \times W, then the output of upsample2 has dimensions 2H×2W2H \times 2W; the output of upsample4 has dimesions 4H×4W4H \times 4W and the output of upsample32 has dimensions 32H×32W32H \times 32W.

The upsampling layers are just convolutional layers of a special kind. If you don't understand how I can state that so directly, please refer chapter 4 of this document on convolution arithmetic. The FCN paper tells us to initialize the filters of the upsample layer with bilinear interpolation filters. It will take me a complete post to explain resampling and bilinear interpolation - I would refer the reader to this post on Daniil Pakhomov's blog. I strongly encourage reading the post atleast three times word for word - it contains extremely valuable theory and practical information about resampling, specifically bilinear interpolation. I have adapted the code presented in Daniil's blog post to generate the filters and create my upsampling layers:

Method to generate the upsampling filters:

def _get_upconv_params(factor, out_channels, name):
  kernel_sz = 2*factor - factor%2
  weights = np.zeros([kernel_sz, kernel_sz, out_channels, out_channels], dtype=np.float32)

  # Populate weights
  if kernel_sz % 2 == 1:
    center = factor - 1
  else:
    center = factor - 0.5
  tmp = np.ogrid[:kernel_sz, :kernel_sz]
  kernel = (1 - abs(tmp[0] - center)/factor) * (1 - abs(tmp[1] - center)/factor)
  for i in range(out_channels):
    weights[:, :, i, i] = kernel

  # Populate biases
  biases = np.zeros([out_channels,], dtype=np.float32)

  dic = {
    'weights': tf.Variable(weights, '{}_weights'.format(name)),
    'biases': tf.Variable(biases, '{}_biases'.format(name))
  }
  return dic

upsample2 node in the Tensorflow computation graph:

params['upsample2'] = _get_upconv_params(2, 21, 'upsample2')
net['upsample2'] = tf.nn.conv2d_transpose(
    net['score_fc'],
    params['upsample2']['weights'],
    output_shape = [1, tf.shape(net['score_fc'])[1] * 2, tf.shape(net['score_fc'])[2] * 2, 21],
    strides = [1,2,2,1]
  )

Cropping and Skip Connections

Finally, the last "new" kind of layer in the FCN architecture is the addition layer (also called a fuse layer or skip connection). This layer takes in two input volumes and adds them elementwise. Note that an important prerequisite of this addition is that the dimensions of the two input volumes must exactly match. Now, since the pool layers follow the 'SAME' scheme of max-pooling they reduce an input of dimensions H×WH \times W to an output of dimensions ⌊(H+1)/2⌋×⌊(W+1)/2⌋\lfloor (H+1)/2 \rfloor \times \lfloor (W+1)/2 \rfloor. However, the upsample layer increases the dimensions of an input (H×WH \times W) to 2H×2W2H \times 2W (when the upsampling factor is 2). It is not too hard to show from these two observations that the output of upsample2 always has dimensions greater than or equal to the output of pool4 (equivalently, the output of score_pool4). Empirically, suppose that the output of pool4 (input to pool5) has size 5×55 \times 5; then the output of pool5 will be of size 3×33 \times 3. The upsampling (upsample2) would hence produce an output of 6×66 \times 6 (greater dimensions than the output of pool4). Thus, in order for us to create a fuse layer with score_pool4 and upsample2 as inputs, we must first crop upsample2 to an appropriate size. I have implemented this cropping as follows:

def _get_crop_layer(big_batch, small_batch):
  h_s = tf.shape(small_batch)[1]
  w_s = tf.shape(small_batch)[2]

  h_b = tf.shape(big_batch)[1]
  w_b = tf.shape(big_batch)[2]

  return big_batch[:,
                 (h_b - h_s)//2 : (h_b - h_s)//2 + h_s,
                 (w_b - w_s)//2 : (w_b - w_s)//2 + w_s,
                 :]

Now, the fuse layer can be implemented as follows:

# Crop upsample2
net['cropped_upsample2'] = _get_crop_layer(net['upsample2'], net['pool4'])

# Score pool4
params['score_pool4'] = _get_score_layer_init_params('score_pool4')
net['score_pool4'] = _get_conv_layer(net['pool4'], params['score_pool4'])

# Fuse pool4
net['fuse_pool4'] = tf.add(net['score_pool4'], net['cropped_upsample2'])

A few minor notes about the above code. The method to initialize score layer parameters (_get_score_layer_init_params) initializes filter weights using the Xavier initialization scheme and the biases to zero. Also it should be noted that for the crop layer implementation to work, the batch size must be fixed to one. Finally, here I've only presented cropping of the upsample2 output. The outputs of upsample4 and upsample32 must also be cropped to the sizes of pool3 and the input image respectively.


Results

The standard metric for evaluating segmentation performance is mean intersection over union. Ie. mean ratio of the number of correctly predicted pixels to the total number of pixels (mean is computed over all the images in the test set). I haven't gone to the extent of computing this metric for my model - mainly because my aim, going into the project, was more to get a program to do something cool and something which produces visually gratifying results than to produce a model that achieves state of the art results. I suppose though (and this guess may be well off the mark), that with a fair amount of fine tuning, it is possible to reproduce (atleast to a very small margin of error) the results reported in the original paper. Besides, I feel like the implementation process has already shown me the large bulk of what the FCN paper had to offer - and sitting around now, hoping and praying to God that a particular hyperparameter setting pushes the mIOU metric up by a third decimal point is not (IMHO) the best use of my time.

*mic drop*

Now for the moment we've all been holding our breaths for! Without further ado, here they are. Leftmost panel contains the original image, middle panel contains the ground truth segmentation and the rightmost panel contains the predicted segmentation.