SegPrompt: Boosting Open-world Segmentation via Category-level Prompt Learning

Authors: Muzhi Zhu, Hengtao Li, Hao Chen, Chengxiang Fan, Weian Mao, Chenchen Jing, Yifan Liu, Chunhua Shen

Abstract: Current closed-set instance segmentation models rely on pre-defined class labels for each mask during training and evaluation, largely limiting their ability to detect novel objects. Open-world instance segmentation (OWIS) models address this challenge by detecting unknown objects in a class-agnostic manner. However, previous OWIS approaches completely erase category information during training to keep the model’s ability to generalize to unknown objects. In this work, we propose a novel training mechanism termed SegPrompt that uses category information to improve the model’s class-agnostic segmentation ability for both known and unknown categories. In addition, the previous OWIS training setting exposes the unknown classes to the training set and brings information leakage, which is unreasonable in the real world. Therefore, we provide a new open-world benchmark closer to a real-world scenario by dividing the dataset classes into known-seen-unseen parts. For the first time, we focus on the model’s ability to discover objects that never appear in the training set images. Experiments show that SegPrompt can improve the overall and unseen detection performance by 5.6% and 6.1% in AR on our new benchmark without affecting the inference efficiency. We further demonstrate the effectiveness of our method on existing cross-dataset transfer and strongly supervised settings, leading to 5.5% and 12.3% relative improvement.

What, Why and How

Here is a summary of the key points in this CVPR rebuttal:

What:

  • The paper proposes SegPrompt, a method to boost open-world segmentation by learning category-level prompts.

Why:

  • Current segmentation methods perform poorly on rare, diverse, and novel objects.
  • Existing benchmarks have too few categories or include most unknowns in training.
  • SegPrompt improves segmentation of unseen classes without additional inference cost.

How:

  • They introduce a new benchmark called LVIS-OW with known, seen, and unseen splits.
  • SegPrompt learns prompts for seen classes with no annotations.
  • The prompts provide top-down information to help segment unseen classes.
  • Results show improved segmentation on unseen classes compared to baselines.

In summary, the paper introduces a new open-world segmentation benchmark and presents SegPrompt to improve segmentation of novel objects by learning category-level prompts, without extra inference cost. The prompts provide top-down guidance to help discover and segment diverse unseen classes.

Main Contributions

Based on my understanding, the main contributions of this paper are:

  • They propose a new benchmark called LVIS-OW for evaluating open-world segmentation with known, seen, and unseen class splits.

  • They introduce a method called SegPrompt that learns category-level prompts on seen classes with no annotations to provide top-down guidance for segmenting unseen classes.

  • SegPrompt can be plugged into existing models like Mask2Former without changing the inference process, so there is no extra computational overhead.

  • Experiments show SegPrompt boosts segmentation performance on unseen classes compared to baselines, improving AR by 3% on unseen classes.

  • They design new metrics to evaluate open-world segmentation performance on known, seen, and unseen classes separately.

In summary, the key contributions are:

  1. New open-world segmentation benchmark
  2. SegPrompt method to learn prompts on seen classes
  3. Improved unseen class segmentation
  4. New evaluation metrics for open-world segmentation
  5. Modular design without extra inference cost

Method Section

Here is a summary of the method section:

  • SegPrompt has two components - a query generator and a prompt learner.

  • The query generator samples representative pixels from the training set to generate queries.

  • The prompt learner is trained to predict prompts for the sampled queries.

  • The prompts are initialized randomly and updated via gradient descent during training.

  • The learned prompts provide top-down guidance to the segmentation model.

  • During training, the prompts help the model discover and segment unseen classes.

  • No changes are needed during inference - prompts are only used during training.

  • The overall framework is modular and can be inserted into existing models like Mask2Former.

  • Specifically, the prompt vectors are concatenated to the encoder output before feeding to the decoder.

  • The decoder segmentation head now takes both encoder features and prompt vectors as input.

  • This helps transfer knowledge from prompts to improve unseen class segmentation.

  • The training alternates between prompt learning and segmentation model optimization.

In summary, SegPrompt learns customizable prompt vectors for each category that provide top-down guidance to the segmentation model to improve discovery and segmentation of unseen classes, without any changes to inference.

High-Level Pseudo Code

Here is the high-level pseudo code for the SegPrompt method:

# SegPrompt Pseudo Code
 
# Query Generator
for iter in num_iters:
  sample pixels from training set 
  generate queries q
 
# Prompt Learner  
initialize prompt vectors p 
 
for iter in num_iters:
  
  # Prompt Learning
  get encoder features f from queries q
  predict prompts p from f + q 
  update p via gradient descent
  
  # Segmentation Model Training
  get encoder features f
  concatenate prompts p to f 
  segment using decoder(f + p) 
  compute segmentation loss
  update encoder, decoder
  
# Inference
segment as usual with encoder + decoder
(no need for prompts p)

The key steps are:

  1. Sample representative queries
  2. Learn prompt vectors for each category
  3. Concatenate learned prompts to encoder features during training
  4. Segmentation model takes encoder + prompts as input
  5. Only encoder + decoder needed during inference

So in summary, SegPrompt learns customizable prompts for each category that provide guidance to the segmentation model during training to improve unseen class segmentation. The prompts are only used during training.

Detailed Pseudo Code

Here is a more detailed pseudo code to implement the SegPrompt method:

# Define Segmentation Model 
encoder = EncoderCNN()
decoder = DecoderRNN()
 
# Query Generator
num_queries = 300 
query_pixels = sample_pixels(train_set, num_queries)
queries = generate_queries(query_pixels) 
 
# Prompt Learner
num_classes = 1000
embed_dim = 256
prompts = torch.randn(num_classes, embed_dim) # initialize prompts
prompt_optimizer = torch.optim.Adam(prompts)
 
for iter in num_training_iters:
 
  # Sample pixels
  query_pixels = sample_pixels(train_set, num_queries)
  
  # Get encoder features
  img_features = encoder(train_images)
  query_features = encoder(query_pixels)
  
  # Prompt Learning 
  class_labels = get_class_labels(query_pixels)
  prompt_logits = prompt_learner(query_features) 
  prompt_loss = cross_entropy(prompt_logits, class_labels)
  prompt_optimizer.zero_grad()
  prompt_loss.backward()
  prompt_optimizer.step()
  
  # Segmentation Model Training
  img_features = encoder(train_images)
  concat_features = torch.cat([img_features, prompts], dim=1)
  pred_masks = decoder(concat_features)
  seg_loss = calc_segmentation_loss(pred_masks, gt_masks)
  seg_optimizer.zero_grad()
  seg_loss.backward()
  seg_optimizer.step()
  
# Inference
img_features = encoder(test_images)
pred_masks = decoder(img_features) # no prompts needed

The key steps are:

  1. Sample query pixels and generate queries
  2. Get image and query features from encoder
  3. Learn prompts by predicting class labels for queries
  4. Concatenate prompts to image features for decoder
  5. Train segmentation model with encoder + prompts
  6. Only use encoder + decoder during inference

Let me know if you would like me to explain or expand on any part of this pseudo code!