Diffusion Transformed: A new class of diffusion models based on the transformer architecture

unnamed (80)
A tweak to diffusion models, which are responsible for most of the recent excitement about AI-generated images, enables them to produce more realistic output.

What’s new: William Peebles at UC Berkeley and Saining Xie at New York University improved a diffusion model by replacing a key component, a U-Net convolutional neural network, with a transformer. They call the work Diffusion Transformer (DiT).

Diffusion basics: During training, a diffusion model takes an image to which noise has been added, a descriptive embedding (typically an embedding of a text phrase that describes the original image, in this experiment, the image’s class), and an embedding of the current time step. The system learns to use the descriptive embedding to remove the noise in successive time steps. At inference, it generates an image by starting with pure noise and a descriptive embedding and removing noise iteratively according to that embedding. A variant known as a latent diffusion model saves computation by removing noise not from an image but from an image embedding that represents it.

Key insight: In a typical diffusion model, a U-Net convolutional neural network (CNN) learns to estimate the noise to be removed from an image. Recent work showed that transformers outperform CNNs in many computer vision tasks. Replacing the CNN with a transformer can lead to similar gains.

How it works: The authors modified a latent diffusion model (specifically Stable Diffusion) by putting a transformer at its core. They trained it on ImageNet in the usual manner for diffusion models.

  • To accommodate the transformer, the system broke the noisy image embeddings into a series of tokens.
  • Within the transformer, modified transformer blocks learned to process the tokens to produce an estimate of the noise.
  • Before each attention and fully connected layer, the system multiplied the tokens by a separate vector based on the image class and time step embeddings. (A vanilla neural network, trained with the transformer, computed this vector.)

Results: The authors assessed the quality of DiT’s output according to Fréchet Interception Distance (FID), which measures how the distribution of a generated version of an image compares to the distribution of the original (lower is better). FID improved depending on the processing budget: On 256-by-256-pixel ImageNet images, a small DiT with 6 gigaflops of compute achieved 68.4 FID, a large DiT with 80.7 gigaflops achieved 23.3 FID, and the largest DiT with 119 gigaflops achieved 9.62 FID. A latent diffusion model that used a U-Net (104 gigaflops) achieved 10.56 FID.

Why it matters: Given more processing power and data, transformers achieve better performance than other architectures in numerous tasks. This goes for the authors’ transformer-enhanced diffusion model as well.

We’re thinking: Transformers continue to replace CNNs for many tasks. We’ll see if this replacement sticks.

1 Like