Visualize Bounding Boxes

Zarrar Shehzad · November 22, 2021

For certain computer vision tasks, I have an image and bounding boxes associated with that image in a text file. How would I visualize these boxes?

Relevant Packages

# Plotting
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Loading the Image
from PIL import Image
# Loading the text file
import pandas as pd

Classes for the Boxes

For Yolo models, I might have a separate file that lists the different class names associated with each box. The row number (0-indexed) gives me the class number.

with open('obj.names', 'r') as f:
    classes = f.readlines()
    classes = [ line.rstrip('\n') for line in classes ]
classes
['slide',
 'title',
 'banner',
 'speaker',
 'camera_on',
 'camera_off',
 'cue_zoom',
 'cue_teams',
 'cue_meet',
 'cue_uberconference']

Load Data

Here I load the image and its associated text file.

im = Image.open('sample_input.jpg')
im

df = pd.read_csv("sample_input.txt", sep=" ", header=None, names=['idx', 'x', 'y', 'w', 'h'])
df
idx	x	        y	        w	        h
0   0.500000	0.500000	1.000000	1.000000
1	  0.182031	0.133333	0.285938	0.066667

Note that each row is a bounding box. The x,y,w,h values for the bounding box are normalized, that is they are the raw pixel values divided by the image size. I give some details on each value below in the viz_frame documentation.

Visualize

I convert the normalized values in df to raw pixel values and plot the bounding boxes over the image.

def viz_frame(im, df):
    """
    Will visualize the coordinates in `df` on an image `im`.
    
    Each row of the input `data-frame` is a particular bounding box.
    The columns of the input `data-frame` should be:
    * x: Center x coordinate of the bounding box normalized by the width of the image.
    * y: Center y coordinate of the bounding box normalized by the height of the image.
    * w: Normalized width of the bounding box
    * h: Normalized height of the bounding box
    """
    # Create figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(im)
    
    for _,row in df.iterrows():
        # Convert the normalized coordinates to raw pixels
        iw = im.size[0]
        ih = im.size[1]
        w = row.w * iw
        h = row.h * ih
        x = row.x*iw - w/2. # center x
        y = row.y*ih - h/2. # center y
        # Create a Rectangle patch
        rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='red', fill=False)
        # Add the patch to the Axes
        ax.add_patch(rect)

    plt.show()
viz_frame(im, df)

And that’s it!

Twitter, Facebook