Free-ATM: Exploring Unsupervised Learning on Diffusion-Generated Images with Free Attention Masks

Authors: David Junhao Zhang, Mutian Xu, Chuhui Xue, Wenqing Zhang, Xiaoguang Han, Song Bai, Mike Zheng Shou

Abstract: Despite the rapid advancement of unsupervised learning in visual representation, it requires training on large-scale datasets that demand costly data collection, and pose additional challenges due to concerns regarding data privacy. Recently, synthetic images generated by text-to-image diffusion models, have shown great potential for benefiting image recognition. Although promising, there has been inadequate exploration dedicated to unsupervised learning on diffusion-generated images. To address this, we start by uncovering that diffusion models’ cross-attention layers inherently provide annotation-free attention masks aligned with corresponding text inputs on generated images. We then investigate the problems of three prevalent unsupervised learning techniques ( i.e., contrastive learning, masked modeling, and vision-language pretraining) and introduce customized solutions by fully exploiting the aforementioned free attention masks. Our approach is validated through extensive experiments that show consistent improvements in baseline models across various downstream tasks, including image classification, detection, segmentation, and image-text retrieval. By utilizing our method, it is possible to close the performance gap between unsupervised pretraining on synthetic data and real-world scenarios.

What, Why and How

Here is a summary of the key points from this paper:

What:

  • Proposes a technique called Free-ATM to improve unsupervised learning on images generated by text-to-image diffusion models.

Why:

  • Diffusion models provide free attention masks aligned to text prompts, which can enhance unsupervised learning.
  • Unsupervised pretraining on synthetic images can avoid costly data collection and annotation.

How:

  • Leverages attention masks for contrastive learning, masked modeling, and vision-language pretraining.

  • For contrastive learning, uses masks to get positive/negative pairs from instances.

  • For masked modeling, gradually masks more foreground patches based on masks.

  • For vision-language pretraining, uses masks for position prompts.

  • Shows consistent performance improvements on various downstream tasks like classification, detection, segmentation, retrieval.

  • Closes the gap between unsupervised pretraining on synthetic vs real images.

  • Mixing synthetic and real data can further boost performance.

In summary, this paper provides a method to better utilize the free attention masks from diffusion models to enhance unsupervised learning on synthetic images, avoiding costly data collection while improving performance on downstream tasks.

Main Contributions

Here are the main contributions of this paper:

  • Uncovers that diffusion models provide free attention masks aligned with text prompts on generated images, without needing manual annotations.

  • Identifies limitations of current unsupervised learning techniques like contrastive learning, masked modeling, and vision-language pretraining when applied to diverse synthetic images.

  • Proposes customized solutions called Free-ATM that utilize the free attention masks to enhance these unsupervised learning frameworks on synthetic data.

  • For contrastive learning, uses attention masks to extract positive/negative pairs at the instance level rather than image level.

  • For masked modeling, gradually increases masking of foreground patches based on attention masks.

  • For vision-language pretraining, uses masks to create positional prompts.

  • Validates the effectiveness of Free-ATM through extensive experiments on downstream tasks like classification, detection, segmentation, and retrieval.

  • Shows consistent improvements over baseline models trained on synthetic images, closing the gap with models trained on real images.

  • Demonstrates additional gains by mixing synthetic and real data for pretraining.

In summary, the key contributions are identifying limitations of existing methods on synthetic data, proposing tailored solutions to address those limitations by exploiting free attention masks, and thoroughly validating the improvements enabled by the proposed Free-ATM framework.

Method Section

Here is a summary of the method section from the paper:

The paper proposes Free-ATM, a technique to enhance unsupervised learning on synthetic images from diffusion models by utilizing the free attention masks.

Attention Mask Extraction:

  • Uses a latent diffusion model to generate images from text prompts.
  • The cross-attention layers provide attention masks aligned with input text.
  • Compiles attention maps across layers and timesteps to get a final mask.

Prompt Generation:

  • Uses ImageNet label space and GPT-3.5 to create diverse and realistic prompts.

