ReFACT: Updating Text-to-Image Models by Editing the Text Encoder
Authors: Dana Arad, Hadas Orgad, Yonatan Belinkov
Abstract: Text-to-image models are trained on extensive amounts of data, leading them to implicitly encode factual knowledge within their parameters. While some facts are useful, others may be incorrect or become outdated (e.g., the current President of the United States). We introduce ReFACT, a novel approach for editing factual knowledge in text-to-image generative models. ReFACT updates the weights of a specific layer in the text encoder, only modifying a tiny portion of the model’s parameters, and leaving the rest of the model unaffected. We empirically evaluate ReFACT on an existing benchmark, alongside RoAD, a newly curated dataset. ReFACT achieves superior performance in terms of generalization to related concepts while preserving unrelated concepts. Furthermore, ReFACT maintains image generation quality, making it a valuable tool for updating and correcting factual information in text-to-image models.
What, Why and How
Here is a summary of the key points from the paper:
What:
- The paper presents ReFACT, a novel method for editing factual knowledge in text-to-image generative models.
Why:
- Text-to-image models implicitly encode factual knowledge within their parameters, some of which can become outdated. There is a need for efficient ways to update these models without full retraining.
How:
- ReFACT updates a specific layer in the text encoder using a rank one editing approach. This modifies only a tiny portion of the model’s parameters.
- It takes an edit prompt (e.g. “The President of the United States”) and a target text/image (e.g. “Joe Biden”), and edits the layer weights to align the edit prompt representation with the target.
- Experiments show ReFACT achieves superior performance in updating facts while preserving unrelated concepts and image quality.
- ReFACT demonstrates high generalization and specificity, and outperforms a previous editing method TIME.
- The authors also introduce a new benchmark dataset RoAD for evaluating factual knowledge editing.
In summary, the paper presents an efficient and effective method called ReFACT to update factual knowledge in text-to-image models by making targeted edits to the text encoder parameters. It shows strong empirical results on updating diverse facts while minimizing changes to the model.
Main Contributions
Here are the key contributions of this paper:
-
Proposes ReFACT, a new method for editing factual knowledge in text-to-image models by updating a specific layer in the text encoder using a rank one editing approach.
-
Demonstrates ReFACT’s ability to update facts about people, objects, and abstract concepts while preserving unrelated concepts and maintaining image quality.
-
Shows strong empirical results with ReFACT outperforming a previous editing method TIME on two datasets in terms of efficacy, generalization, and specificity.
-
Introduces a new benchmark dataset called RoAD (Roles and Appearances Dataset) for evaluating factual knowledge editing in text-to-image models.
-
Provides analysis showing ReFACT can edit multiple facts sequentially while retaining performance.
-
Uses ReFACT as a causal analysis tool to investigate which facial expression information is encoded in different layers of the text encoder.
-
Makes code and data publicly available to support further research.
In summary, the key contributions are: (1) proposing the ReFACT editing method, (2) extensive empirical analysis demonstrating its capabilities, (3) introduction of the new RoAD dataset, (4) using ReFACT for text encoder analysis, and (5) releasing code and data.
Method Section
Here is a summary of the method section:
The paper proposes ReFACT, which edits a specific layer in the text encoder of a text-to-image model to update factual knowledge.
ReFACT is based on the idea that linear layers in the text encoder encode factual knowledge as key-value pairs. The keys are representations of phrases like “The President of the United States” and the values are identities like “Donald Trump”.
ReFACT takes an edit prompt (key) and a target text/image (new value) as input. It computes representations for the edit prompt and target using the frozen CLIP encoders.
To compute the new value v*, it optimizes a contrastive loss that brings the edit prompt representation close to the target while keeping it far from negative examples.
The new key k* is computed as the average representation of the last subject token from variations of the edit prompt.
Using k* and v*, ReFACT performs rank one editing of the chosen layer’s weights to align the edit prompt representation with the target. This modifies only a small portion of weights.
The edit is designed to be applied to the multi-layer perceptron (MLP) layers in the text encoder, specifically the projection matrices which are viewed as stores of key-value pairs.
In summary, ReFACT makes a targeted edit to a specific layer of the text encoder to update the value associated with the input edit key to match the provided target. It modifies only a small part of the model and does not require fine-tuning.
High-Level Pseudo Code
Here is the high-level pseudo code for the ReFACT method proposed in the paper:
# ReFACT Pseudocode
# Inputs:
edit_prompt # key phrase representing the fact to edit (e.g. "President of US")
target # target value (text or image)
layer # layer in text encoder to edit
# Load CLIP model
text_encoder = CLIP.text_encoder
image_encoder = CLIP.image_encoder
# Get target representation
if target is text:
target_rep = text_encoder(target)[EOS]
else:
target_rep = image_encoder(target)[EOS]
# Get edit prompt representations
edit_reps = [text_encoder(p)[EOS] for p in edit_prompt_variations]
# Compute new value v* that brings edit prompt close to target
v_star = optimize_contrastive_objective(target_rep, edit_reps)
# Compute new key k*
k_star = average(last_subject_tokens(edit_reps))
# Perform rank-1 edit
updated_weights = edit_layer(layer, k_star, v_star)
# Load text encoder with updated weights
text_encoder.load_state_dict(updated_weights)
# Text encoder now has updated factsIn summary, the key steps are:
- Get target representation from CLIP encoder
- Compute representations for edit prompt
- Optimize new value v* using contrastive loss
- Compute new key k* as average of last subject tokens
- Perform rank-1 edit of layer weights
- Load updated weights into text encoder
This allows editing a specific layer of the text encoder to align the edit prompt representation with the provided target.
Detailed Pseudo Code
Here is a more detailed pseudo code for implementing the ReFACT method:
import torch
from transformers import CLIP
# Load CLIP
text_encoder, image_encoder = CLIP.from_pretrained('openai/clip-vit-base-patch32')
# Freeze CLIP encoders
text_encoder.eval()
image_encoder.eval()
# Inputs
edit_prompt = "The President of the United States"
target = "Joe Biden" # can also be an image
layer_to_edit = 9
# Get target representation
if isinstance(target, str):
with torch.no_grad():
target_rep = text_encoder(target)[EOS]
else:
with torch.no_grad():
target_rep = image_encoder(target)[EOS]
# Get edit prompt representations
edit_prompt_vars = ["The President of the US", "A photo of the President", etc.]
edit_reps = []
for p in edit_prompt_vars:
with torch.no_grad():
r = text_encoder(p)[EOS]
edit_reps.append(r)
# Compute new value v*
def contrastive_loss(v):
loss = 0
edit_rep_new = text_encoder_with_v(edit_prompt, v)
loss += distance(target_rep, edit_rep_new)
for r in other_edit_reps:
r_new = text_encoder_with_v(r, v)
loss -= distance(target_rep, r_new)
return loss
v_star = optimize(contrastive_loss)
# Compute new key k*
k_star = torch.mean(last_subject_tokens(edit_reps))
# Get layer weights
weights = text_encoder.get_weights()
# Perform rank-1 edit
updated_weights = edit_layer(weights[layer_to_edit], k_star, v_star)
# Load updated weights
text_encoder.load_state_dict(updated_weights)This provides more implementation details like getting the representations, optimizing the contrastive loss, computing k* and v*, performing rank-1 edit of the chosen layer, and loading the updated weights into the text encoder.