StylePrompter: All Styles Need Is Attention
Authors: Chenyi Zhuang, Pan Gao, Aljosa Smolic
Abstract: GAN inversion aims at inverting given images into corresponding latent codes for Generative Adversarial Networks (GANs), especially StyleGAN where exists a disentangled latent space that allows attribute-based image manipulation at latent level. As most inversion methods build upon Convolutional Neural Networks (CNNs), we transfer a hierarchical vision Transformer backbone innovatively to predict latent codes at token level. We further apply a Style-driven Multi-scale Adaptive Refinement Transformer (SMART) in space to refine the intermediate style features of the generator. By treating style features as queries to retrieve lost identity information from the encoder’s feature maps, SMART can not only produce high-quality inverted images but also surprisingly adapt to editing tasks. We then prove that StylePrompter lies in a more disentangled and show the controllability of SMART. Finally, quantitative and qualitative experiments demonstrate that StylePrompter can achieve desirable performance in balancing reconstruction quality and editability, and is “smart” enough to fit into most edits, outperforming other -involved inversion methods.
What, Why and How
Here is a summary of the key points from the paper:
\textbf{What}:
- Proposes StylePrompter, a new GAN inversion framework using Transformers to map images to latent codes in StyleGAN.
\textbf{Why}:
- Most prior GAN inversion methods use CNN backbones. This explores using Transformers which can predict latent codes at the token level.
\textbf{How}:
-
Uses a hierarchical vision Transformer (Swin Transformer) as the backbone to predict latent codes in the space of StyleGAN. Latent codes are embedded as tokens interacting with image patch tokens.
-
Proposes a Style-driven Multi-scale Adaptive Refinement Transformer (SMART) to refine intermediate style features of the StyleGAN generator in the space. SMART treats style features as queries to retrieve lost identity information from the encoder features.
-
SMART uses controllable weights to balance reconstruction quality and editability. It adapts to edited cases unlike prior space methods.
-
Training is a two-stage strategy. First the Transformer backbone is trained. Then SMART is trained while keeping the backbone fixed.
-
Experiments show StylePrompter balances quality and editability well, outperforming CNN-based and other Transformer inversion methods. SMART produces high quality inversions while maintaining editability compared to other space methods.
In summary, this paper innovatively uses a hierarchical vision Transformer for GAN inversion to predict disentangled latent codes at the token level. The proposed SMART refines style features in a controllable way to balance reconstruction quality and editing flexibility.
Main Contributions
Based on my reading of the paper, the main contributions are:
-
Proposes a novel Transformer-based architecture (StylePrompter) for GAN inversion, using a hierarchical vision Transformer backbone to predict latent codes at the token level. To my knowledge, this is the first work to use a Transformer backbone for inversion.
-
Introduces a Style-driven Multi-scale Adaptive Refinement Transformer (SMART) to refine intermediate style features of the generator. SMART treats style features as queries to retrieve identity information from encoder features.
-
Designs SMART with controllable weights to balance reconstruction quality and editing flexibility. Shows SMART can adapt to edited cases, maintaining editability unlike prior methods.
-
Explores the disentanglement of the space predicted by StylePrompter, and shows it is more disentangled than other inversion methods.
-
Conducts experiments showing StylePrompter balances quality and editability well, outperforming CNN-based and other Transformer inversion methods.
-
Demonstrates the effectiveness of SMART to produce high quality inversions while maintaining the editability of the space, compared to other space inversion methods.
In summary, the key contributions are proposing a novel Transformer architecture for inversion, the SMART module for controllable refinement, exploring disentanglement, and showing experimentally that StylePrompter balances reconstruction quality and editing flexibility well compared to prior inversion methods.
Method Section
Here is a summary of the method section from the paper:
The StylePrompter framework has two main components:
\textbf{Transformer-based Encoder:}
- Uses a hierarchical vision Transformer (Swin Transformer) as the backbone.
- Latent codes are embedded as tokens and concatenated with image patch tokens as input to the Transformer.
- Latent tokens interact with patches during self-attention to learn image-level and latent-level features.
- Latent tokens are replicated across windows and merged after self-attention.
- The encoder outputs latent codes in the space of StyleGAN.
\textbf{Style-driven Multi-scale Adaptive Refinement Transformer (SMART):}
- Refines intermediate style features of StyleGAN generator in space.
- Style features are projected as queries. Encoder features are projected as keys and values.
- Queries retrieve identity information from keys/values via cross-attention.
- Uses local attention - each query only looks at corresponding spatial location.
- Additional weights provide control over residual connections.
- Training is two-stage: backbone first, then SMART.
In summary, the method uses a hierarchical Transformer backbone to predict disentangled latent codes in . Then SMART refines style features in space in a controllable way to balance reconstruction quality and editing flexibility. The two components aim to achieve high quality inversion while maintaining editability.
High-Level Pseudo Code
Here is the high-level pseudo code for the StylePrompter method:
# Encoder (Swin Transformer backbone)
img_patches = partition(img)
latent_codes = initialize_latent_tokens()
for window in windows:
window_input = concat(window_patches, latent_codes)
attened_tokens = multi_head_self_attention(window_input)
window_patches = mlp(attened_tokens)
latent_codes += latent_tokens # merge across windows
w_plus = projection_head(latent_codes)
# SMART Refinement
style_features = generator(w_plus)[:layer_l]
queries = linear_proj(style_features)
keys, values = linear_proj(encoder_features)
refined = cross_attention(queries, keys, values)
refined += style_features # residual
refined = mlp(refined)
refined += style_features # residual
# Generator
gen_output = generator(w_plus,
style_features[:l]=refined)
This shows the overall flow:
- Encoder backbone predicts w_plus latent codes
- SMART refines intermediate style features
- Pass refined features to generator for high quality inversion
The key steps are the hierarchical Transformer encoder, controllable cross-attention refinement in SMART, and integrating refined features into the generator.
Detailed Pseudo Code
Here is a more detailed pseudo code for implementing the StylePrompter method:
# Hyperparameters
NUM_LATENT_TOKENS = 18
NUM_WINDOWS = (H*W) / (WINDOW_SIZE**2)
# Swin Transformer Backbone
def encoder(img):
img_patches = patchify(img) # (N, patch_H, patch_W, C)
latent_codes = random_init(NUM_LATENT_TOKENS, C)
for l in range(L layers):
# Window partition
windows = []
for n in range(NUM_WINDOWS):
window = concat(
img_patches[:, n*window_H:(n+1)*window_H,
n*window_W:(n+1)*window_W, :],
latent_codes)
windows.append(window)
# Window multi-head self-attention
attened_windows = []
for window in windows:
x = multi_head_self_attn(layer_norm(window))
x = mlp(layer_norm(x))
attened_windows.append(x)
# Merge latent tokens across windows
img_patches = concat(attened_windows)[:img_patches.shape]
latent_codes = sum(attened_windows)[img_patches.shape:]
# Patch merging for img tokens
img_patches = patch_merging(img_patches)
latent_codes = mlp(layer_norm(latent_codes))
w_plus = projection_head(latent_codes)
return w_plus
# SMART Refinement
def smart(style_features, encoder_features):
queries = linear_proj(style_features)
keys, values = linear_proj(encoder_features)
refined = cross_attn(queries, keys, values)
refined += style_features # residual 1
refined = mlp(refined)
refined += style_features # residual 2
return refined
# Generator
def generator(w_plus):
for l in range(L layers):
styles = get_style(w_plus[:l])
if l == SMART_LAYER:
styles = smart(styles, encoder_features)
features = modulated_conv(prev_features, styles)
image = to_rgb(features)
return image
This shows:
- Details of window partition and merging latent tokens in encoder
- Linear projections and cross attention in SMART
- Passing refined styles back into the generator
The key aspects are the window processing for latent tokens, controllable cross attention refinement, and integrating refined styles into the existing generator.