Seeing What Comes Next: Transformers predict future video frames


If a robot can predict what it’s likely to see next, it may have a better basis for choosing an appropriate action — but it has to predict quickly. Transformers, for all their utility in computer vision, aren’t well suited to this because of their steep computational and memory requirements. A new approach could change that.

What’s new: Agrim Gupta and colleagues at Stanford devised Masked Visual Pre-Training for Video Prediction (MaskViT), a transformer model that generates likely future video frames with far less computation than earlier transformer-based approaches. You can see its output here.

Key insight: Transformers typically predict one token per forward pass (processing every layer in the model from first to last). The amount of processing required for this approach is manageable when generating an image, which may be divided among hundreds or thousands of tokens. But it becomes very time-consuming when generating video, which involves many images. Predicting multiple tokens at once reduces the number of forward passes needed to generate video, significantly accelerating the process.

How it works: MaskViT consists of an image tokenizer (VQ-GAN, a discrete variational autoencoder) and a transformer. The authors trained and tested it on three video datasets: RoboNet (15 million frames that depict robotic arms interacting with objects), BAIR (a smaller dataset that shows a robot pushing things on a table top), and KITTI (57 videos recorded from a car driving on roads in Germany). The model generated 10 to 25 video frames, depending on the dataset, following between one and five initial frames, depending on the dataset.

  • The authors trained VQ-GAN to reconstruct video frames. Given all frames in a video, the trained VQ-GAN encoder tokenized each frame into a 16x16 grid of tokens.
  • The system randomly masked from 50 percent to almost 100 percent of tokens.
  • The transformer processed the tokens through two alternating types of layers, each a modified version of the base transformer layer. The first type learned spatial patterns by applying self-attention to each of 16 sequential frames (16x16 tokens) individually. The second type learned temporal patterns by limiting attention to a window of 4x4 tokens across the frames.
  • The loss function encouraged the model to generate masked tokens correctly.
  • Inference proceeded gradually, in 7 to 64 forward passes, depending on the dataset. In each forward pass, the model received tokens that represent the initial frame(s) plus tokens it had predicted so far. It predicted a fixed percentage of remaining masked tokens. The process repeated until all tokens were predicted.
  • The VQ-GAN decoder turned the tokens back into frames.

Results: The authors compared their model’s efficiency at inference with that of earlier transformer-based approaches. On BAIR, for instance, MaskViT required 24 forward passes to generate 15 frames, while the previous state of the art, VT, needed 3,840. With respect to its predictive ability, on BAIR, MaskViT achieved 93.7 Fréchet Video Distance (FVD), a measure of how well a generated distribution resembles the original distribution, for which lower is better. That’s better than VT (94.0 FVD) and roughly equal to the best non-transformer approach, FitVid (93.6 FVD). On the more complicated RoboNet dataset, MaskViT achieved 133.5 FVD, while FitVid achieved 62.5 FVD. (VT results on that dataset are not reported.)

Yes, but: The authors compared numbers of forward passes at inference, but they didn’t compare processing time. Different models take different amounts of time to run, so there’s no guarantee that a smaller number of forward passes takes less time. That said, given differences between the options for hardware, machine learning libraries, and programming languages, it would be hard to compare execution speeds directly.

Why it matters: While the reduction of forward passes is notable, the authors also came up with an interesting way to improve output quality. During inference, 100 percent of the tokens to be generated start out missing and fill in slowly over the generation process. However, in the typical training practice, which masks a fixed percentage of tokens, the model never encounters such a large percentage of missing tokens. Instead, during training, the authors masked a variable portion of tokens up to 100 percent. This procedure better aligned the tasks during training and inference, which yielded better results.

We’re thinking: Giving robots the ability to predict visual changes could make for a generation of much safer and more capable machines. We look forward to future work that integrates this capability with planning algorithms.

1 Like