Utilizing Attention Masks:

  • Applies masks to improve contrastive learning, masked modeling, and vision-language pretraining.

For contrastive learning:

  • Uses instance-level features based on masks instead of image-level.
  • Positive pairs are from same instance, negatives from different instances.

For masked modeling:

  • Gradually increases masking of foreground patches based on attention scores.
  • Balances learning universal and focused representations.

For vision-language pretraining:

  • Gets bounding boxes from masks to create positional prompts.
  • Avoids costly object detection and enhances grounding.

In summary, the method section explains how free attention masks are extracted from diffusion models, how prompts are created, and how the masks are utilized to enhance three different unsupervised learning frameworks.

High-Level Pseudo Code

Here is the high-level pseudo code for the key components of the paper:

# Generate images and extract attention masks
prompts = generate_diverse_prompts(imagenet_labels) 
images, masks = diffusion_model(prompts)
 
# Contrastive learning 
for x, x' in augmented_image_crops:
  z, z' = encoder(x), encoder(x')  
  for m in masks:
    z_inst = attentive_pool(z, m) 
    z'_inst = attentive_pool(z', m)
    loss = instance_contrastive_loss(z_inst, z'_inst)
 
# Masked modeling
for x in images:
  x_masked = gradually_mask(x, masks) 
  x_recon = decoder(x_masked)
  loss = reconstruction_loss(x, x_recon)
 
# Vision-language pretraining  
for x, p in (image, prompt) pairs:
  bbox = get_bbox(p_mask)
  p_pos = generate_position_prompt(p, bbox) 
  z = encoder(x)
  h = text_encoder(p_pos)
  loss = vision_language_pretrain_loss(z, h) 

This covers:

  • Generating images and extracting attention masks
  • Using masks for contrastive learning on instances
  • Gradually masking foreground for masked modeling
  • Getting bboxes from masks for positional prompts in vision-language pretraining

The overall framework utilizes the “free” attention masks from diffusion models to enhance unsupervised learning on synthetic images.

Detailed Pseudo Code

Here is a more detailed pseudo code to implement the key components of the paper:

# Image generation and mask extraction
 
prompt_templates = ["template 1", "template 2", ...] 
 
prompts = []
for label in imagenet_labels:
  p = gpt3.generate_prompt(label, prompt_templates)  
  prompts.append(p)
 
images, masks = [] 
for p in prompts:
  image, mask = diffusion_model(p) 
  images.append(image)
  masks.append(mask)
 
# Contrastive learning
 
for x, x' in zip(augmented_crops(images), augmented_crops(images)):
  
  # Encoder
  z = encoder(x) # activations before avg pool
  z' = encoder(x')
 
  inst_feats = []
  inst_feats' = []
 
  for m, m' in zip(crop_masks(masks), crop_masks(masks)):
    
    # Attentive pooling 
    z_inst = attentive_pool(z, m)
    z_inst' = attentive_pool(z', m')
    
    inst_feats.append(z_inst)
    inst_feats'.append(z_inst')
 
  # Contrastive loss
  loss = 0
  for z, z' in zip(inst_feats, inst_feats'):
    loss += instance_contrastive_loss(z, z') 
 
# Masked modeling
 
for x, masks in zip(images, masks):
 
  x_masked = gradually_mask(x, masks)
 
  # Reconstruction
  x_recon = decoder(x_masked)  
 
  # Loss
  loss = reconstruction_loss(x, x_recon)
 
# Vision-language pretraining
 
for (x, p), masks in zip(image_text_pairs, masks):
 
  # Postion encoding
  bbox = get_bbox(masks)
  p_pos = generate_position_prompt(p, bbox)
 
  # Encoder
  z = encoder(x)  
  h = text_encoder(p_pos)
 
  # Pretrain loss
  loss = pretrain_loss(z, h) 

This provides more implementation details for generating images, extracting masks, and using the masks to enhance contrastive learning, masked modeling, and vision-language pretraining.