Training-free Object Counting with Prompts

Authors: Zenglin Shi, Ying Sun, Mengmi Zhang

Abstract: This paper tackles the problem of object counting in images. Existing approaches rely on extensive training data with point annotations for each object, making data collection labor-intensive and time-consuming. To overcome this, we propose a training-free object counter that treats the counting task as a segmentation problem. Our approach leverages the Segment Anything Model (SAM), known for its high-quality masks and zero-shot segmentation capability. However, the vanilla mask generation method of SAM lacks class-specific information in the masks, resulting in inferior counting accuracy. To overcome this limitation, we introduce a prior-guided mask generation method that incorporates three types of priors into the segmentation process, enhancing efficiency and accuracy. Additionally, we tackle the issue of counting objects specified through free-form text by proposing a two-stage approach that combines reference object selection and prior-guided mask generation. Extensive experiments on standard datasets demonstrate the competitive performance of our training-free counter compared to learning-based approaches. This paper presents a promising solution for counting objects in various scenarios without the need for extensive data collection and model training. Code is available at trainingfree-object-counter.

What, Why and How

Here is a summary of the key points in this paper:

What:

  • This paper proposes a training-free object counting method that can count objects specified by prompts like points, boxes or text descriptions.

Why:

  • Existing object counting methods rely on extensive training data which is laborious to collect. This limits their applicability to new scenarios.
  • A training-free method eliminates the need for collecting training data and retraining models, making object counting more accessible.

How:

  • Formulate counting as a segmentation task using the Segment Anything Model (SAM).
  • Propose a prior-guided mask generation method to improve SAM’s segmentation by incorporating 3 priors: similarity, segment and semantic priors.
  • For text-specified objects, propose a 2-stage approach:
    1. Select reference objects from the image based on an initial text-image similarity map.
    2. Use the selected references with prior-guided mask generation.
  • Experiments show the method achieves competitive performance compared to learning-based approaches, demonstrating the promise of training-free counting.

In summary, this paper tackles the limitations of existing learning-based counting methods by developing a training-free approach. It cleverly formulates counting as segmentation and enhances SAM’s masks using prior information. The training-free nature makes counting more accessible without requiring extensive data collection and model retraining.

Main Contributions

Here are the main contributions of this paper:

  1. Proposes a training-free class-agnostic object counting approach by formulating counting as a segmentation task using the Segment Anything Model (SAM).

  2. Introduces a prior-guided mask generation method that improves SAM’s segmentation by incorporating three types of priors:

    • Similarity prior using a similarity map
    • Segment prior using existing segmented regions
    • Semantic prior using reference object features
  3. Presents a two-stage approach to handle counting objects specified by free-form text:

    • Stage 1: Select reference objects from the image based on an initial text-image similarity map
    • Stage 2: Use the selected reference objects with prior-guided mask generation
  4. Conducts extensive experiments on standard datasets that demonstrate the proposed training-free counting approach achieves competitive performance compared to learning-based methods.

  5. The training-free nature makes object counting more accessible without requiring extensive data collection and model retraining for new scenarios.

In summary, the main contributions are: formulating counting as segmentation, proposing prior-guided mask generation, handling text-specified counting, and experimentally validating the effectiveness of training-free counting.

Method Section

Here is a summary of the method section from the paper:

The paper tackles class-agnostic object counting where the objects to count are specified by prompts like points, boxes or text.

The key idea is to formulate counting as a segmentation task using the Segment Anything Model (SAM). SAM can generate high-quality masks from prompts and performs zero-shot segmentation.

However, directly using SAM for counting lacks class-specific information and leads to poor results. To address this, the paper proposes a prior-guided mask generation method that incorporates 3 types of priors:

  • Similarity prior: Compute a similarity map between reference and image features. Use this to distinguish between positive and negative points when prompting SAM.

  • Segment prior: Maintain an overall segment map to avoid redundant segmentation. Use existing segments to guide SAM’s predictions.

  • Semantic prior: Incorporate reference object features as context when generating masks. This helps SAM identify the target objects.

For text-specified counting, a two-stage approach is proposed:

  • Stage 1: Select reference objects from the image based on an initial text-image similarity map from CLIP-Surgery. Refine the references using binarization and bounding boxes.

  • Stage 2: Use the refined references with prior-guided mask generation.

In summary, the key ideas are: formulating counting as segmentation, enhancing SAM’s masks using similarity, segment and semantic priors, and a two-stage approach to handle text-specified counting.

