Train an Image captioning model 🧠 using Pytorch


Today, we are going to train deep learning model using Pytorch for image captioning. I will use the quite famous dataset called Common Objects in Context - COCO.

Setup

To start doing that, we need to first setup our workspace and download the data. Please execute the following statements in your bash or command line:

mkdir -p data/img
pip install wget
pip install torch
pip install torchvision

The first line allows us to have a directory where to save our JSON and JPG files. Next lines allow to install some python packages.

Download dataset

Download JSON files

To download our JSON files, you could write the following code inside a Python file and later execute it.

import wget
import zipfile

remote_url = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'
local_file = 'data/annotations_trainval2014.zip'
wget.download(remote_url, local_file)

with zipfile.ZipFile(local_file, 'r') as zip:
    zip.extractall('data')

Before continue, we need to define some path variables

data_path = 'data'
imgs_path = "data/img"

If you execute the following code

captions_json = {f.replace('captions_','').replace('2014.json',''):f for f in os.listdir(data_path) if f.startswith('caption')}
print(captions_json)

You should see something like this 👇:

{'val': 'captions_val2014.json', 'train': 'captions_train2014.json'}

Now, we could load our JSON files to see how is the structure:

import json

with open(os.path.join(data_path, captions_json['train']), 'r') as f:
    data_train = json.load(f)

The key images keeps information about the JPG files, and it looks like this:

print(data_train['images'][0])
{'license': 5,
 'file_name': 'COCO_train2014_000000057870.jpg',
 'coco_url': 'http://images.cocodataset.org/train2014/COCO_train2014_000000057870.jpg',
 'height': 480,
 'width': 640,
 'date_captured': '2013-11-14 16:28:13',
 'flickr_url': 'http://farm4.staticflickr.com/3153/2970773875_164f0c0b83_z.jpg',
 'id': 57870}

The another important key annotations keeps information about the captions.

print(data_train['annotations'][0])
{'image_id': 318556,
 'id': 48,
 'caption': 'A very clean and well decorated empty bathroom'}

We have 82783 images for the train dataset, and each of them have between 5, 6 or 7 captions.

Previous to download images, we could create a dictionary of captions per images because I just knew images could have more than 5 captions.

annot_by_imageid = {}
if unique:
    for annot in data_train['annotations']:
        if annot_by_imageid.get(annot['image_id'], 0) == 0:
            annot_by_imageid[annot['image_id']] = annot['caption'].strip('\n')

Download images

You could download all images or just using some of them. To download the images you could use the following code:

import os
from tqdm.auto import tqdm
import wget 

imgs_train = {
    'id':[],
    'filename':[],
    'captions': []
}
for img in tqdm(data_train['images']):
    try:
        remote_url = img['coco_url']
        local_file = os.path.join(imgs_path, img['file_name'])
        if not os.path.exists(local_file):
            wget.download(remote_url, local_file)
            imgs_train['filename'].append(img['file_name'])
            imgs_train['id'].append(img['id'])
            imgs['captions'].append(annot_by_imageid[int(img['id'])])
    except:
        pass

After last step, you can create a dataframe from imgs_train variable.

import pandas as pd
df_train = pd.DataFrame(imgs_train) 

Later, you could take a look at images using the following function:

Plot images

def plot_images(df, img_path, num_images=5):
    """
        df: dataframe of ids, image filanames, captions
        img_path: where images are located
        num_images, number of images to plot
    """
    df_tmp = df.sample(num_images).reset_index(drop=True)
    plt.figure(figsize=(20, 20))
    
    for idx, row in df_tmp.iterrows():
        
        ax = plt.subplot(1, num_images, idx + 1)
        
        row_image = Image.open(os.path.join(img_path, row['filename']))
        
        caption = row['captions']
        caption = "\n".join(wrap(caption, 32))
        
        plt.title(caption)
        plt.imshow(row_image)
        plt.axis("off")
© 2025 Giordano Alvitez