InstantBooth: Personalized Text-to-Image Generation without Test-Time Finetuning

Authors: Jing Shi, Wei Xiong, Zhe Lin, Hyun Joon Jung

Abstract: Recent advances in personalized image generation allow a pre-trained text-to-image model to learn a new concept from a set of images. However, existing personalization approaches usually require heavy test-time finetuning for each concept, which is time-consuming and difficult to scale. We propose InstantBooth, a novel approach built upon pre-trained text-to-image models that enables instant text-guided image personalization without any test-time finetuning. We achieve this with several major components. First, we learn the general concept of the input images by converting them to a textual token with a learnable image encoder. Second, to keep the fine details of the identity, we learn rich visual feature representation by introducing a few adapter layers to the pre-trained model. We train our components only on text-image pairs without using paired images of the same concept. Compared to test-time finetuning-based methods like DreamBooth and Textual-Inversion, our model can generate competitive results on unseen concepts concerning language-image alignment, image fidelity, and identity preservation while being 100 times faster.

What, Why and How

Here is a summary of the key points from this paper:

What:

  • The paper proposes a novel approach called InstantBooth for personalized text-to-image generation without requiring test-time finetuning.

Why:

  • Existing personalized image generation methods like DreamBooth and Textual Inversion require heavy finetuning of the model for each new concept, which is time-consuming and difficult to scale.

How:

  • They introduce a learnable image encoder to convert input images to a compact textual concept embedding, capturing the general semantic information.
  • They add trainable adapter layers to inject rich visual features from the input images into a frozen pre-trained text-to-image model, preserving identity details.
  • The model is trained on text-image pairs without paired images of the same concept.
  • During inference, the frozen pre-trained model and adapter layers generate personalized images in a single forward pass for new concepts.

In summary, this paper proposes a novel architecture and training approach to achieve fast personalized text-to-image generation without test-time finetuning, while generating results competitive with state-of-the-art finetuning-based methods. The efficiency improvement enables scalable practical applications.

Main Contributions

Here are the main contributions of this paper:

  • Proposes one of the first approaches for personalized text-to-image generation without requiring test-time finetuning of the model. This is a significant improvement in efficiency over existing methods.

  • Designs a new model architecture with a learnable image encoder to convert input images to a textual concept embedding for capturing general semantic information about a new concept.

  • Introduces trainable adapter layers to inject rich visual features from input images into a frozen pre-trained text-to-image model to preserve fine-grained identity details.

  • Demonstrates that the model can be trained on text-image pairs without requiring paired images of the same concept, improving generalization.

  • Shows that the proposed model can achieve results competitive with state-of-the-art finetuning-based methods like DreamBooth and Textual Inversion in terms of image quality, alignment, and identity preservation.

  • Provides extensive analysis of different model design choices through ablation studies.

  • Provides both quantitative and qualitative results on personalized image generation for unseen concepts with only a single forward pass, which is 100x faster than finetuning-based approaches.

In summary, the key contributions are around the novel model architecture, training approach, generalization ability, and significant efficiency improvement over existing methods.

Method Section

Here is a summary of the method section from the paper:

The overall framework is built on top of a pre-trained text-to-image diffusion model.

Data Pre-Processing:

  • Crop and mask the input images to focus on the target object. Apply augmentations during training.

Prompt Creation:

  • Inject a unique identifier token into the prompt to represent the concept.

Concept Embedding Learning:

  • Use a learnable image encoder to map the input images to a compact concept embedding vector in the textual space.

Representation Learning with Adapters:

  • Extract rich patch features from input images using another encoder.
  • Introduce trainable adapter layers in the U-Net of the pre-trained model to inject these visual features.
  • The adapter layers are formulated as self-attention over the visual tokens concatenated with patch features.

Model Training:

  • Train only the image encoders and adapter layers, while freezing the pre-trained model weights.
  • Use the standard denoising loss with text-image pairs as training data.

Model Inference:

  • Can take arbitrary number of input images due to concatenation in adapter.
  • Use a balanced sampling schedule between the adapter branch and pretrained model.
  • Apply concept embedding renormalization for better text-image alignment.

In summary, the key ideas are learning compact concept embeddings, injecting rich visual features via adapters, and freezing the pre-trained model to retain its capabilities.

High-Level Pseudo Code

Here is the high-level pseudo code for the key components of the method proposed in this paper:

# Pre-process input images
for image in input_images:
  cropped_image = crop_and_mask(image) 
  augmented_image = augment(cropped_image)
 
# Create prompt 
prompt = inject_identifier(original_prompt)
 
# Learn concept embedding  
concept_embeddings = []
for image in augmented_images:
  embedding = image_encoder(image)
  concept_embeddings.append(embedding)
concept_embed = average(concept_embeddings)
 
# Get prompt embeddings
prompt_embed = text_encoder(prompt) 
prompt_embed[identifier_token_idx] = concept_embed
 
# Pass embeddings to model
adapters = build_adapters(pretained_model) # Add adapters
patch_features = patch_encoder(cropped_images) 
 
for t in timesteps:
  # Forward pass
  noise = pretained_model(prompt_embed, patch_features, adapters)  
 
  # Loss calculation
  loss = denoising_loss(noise, target_image)
  
  # Backprop
  loss.backward()
  optimizer.step()

The key steps are pre-processing the input images, creating the prompt with identifier, learning the concept embedding from images, getting prompt embeddings from text, passing both to the model with adapters, and training with the denoising loss.

Detailed Pseudo Code

Here is a more detailed pseudo code for implementing the key components of the method proposed in this paper:

# Image encoders
image_encoder = ClipEncoder() # Initialize with CLIP weights
concept_encoder = nn.Sequential(image_encoder, nn.Linear(512, 512)) # Append linear layer
patch_encoder = nn.Sequential(image_encoder, nn.Linear(512, 257*32)) 
 
# Load pre-trained model
diffusion_model = StableDiffusion.load_from_ckpt() 
 
# Build adapter layers
adapters = []
for block in diffusion_model.unet.blocks:
  adapter = Adapter(block.attn) # Self-attn layer
  adapters.append(adapter)
 
# Training loop
for epoch in epochs:
 
  for batch in data_loader:
  
    text, images = batch 
    cropped_images = crop_and_mask(images)
    aug_images = augment(cropped_images)
 
    # Create concept embedding
    concepts = []
    for img in aug_images:
      concept = concept_encoder(img) 
      concepts.append(concept)
    concept_embed = torch.mean(torch.stack(concepts), dim=0)
 
    # Pass images through patch encoder
    patch_features = patch_encoder(cropped_images)
 
    # Modify text prompt
    prompt = inject_identifier(text) 
    prompt_embed = text_encoder(prompt)
    prompt_embed[:,id_token_idx] = concept_embed
 
    # Forward pass
    noise_pred = diffusion_model(prompt_embed, patch_features, adapters)
    loss = denoising_loss(noise_pred, images)
 
    # Update
    loss.backward()
    optimizer.step()
 
# Generate images
with torch.no_grad():
  
  concept_embed = concept_encoder(input_image)
  prompt_embed = text_encoder(prompt) 
  prompt_embed[:,id_token_idx] = concept_embed
 
  noise_pred = diffusion_model(prompt_embed, patch_features, adapters)
  image = decoder(noise_pred) 

The key aspects covered are:

  • Initializing the different encoders
  • Loading the pre-trained diffusion model
  • Building adapter layers
  • Creating concept embedding from images
  • Getting prompt embeddings
  • Passing both to diffusion model with adapters
  • Training loop with denoising loss
  • Generating images at test time