Editing Implicit Assumptions in Text-to-Image Diffusion Models
Authors: Hadas Orgad, Bahjat Kawar, Yonatan Belinkov
Abstract: Text-to-image diffusion models often make implicit assumptions about the world when generating images. While some assumptions are useful (e.g., the sky is blue), they can also be outdated, incorrect, or reflective of social biases present in the training data. Thus, there is a need to control these assumptions without requiring explicit user input or costly re-training. In this work, we aim to edit a given implicit assumption in a pre-trained diffusion model. Our Text-to-Image Model Editing method, TIME for short, receives a pair of inputs: a “source” under-specified prompt for which the model makes an implicit assumption (e.g., “a pack of roses”), and a “destination” prompt that describes the same setting, but with a specified desired attribute (e.g., “a pack of blue roses”). TIME then updates the model’s cross-attention layers, as these layers assign visual meaning to textual tokens. We edit the projection matrices in these layers such that the source prompt is projected close to the destination prompt. Our method is highly efficient, as it modifies a mere 2.2% of the model’s parameters in under one second. To evaluate model editing approaches, we introduce TIMED (TIME Dataset), containing 147 source and destination prompt pairs from various domains. Our experiments (using Stable Diffusion) show that TIME is successful in model editing, generalizes well for related prompts unseen during editing, and imposes minimal effect on unrelated generations.
What, Why and How
Here is a summary of the key points from this paper:
What: This paper proposes a method called TIME (Text-to-Image Model Editing) for editing the implicit assumptions made by text-to-image diffusion models.
Why: Text-to-image models often make incorrect or biased assumptions based on the training data. There is a need to edit these assumptions without expensive retraining or requiring extra user input.
How:
- TIME takes as input a source prompt (e.g. “a pack of roses”) and a destination prompt (e.g. “a pack of blue roses”).
- It edits the text projection matrices in the cross-attention layers to map the source prompt embeddings close to the destination prompt.
- This shifts the visual meaning of the source prompt tokens towards the destination prompt.
- TIME modifies only 2.2% of model parameters and is efficient, taking under 1 second.
- Experiments show it successfully edits assumptions, generalizes to related prompts, and minimally affects unrelated prompts.
- TIME also mitigates gender bias by equalizing genders for professions.
In summary, TIME efficiently edits the implicit assumptions in text-to-image models by modifying the cross-attention layers based on source and destination text prompts. It shows promising results in editing assumptions and reducing bias.
Main Contributions
Here are the main contributions of this paper:
-
Proposes TIME, the first method for editing implicit assumptions in text-to-image diffusion models using only text prompts.
-
Targets the cross-attention layers as the location for editing the mapping from text to image.
-
Provides an efficient closed-form solution to update the text projection matrices.
-
Modifies only 2.2% of model parameters in under 1 second.
-
Introduces TIMED, a dataset for evaluating text-to-image model editing methods.
-
Demonstrates TIME’s ability to successfully edit associations, generalize to related prompts, and minimally affect unrelated prompts.
-
Shows TIME can mitigate gender bias by equalizing gender representations for professions.
-
Analyzes model editing as a tradeoff between generality and specificity.
-
Discusses social implications and limitations, such as inheriting generative limitations of the model.
In summary, the key contributions are proposing an efficient text-based model editing method, introducing a new dataset, and demonstrating applications for editing associations and reducing bias. The analyses and discussions also provide insights into this emerging research area.
Method Section
Here is a summary of the method section from the paper:
The proposed method is called TIME (Text-to-Image Model Editing). It takes as input a source prompt that lacks some specification (e.g. “a pack of roses”) and a destination prompt that specifies the desired attribute (e.g. “a pack of blue roses”).
TIME focuses on modifying the cross-attention layers in the diffusion model, specifically the projection matrices W_K and W_V that map text embeddings into keys and values. This is because these projections operate only on text, irrespective of the image generation process.
For each source token, TIME identifies the corresponding destination token embedding. It then calculates the keys and values of these destination embeddings using the original W_K and W_V.
TIME optimizes for new projection matrices W’_K and W’_V that minimize the L2 distance between mapping the source embeddings and the pre-computed destination keys/values. This is regularized to stay close to the original matrices.
The loss function has a closed form solution, allowing efficient computation of the updated projections. The new matrices map the source prompt close to the destination prompt in the cross-attention space.
After updating the projections, TIME uses the modified diffusion model to generate images. The goal is that generations based on the source prompt will now align with the assumption in the destination prompt.
In summary, TIME efficiently edits the cross-attention layers to update the mapping from source text to destination visuals, promoting the desired attribute or assumption change.
High-Level Pseudo Code
Here is high-level pseudo code for the TIME method proposed in the paper:
# Input
source_prompt = "a pack of roses"
destination_prompt = "a pack of blue roses"
# Get embeddings
source_embeddings = text_encoder(source_prompt)
destination_embeddings = text_encoder(destination_prompt)
# Initialize matrices
W_K = model.cross_attention.W_K
W_V = model.cross_attention.W_V
# Get destination keys and values
destination_keys = W_K @ destination_embeddings
destination_values = W_V @ destination_embeddings
# Optimize new projection matrices
loss = ||W'_K @ source_embeddings - destination_keys||^2 +
||W'_V @ source_embeddings - destination_values||^2 +
lambda * ||W'_K - W_K||^2 +
lambda * ||W'_V - W_V||^2
W'_K, W'_V = closed_form_minimize(loss)
# Update model
model.cross_attention.W_K = W'_K
model.cross_attention.W_V = W'_V
# Generate images with edited model
images = model(source_prompt)
This shows the overall flow:
- Get embeddings for source and destination prompts
- Compute destination keys/values using original model projections
- Optimize new projections to map source close to destination
- Update the model’s cross-attention matrices
- Generate images using edited model
The key steps are computing the target projection and efficiently optimizing for the updated mappings that promote the desired attribute change.
Detailed Pseudo Code
Here is more detailed pseudo code to implement the TIME model editing method:
import torch
from model import TextDiffusion
# Input prompts
source_prompt = "a pack of roses"
destination_prompt = "a pack of blue roses"
# Text encoder tokenize and embed prompts
source_embeddings = text_encoder(source_prompt)
destination_embeddings = text_encoder(destination_prompt)
# Get source tokens
source_tokens = tokenize(source_prompt)
# Get corresponding source and destination embeddings
corresponding_source_embs = []
corresponding_dest_embs = []
for i, token in enumerate(source_tokens):
if token in destination_prompt:
source_emb = source_embeddings[i]
dest_emb = destination_embeddings[destination_prompt.index(token)]
corresponding_source_embs.append(source_emb)
corresponding_dest_embs.append(dest_emb)
# Get diffusion model cross attention matrices
W_K = model.cross_attention.W_K
W_V = model.cross_attention.W_V
# Compute destination keys and values
destination_keys = W_K @ torch.stack(corresponding_dest_embs)
destination_values = W_V @ torch.stack(corresponding_dest_embs)
# Define loss
loss_fn = nn.MSELoss()
loss = 0
for W in [W_K, W_V]:
W_prime = torch.nn.Parameter(torch.randn_like(W))
loss += loss_fn(W_prime @ torch.stack(corresponding_source_embs),
destination_keys if W==W_K else destination_values)
loss += lambda * loss_fn(W_prime, W)
# Analytically minimize loss
W_K_prime = closed_form_minimize(loss, W=W_K)
W_V_prime = closed_form_minimize(loss, W=W_V)
# Update model
model.cross_attention.W_K = W_K_prime
model.cross_attention.W_V = W_V_prime
# Generate images
images = model(source_prompt)
Key details:
- Get corresponding source and destination embeddings
- Compute target keys/values with original projections
- Define MSE loss between source/destination projections + regularization
- Analytically minimize loss to get updated projections
- Update model with new cross-attention matrices
- Generate images with edited model
This shows a more complete implementation with embedding lookup, loss computation, and closed form optimization.