DreamTeacher: Pretraining Image Backbones with Deep Generative Models
Authors: Daiqing Li, Huan Ling, Amlan Kar, David Acuna, Seung Wook Kim, Karsten Kreis, Antonio Torralba, Sanja Fidler
Abstract: In this work, we introduce a self-supervised feature representation learning framework DreamTeacher that utilizes generative networks for pre-training downstream image backbones. We propose to distill knowledge from a trained generative model into standard image backbones that have been well engineered for specific perception tasks. We investigate two types of knowledge distillation: 1) distilling learned generative features onto target image backbones as an alternative to pretraining these backbones on large labeled datasets such as ImageNet, and 2) distilling labels obtained from generative networks with task heads onto logits of target backbones. We perform extensive analyses on multiple generative models, dense prediction benchmarks, and several pretraining regimes. We empirically find that our DreamTeacher significantly outperforms existing self-supervised representation learning approaches across the board. Unsupervised ImageNet pre-training with DreamTeacher leads to significant improvements over ImageNet classification pre-training on downstream datasets, showcasing generative models, and diffusion generative models specifically, as a promising approach to representation learning on large, diverse datasets without requiring manual annotation.
What, Why and How
Here is a summary of the key points in the paper:
What:
-
The paper proposes a new self-supervised representation learning framework called DreamTeacher.
-
It utilizes generative models such as GANs and diffusion models to distill knowledge into target image backbones (CNNs) for pre-training.
-
Two types of distillation are investigated: 1) Feature distillation to transfer generative features into image backbones. 2) Label distillation to transfer soft labels predicted by a task head on the generative model onto target backbones.
Why:
-
Leverages generative models as powerful unsupervised representation learners, avoiding the need for large labeled datasets.
-
Shows generative model features contain rich semantic information that transfers well to downstream tasks.
-
Avoids having to jointly train generative and discriminative models. Can distill into any target backbone.
How:
-
Creates a feature dataset of images and corresponding generative model features.
-
Uses a feature regressor attached to the target backbone to align and regress generative features.
-
Employs distillation losses like MSE and Attention Transfer to transfer features.
-
For label distillation, trains a task head on generative features, and distills predictions onto target backbone.
-
Evaluates on ImageNet pre-training and transfer, plus in-domain pre-training on datasets like BDD100K.
In summary, DreamTeacher provides a new way to leverage generative models to distill visual knowledge into target discriminative backbones in a self-supervised manner, avoiding the need for large labeled datasets. It shows strong performance on multiple benchmarks.
Main Contributions
The main contributions of this paper are:
-
Proposes DreamTeacher, a new framework for distilling knowledge from generative models into target image backbones for self-supervised pre-training.
-
Introduces two distillation methods: feature distillation to transfer semantic features, and label distillation to transfer predictions from a task head.
-
Evaluates various generative models including GANs (BigGAN, StyleGAN) and diffusion models (DDPM) as teachers.
-
Demonstrates strong performance on ImageNet pre-training and transfer learning, outperforming prior self-supervised methods.
-
Shows impressive results for in-domain pre-training, significantly outperforming contrastive and masked image modeling methods on datasets like BDD100K and label-efficient segmentation benchmarks.
-
Provides extensive ablations on design choices like distillation objectives, feature regressors, encoding techniques for diffusion models.
-
Makes code and models publicly available.
In summary, the key contributions are:
-
Proposing DreamTeacher as a new self-supervised representation learning paradigm using generative model distillation
-
Comprehensive evaluation across various settings like ImageNet pre-training, in-domain pre-training, label-efficient tasks
-
Extensive ablations to validate design choices
-
Releasing code and models to the community
The paper shows generative model distillation is a promising approach for self-supervised representation learning, achieving new state-of-the-art across multiple benchmarks.
Method Section
Here is a summary of the method section:
The paper proposes the DreamTeacher framework for distilling knowledge from a pretrained generative model G into a target image backbone f in a self-supervised manner.
Two distillation settings are explored:
- Unsupervised representation learning:
- Create a feature dataset D = {xi, f^g_i} of images and corresponding features from G
- Features can come from sampling G (synthesized dataset) or encoding real images into G’s latent space (encoded dataset)
- Attach feature regressors to f to map/align its features with f^g_i
- Use losses like MSE and Attention Transfer to distill f^g_i into f’s features
- Semi-supervised learning:
- Train a task-specific head (feature interpreter) on top of frozen G
- Distill soft labels predicted by the interpreter into logit layer of f
- Combine feature and label distillation losses to train f
For feature regression, an FPN-like top-down architecture is used to align multi-scale features from f to f^g.
For label distillation, an interpreter takes G’s features as input and predicts soft labels that are distilled onto f.
In summary, DreamTeacher provides a flexible framework to distill knowledge from pretrained generative models into target image backbones. It exploits feature and optional label distillation to transfer semantic and perceptual knowledge in a self-supervised manner.
High-Level Pseudo Code
Here is a high-level pseudo code for the DreamTeacher method:
# Generate feature dataset
features = []
for image in images:
if GAN:
# Sample image from GAN
fake_image = GAN.sample(z)
fake_feats = GAN.extract_features(fake_image)
elif Diffusion:
# Encode image into diffusion model
encoded_z = Diffusion.encode(image)
fake_image = Diffusion.decode(encoded_z)
fake_feats = Diffusion.extract_features(fake_image)
features.append((image, fake_feats))
# Define student network (CNN backbone)
student = CNN()
# Attach feature regressor to student
student = add_feature_regressor(student)
# Define feature distillation losses
feat_loss = MSELoss + AttentionTransferLoss
for image, fake_feats in features:
# Get student features
student_feats = student(image)
# Feature distillation
loss = feat_loss(student_feats, fake_feats)
# Update student weights
loss.backward()
optimizer.step()
# Semi-supervised distillation
# Train task head on generative model features
head = SegmentationHead(Diffusion.features)
head.train(labeled_subset)
# Get soft labels from head
fake_labels = head(images)
# Distill labels into student
for image, fake_label in zip(images, fake_labels):
student_label = student(image)
# Label distillation
loss = CrossEntropyLoss(student_label, fake_label)
# Update student weights
loss.backward()
optimizer.step()
In summary, the key steps are:
- Generate feature dataset from GAN/Diffusion model
- Attach regressor to student CNN backbone
- Perform feature distillation with MSE and attention losses
- For semi-supervised case, train and distill from a task head
The student backbone is trained with generative features and soft labels in a self-supervised manner.
Detailed Pseudo Code
Here is a more detailed pseudo code for implementing DreamTeacher:
# Configs
student = ResNet50() # Student backbone
teacher = DiffusionModel() # Teacher generative model
feat_levels = [2, 3, 4, 5] # Feature levels to distill
# 1. Create feature dataset
features = []
for img in images:
# Encode image into diffusion model
noise = GaussianNoise(img)
noisy_img = img + noise
encoded_z = teacher.encode(noisy_img)
# Decode to get teacher features
clean_img = teacher.decode(encoded_z)
teacher_feats = teacher.extract_features(clean_img)
features.append((img, teacher_feats))
# 2. Define student model
student = ResNet50()
regressor = FeatureRegressor() # FPN-like architecture
for l in feat_levels:
student = inject_layer(student, 'layer'+str(l))
student = regressor(student)
# 3. Define losses
mse_loss = MSELoss()
at_loss = AttentionTransferLoss()
feat_loss = mse_loss + at_loss
# 4. Feature distillation
opt = SGD(student.parameters())
for img, teacher_feats in features:
student_feats = student(img)
# Get student feat at distillation levels
student_feats = [student_feats['layer'+str(l)] for l in feat_levels]
loss = 0
for l, f_t, f_s in zip(feat_levels, teacher_feats, student_feats):
loss += feat_loss(f_t, f_s)
loss.backward()
opt.step()
# 5. Semi-supervised label distillation
# Train segmentation head on teacher
head = SegmentationHead(teacher.features)
head.train(subset_labeled_data)
# Distill labels
student_head = SegmentationHead(student.features)
for img in images:
with torch.no_grad():
# Get soft labels from teacher
teacher_label = head(img)
student_label = student_head(img)
# Label distillation
ld_loss = CrossEntropyLoss(student_label, teacher_label)
ld_loss.backward()
opt.step()
The key steps are:
- Create feature dataset by encoding images into teacher
- Inject layers into student backbone to collect features
- Attach feature regressor to align features
- Compute distillation losses and update student
- For semi-supervised, train and distill from a task head
This provides a detailed pseudo code for implementing DreamTeacher for feature and label distillation.