Generate an image grid using Python

Generate an image grid using Python

I have been working with some image processing stuff recently and I am surprised to see that the Python ecosystem has a plethora of libraries available to solve every image processing use case

In my previous post, I covered how to convert a PDF document to a set of images. This post is somewhat the reverse of it, but not entirely. We will see how to generate an image grid by combining a set of images using Python and PyTorch Vision

All you need for this is torchvision installed on your machine and a few images lying around to be trapped into a grid

pip install torchvision

These are the individual images that will be added to the resultant grid

Input images

I like Python for its simplicity because oftentimes, complicated problems could be ground down to just a few lines of code. Below is all the code it takes to generate an image grid

from torchvision.io import read_image
from torchvision.utils import make_grid, save_image
from torchvision.transforms import transforms
from torchvision import torch

import os

tensors = []

transform = transforms.Compose([
    transforms.CenterCrop(1000),
    transforms.ConvertImageDtype(dtype=torch.float),
])

for file in os.listdir('./images'):
    image = os.path.join('./images', file)
    transformed_tensor = transform(read_image(image))
    tensors.append(transformed_tensor)

grid = make_grid(tensors, nrow=2, padding=5)

save_image(grid, "grid.jpg")

Now let's see what each block does

Transformers

transform = transforms.Compose([
    transforms.CenterCrop(1000),
    transforms.ConvertImageDtype(dtype=torch.float),
])

The above block includes a list of transformers to transform the images into the desired vector.

torchvision expects all the input images to be of the same dimension. If the dimension of your first image is [1024 x 480] then the rest of the images should be of the same dimension, else you will be greeted with an exception

RuntimeError: stack expects each tensor to be equal size, but got [3, 816, 1456] at entry 0 and [3, 1024, 1024] at entry 1

To mitigate this, we are cropping the image from the center using CenterCrop with a uniform dimension (square dimension) of [1000 x 1000]. If the original image dimension is smaller than this value, then the empty space will be filled with black pixels (like the back letterbox bars you see in a movie to preserve the aspect ratio)

The crop transformer returns a tensor with uint8 as the datatype, but the save_image function expects the dtype to be float. This is when the ConvertImageDtype transformer comes in handy, using which we can convert the tensor datatype to float

If you have a list of images with the same dimension then feel free to skip the transformers

Read the images

for file in os.listdir('./images'):
    image = os.path.join('./images', file)
    transformed_tensor = transform(read_image(image))
    images.append(transformed_tensor)

Once the transformers are setup we can directly go ahead and read the images. The make_grid function expects a list of tensors and the following are the steps required to achieve just that

  • List the directory to get all the image file names

  • Use read_image function from the torchvision.io module to read the images one by one

  • The read_image function returns a tensor and we transform the tensor to the desired form using the transformers, which we have already setup

  • Store the tensors in a List that will be used to generate the grid

Generate and save the grid

grid = make_grid(images, nrow=2, padding=5)
save_image(grid, "grid.jpg")

The make_grid function expects the tensor list to be the first argument.

  • nrow argument defines the number of images you expect to see in a single row

  • padding argument just adds padding between the images in the grid

You get a tensor back from the make_grid with float as the datatype and the same can be passed to the save_image function to save the grid as an image file (grid.jpg in this example)

The following is the result

Output grid

Conclusion

If you are not familiar with PyTorch, then it is a widely used deep learning library. The torchvision is a part of the PyTorch project and it is used for solving complex computer vision problems. This post covers a very minuscule problem that I recently solved in one of my projects

To learn more about this library, you can take a look at the official docs

PyTorch

Torchvision

Google colab notebook