High-Level Pseudo Code

Here is the high-level pseudo code for the key components of the paper:

# Prior-guided mask generation
def generate_masks(image, prompts):
  
  # Get image feature
  feature = get_image_feature(image)  
 
  # Get reference masks and features 
  ref_masks, ref_features = get_reference(prompts)
 
  # Compute similarity map
  sim_map = get_similarity_map(feature, ref_features)
 
  # Initialize overall segment map
  segment_map = None
 
  # Divide point grid into batches
  point_batches = divide_points_into_batches(image)
 
  # Process each batch
  for batch in point_batches:
 
    # Remove already segmented points 
    batch = remove_segmented_points(batch, segment_map)
 
    # Label points as pos/neg using sim_map 
    batch = label_points(batch, sim_map)
 
    # Pass image feature, ref features and labeled points to SAM
    masks = sam(feature, ref_features, batch)
    
    # Update segment map
    segment_map.update(masks)
 
  # Return all masks
  return segment_map  
 
# Text-specified counting
def count_text(image, text):
 
  # Get coarse similarity map using CLIP-Surgery 
  sim_map_coarse = clip_surgery(image, text)
 
  # Select reference objects
  ref_boxes = select_references(sim_map_coarse)
 
  # Get fine similarity map using references
  sim_map_fine = get_similarity_map(image, ref_boxes)
 
  # Generate masks
  masks = generate_masks(image, ref_boxes, sim_map_fine)
 
  # Return mask count
  return len(masks)

The key steps are:

  • Incorporate priors during mask generation
  • For text, select good references using a coarse similarity map
  • Pass selected references to mask generation
  • Count masks as object count

Detailed Pseudo Code

Here is a more detailed pseudo code to implement the key components of the paper:

# Import SAM, CLIP-Surgery etc.
 
# Prior-guided mask generation
def generate_masks(image, prompts):
 
  # Get image feature
  feature = sam.encode_image(image)   
 
  # Get reference masks and features
  ref_masks, ref_boxes = get_reference(prompts)
  ref_features = feature * ref_masks
 
  # Compute similarity map
  sim_map = cosine_similarity(feature, ref_features)
  
  # Binarize similarity map
  bin_sim_map = otsu_threshold(sim_map)
 
  # Initialize overall segment map
  segment_map = {}
 
  # Get point grid
  point_grid = get_point_grid(image)
 
  # Divide points into batches 
  point_batches = divide_into_batches(point_grid)
 
  # Process each batch
  for batch in point_batches:
 
    # Remove already segmented points
    batch = batch - segment_map.keys()  
 
    # Label points as pos/neg
    pos_points = []
    neg_points = []
    for p in batch:
      if bin_sim_map[p] > 0:
        pos_points.append(p)
      else:
        neg_points.append(p)
    
    # Pass image feature, ref features and labeled points to SAM
    masks = sam.segment_image(feature, ref_features, pos_points)
    
    # Update segment map
    for m in masks:
      segment_map[m.points] = m
 
  # Return all masks
  return segment_map.values()
 
# Text-specified counting  
def count_text(image, text):
 
  # Get coarse similarity map
  text_emb = clip.encode_text(text)
  sim_map_coarse = clip_surgery(image, text_emb)
 
  # Binarize coarse sim map
  bin_sim_map = otsu_threshold(sim_map_coarse)
 
  # Get largest connected component
  components = connected_components(bin_sim_map)
  largest_cc = get_largest_component(components)
 
  # Divide into sub-contours
  sub_contours = divide_contour(largest_cc)
 
  # Get bounding boxes for each sub-contour
  ref_boxes = get_boxes(sub_contours)
 
  # Apply NMS on boxes
  ref_boxes = nms(ref_boxes)
 
  # Get fine similarity map using refined references
  ref_features = sam.encode_image(crop_boxes(image, ref_boxes))
  sim_map_fine = cosine_similarity(sam.encode_image(image), ref_features)
 
  # Generate masks
  masks = generate_masks(image, ref_boxes, sim_map_fine)
 
  # Return mask count
  return len(masks)

The key aspects are:

  • Specific implementations for getting features, similarity maps, processing points/batches
  • Binarization, connected components and NMS for selecting good reference objects from text
  • Reusing selected references to get a high quality similarity map
  • Calling mask generation function

Let me know if you would like me to explain or expand on any part of this code!