Image Captioning in TensorFlow 2.0

From using base level Tensorflow back in 2016, the open-sourced code has changed quite a bit. With further adoption of Keras, Tensorflow is more high level, making it a lot easier to adopt. Although it took away a lot of the low-level functions I used to write to set up layers and even get arrays into tensors, it will probably not be missed.

Image Captioning Code Updates

With the release of Tensorflow 2.0, the image captioning code base has been updated to benefit from the functionality of the latest version. The main change is the use of tf.functions and tf.keras to replace a lot of the low-level functions of Tensorflow 1.X.

The code is based on this paper titled Neural Image Caption Generation with Visual Attention. This was intriguing to me because it both served my interest in Computer Vision and NLP. Also, it could be a potential use case for our clients as well. When I saw Fei Fei Li talk about this in one of her presentations, I was very impressed by the technology, and now I have the opportunity to try it myself concurrently with the new release of Tensorflow 2.0.

Image Captioning Use Cases

Below are five applications of image captioning:

use cases


Challenges and Resolution

Image descriptions are subjective and every person who sees an image will focus their attention differently. Luckily, the MS COCO dataset has pretty clean descriptive captions. The image captioning with “attention” tries to take a vectorial representation of an image and tries to tie that representation to create a meaningful sentence, according to the paper. It combines the power of Convolutional Neural Networks and text generation with Recurrent Neural Networks as shown in the picture.

feature mapReferencing the paper, “as the model generates each word, it’s attention changes to reflect the relevant parts of the image”. This method is quite interesting because according to the paper, it doesn’t use object detection, but “learns” from the images from scratch to look for its “objectness”, which makes the model more adaptable and general rather than be bound by a set of trained objects.

Running the Code

The code is on a Colab notebook, but it crashed quite a few times and sometimes, a GPU was always not available. The original code processed 30,000 images, which was quite a bit. I also had to re-download the datasets every time I exited Colab. I ended up running the code locally and made some adjustments.

Here are some tweaks that I made to the original code:

  1. The batch size didn’t work for a long time, so I used a batch size divisible by 10 and it worked, and I ended up using 100. I don’t know why when it has the original batch size of 64, the code base runs into this error:
    value error
    Even a Google engineer pointed it out here, but there doesn’t seem to be a resolution. At first, I tried a try and except error, but I never fully fed the entire matrix properly so the output of the captioning was always cut off, luckily the batch size divisible by 10 worked.
  2. I decreased the training to 3000 images but increased the batch size to 50 which helped me do well on a smaller dataset. Here is an example of the loss rate over time.
    Loss Plot
  3. There weren’t many datasets to choose from, but the image dataset that I was able to apply this on was the flickr-image-dataset. For each image, there were 4 or 5 different captions, so I picked one for each image the way that the code base dataset is setup.
    Flickr image dataset
  4. I tweaked the captioning to reflect how the code ingests data such as adding the ‘ ‘ and ” in the string. The code looks for it, later on, to know when to start and end the sentence for evaluation. This is an example here:
    ‘<start>  Several men in hard hats are operating a giant pulley system . <end>’
  5. I created some sort of way of doing metrics, this was the result of trying to quantify the accuracy of the predicted captions. I used the Jaccard Index to find words that were featured in the predicted and the actual captions. The first step was to remove stop words to get to the meaningful words of the caption. This is typically done with the NLTK package, but I found using stop-words from Python was much easier from an installation standpoint.

After removing these words such as “the”, “a”, and “an”, there were words that basically meant the same thing, but spelled differently. For example, ‘house’ and ‘building’, or ‘guys’ and ‘people’ will not count because they are spelled differently, but have similar meanings semantically. Using Google’s pretrained word2vec model in python and the gensim package, if two words were above a threshold, then the word would be replaced with the real caption. This helped in replacing words with similar meaning to the actual caption, further increasing the metric.


In the picture below, house and building are similar in context from MS COCO.

remote colorful building

After removing stop words.
After removing stop words.

After doing word2vec .
After doing word2vec .

The word ‘house’ was replaced by ‘building’ so when I analyze the words for an intersection, there is an overlap. At most, there would be 2-3 keywords that there would be some overlap to determine how well the model is doing. The scores are generally between .1-.3 on a 0.0 to 1.0 similarity scale. If there was no similarity at all, there would be a similarity of 0.0, which means no intersection.

There were some more academic metrics outlined, but that would take time to learn to apply. This was a quick and easy way to judge how well the captions are doing without going into too much detail.

Good Image Captioning

As for the generated image captions, there were some good ones:

generated caption
The caption identified a man, guitar, and microphone.

generated caption
The caption at least identified a man, kayak, and water.

generated caption
Although the predicted and the actual are different, the predicted caption reflects what is going on in the photo to some degree.

Bad Image Captioning

And there’s the bad:

generated caption
This one is nowhere close to the actual.

generated caption
The model failed to pick up that this is a bowling alley.

generated caption
The animals are mistaken as a group of people in the photo.

MS COCO Dataset

Overall, I feel the MS COCO dataset for image captioning has better, clear and concise descriptions while Flickr30K has longer verbose, specific descriptions which makes it more difficult to produce good captions.

generated caption
As you can see, the real caption here is quite long and the predicted text is not that long.

As you can see, the real caption here is quite long and the predicted text is not that long.

This is probably why the people who wrote the example used the MS COCO dataset and not Flickr30K. Even in the paper, the MS COCO dataset performed better by their metrics. The authors of the paper also used RMSProp for Flickr8k, while they used Adam Optimizer for the Flickr30k, which is an intriguing choice.

Overall, this exercise helped me:

  1. Get more comfortable with Tensorflow 2.0, even fixing the bug in the code regarding the batch size.
  2. Find a way to quantify the results of the text generated by the model with Jaccard Similarity(IOU).
  3. Adapt the code to another dataset, which definitely helped me understand the code base a lot better. That way, we can do it for a customer.

How to Adapt to a Customer and Improve Accuracy

  1. The dataset is general, so if there was a specific use case, I would think the model would perform better with a bit of work to custom build one for a client with custom images and custom captions.
  2. There are more proper academic metrics used that are out there to try and see how it quantifies things as well.
  3. After creating a specific dataset of tens of thousands of images, a dedicated VM with a GPU would be great with Google’s Compute Engine. We can also set that up for you as well.
  4. Running about 50 epochs seems to be the sweet spot when it comes to training the model.
  5. There are other pre-trained backbones to try such as Resnet50 and MobileNetV2 instead of the InceptionV3 checkpoint.

We can adapt this idea with Tensorflow 2.0 code and many others such as:

And even some Google API use cases such as:

If you enjoyed this post, reach out or tweet us @springmlinc. to us so we can show you how we can grow your business and stay competitive in this rapidly changing marketplace with the latest in analytics, Machine Learning, and AI.