I realized that in the last few months, I’ve spent a lot of time reading about generative modeling in general, with a fair bit of nonsense rhapsodizing about this and that, as one often does when one sees things the first time. I can see that I’ll be working a lot with RNNs in the near future, so I decided to get my hands dirty with pytorch’s RNN offerings. There is little else to do anyway in the heat.
I am working on the “DRAW” paper by Gregor et. al.
This is a natural extension of the Variational Autoencoder formulation by Kingma and Welling, Rezende and Mohamed. The paper appeals to the idea that we can improve upon the VAE’s handiwork by iteratively refining it’s output over the course of several time steps. There is thus a temporal, sequential aspect that comes in. In addition, they also add a spatial attention mechanism wherein one ‘attends’ to portions of the image as to improve them in small NxN patches of a larger image. I am posting on the first part now, which only uses the RNNs. It might be a few more days before I can finish the second part.
Initially, I thought that we just have to pick from pytorch’s RNN modules (LSTM, GRU, vanilla RNN, etc.) and build up the layers in a straightforward way, as one does on paper. But then, some complications emerged, necessitating disconnected explorations to figure out the API. Firstly, we should all be aware of PyTorch’s way of creating arrays (well, I’ve not used any other frameworks except for caffe, and that too, only for benchmarking runs, so it’s all quite new to me) which demands that we include the batch size during initialization.
So for example, if we want to create an input of size 784 (as in MNIST), we must also pass the batch size variable as input:
x = Variable(torch.randn(batch_size, input_size))
We do this in most of our initializations. The first dimension is the batch size. However, RNN variables have an additional dimension, which is the sequence length.
x_rnn = Variable(torch.zeros(seq_len, batch_size, input_size))
This is apparent in retrospect in the documentation (http://pytorch.org/docs/master/nn.html – see under RNN), but we have to play with it in order to make sure. When we unroll this entity, we get the unrolled form with seq_len of them in number.
seq_len * Variable(batch_size, input_size)
Let us look at the example given in the documentation page:
>>> rnn = nn.GRU(10, 20, 2)
>>> input = Variable(torch.randn(5, 3, 10))
>>> h0 = Variable(torch.randn(2, 3, 20))
>>> output, hn = rnn(input, h0)
This defines a GRU of the following form:
rnn (input_size, hidden_size, num_hidden_layers)
We should note that the function curiously returns two outputs: output, hidden. The first output (output) contains the last hidden layer, while ‘hidden’ contains all the hidden layers from the last time step , which we can verify from the ‘size()’ method.
‘output’ is of shape
(seq_len, batch_size, hidden_size) . It contains the sequence of hidden layer outputs from the last hidden layer.
>>>torch.Size([5, 3, 20])
(or torch.size([seq_len, batch_size, hidden_size]))
I find the purpose of ‘hidden’ a little enigmatic. It supposedly contains the hidden layer outputs from the last timestep in the sequence t = seq_len.
>>>torch.Size([2, 3, 20])
(or torch.Size([num_hidden_layers, batch_size, hidden_size)]))
The hidden layer can be bi-directional. Apparently, the default (as we might expect) is a standard uni-directional RNN. The documentation clarifies this:
h_n (num_layers * num_directions, batch, hidden_size)
It helps to remember that the quantity they call ‘output’ is really the hidden layer. The output of an RNN is the hidden variable which we then do something with:
In my experiments I used GRUCell because it seemed intuitive to set up at that time.
Note: I think in the above, we can replace the 2 RNNs used in the encoder (one each for with a single RNN – as can be made out from the clipping below from “Generating Sentences from a Continuous Space”:
In DRAW, we need a connection from the decoder from the previous timestep. Specifically:
They define a cat operation to concatenate two tensors.
As we can make out from the hand written figure (sorry, but that’s just the most efficient way factoring in things such as laziness), at each time step, the VAE encoder takes in the output of the read operation, which then gets encoded into the latent embedding . Furthermore, at each timestep, we take in the same input image , and then give it the previous timestep’s output image to create the sequence, together with the decoder output as well. In that sense the sequence is actually defined by the quantity or the oputput of read:
The read operation is to be implemented. This is done differently depending on whether or not we put in spatial attention. Nevertheless, we can in a rough way make sense of it from a line in the paper:
“Moreover the encoder is privy to the decoder’s previous outputs, allow-
ing it to tailor the codes it sends according to the decoder’s
behaviour so far”
In our experiments, we used the RNNCell (or more precisely, the GRUCell) to handle the sequence, with a manual for loop to do the time stepping – the most intuitive way, if I may say so. In the forward method of the class, we create the set of operations comprising DRAW:
for seq in range(T):
x_hat = x - F.sigmoid(c) # error image
r = self.read(x, x_hat) #cat operation
#encoder output == Q_mu, Q_sigma
mu, h_mu, logvar, h_logvar = self.encoder_RNN(r, h_mu, h_logvar, h_dec, seq)
z = self.reparametrize_and_sample(mu, logvar)
c, h_dec = self.decoder_network(z, h_dec, c) #c is the canvas
Naturally, the RNN layers handle each individual timestep rather than batching the whole sequence together:
The API is as follows:
>>> rnn = nn.RNNCell(10, 20) #(input_size, hidden_size)
>>> input = Variable(torch.randn(6, 3, 10)) #(seq_len, batch_size, input_size)
>>> hx = Variable(torch.randn(3, 20)) #(batch_size, hidden_size)
>>> output = 
>>> for i in range(6): #time advance
... hx = rnn(input[i], hx) #
... output.append(hx) #add to sequence
In addition to the vanilla RNNCell, also included in PyTorch are the GRU and LSTM variants.
I hope to put up a more descriptive post (with feeling!) of DRAW. But for now, I have what seems to be a quasi working implementation without the attention mechanism. In the figures below, we can see that there is a qualitative improvement in the figures as we add refinement timesteps to it.
The code may be found here: https://github.com/pravn/vae_draw