What are we doing?

Sometimes we want to see inputs and outputs of PyTorch layers to build an intuition of what they do. If I've read the docs and put a few tensors through the layer while checking the inputs and outputs shapes, generally that's enough.

But sometimes there's weird parameters that I can't get my head around or I just want to see it working, so building interactive widgets helps me grow my understanding.

So in this post I'll show you how I built an interactive widget to explore PyTorch's ConvTranspose1d, while explaining a bit about the layer itself. We'll use Anacondas's HoloViz tools (Holoviews, Panel and Bokeh) for the plotting and interactivity.

The end goal is to have a interactive plot for interacting with ConvTranspose1d parameters and seeing the output like this tweet.

Introduction to Transposed Convolutions

Before learning about Transposed Convolutions, you're best learning about Convolutions first. CS231n is a great resource for learning about them.

As you may know, Convolutions are often used to efficiently reduce a dimensions of the input in neural networks. In the case of image classification tasks, they are used to efficiently reduce an input image to a single class score.

Transposed Convolutions are useful when you want to grow your network in a certain dimension. For example, say you have a image segmentation task, in which you want a class prediction per pixel, you can use strided Convolutions to reduce the dimensions and then grow the dimensions back to their original sizel with Transposed Convolutions. This is done in U-net style architectures.

Conveniently, PyTorch has implemented ConvTranspose1d such that if it has the same input parameters as Conv1d and if you pass a tensor through both, the output tensor will be the same shape as the input tensor (provided you set output_padding correclty).


import torch
import torch.nn as nn
from panel.interact import interact
from panel import widgets
import panel as pn
from IPython.display import display
import holoviews as hv
from holoviews import opts
import numpy as np
hv.extension('bokeh', logo=False)