Training-free Object Counting with Prompts
Authors: Zenglin Shi, Ying Sun, Mengmi Zhang
Abstract: This paper tackles the problem of object counting in images. Existing approaches rely on extensive training data with point annotations for each object, making data collection labor-intensive and time-consuming. To overcome this, we propose a training-free object counter that treats the counting task as a segmentation problem. Our approach leverages the Segment Anything Model (SAM), known for its high-quality masks and zero-shot segmentation capability. However, the vanilla mask generation method of SAM lacks class-specific information in the masks, resulting in inferior counting accuracy. To overcome this limitation, we introduce a prior-guided mask generation method that incorporates three types of priors into the segmentation process, enhancing efficiency and accuracy. Additionally, we tackle the issue of counting objects specified through free-form text by proposing a two-stage approach that combines reference object selection and prior-guided mask generation. Extensive experiments on standard datasets demonstrate the competitive performance of our training-free counter compared to learning-based approaches. This paper presents a promising solution for counting objects in various scenarios without the need for extensive data collection and model training. Code is available at trainingfree-object-counter.
What, Why and How
Here is a summary of the key points in this paper:
What:
- This paper proposes a training-free object counting method that can count objects specified by prompts like points, boxes or text descriptions.
Why:
- Existing object counting methods rely on extensive training data which is laborious to collect. This limits their applicability to new scenarios.
- A training-free method eliminates the need for collecting training data and retraining models, making object counting more accessible.
How:
- Formulate counting as a segmentation task using the Segment Anything Model (SAM).
- Propose a prior-guided mask generation method to improve SAM’s segmentation by incorporating 3 priors: similarity, segment and semantic priors.
- For text-specified objects, propose a 2-stage approach:
- Select reference objects from the image based on an initial text-image similarity map.
- Use the selected references with prior-guided mask generation.
- Experiments show the method achieves competitive performance compared to learning-based approaches, demonstrating the promise of training-free counting.
In summary, this paper tackles the limitations of existing learning-based counting methods by developing a training-free approach. It cleverly formulates counting as segmentation and enhances SAM’s masks using prior information. The training-free nature makes counting more accessible without requiring extensive data collection and model retraining.
Main Contributions
Here are the main contributions of this paper:
-
Proposes a training-free class-agnostic object counting approach by formulating counting as a segmentation task using the Segment Anything Model (SAM).
-
Introduces a prior-guided mask generation method that improves SAM’s segmentation by incorporating three types of priors:
- Similarity prior using a similarity map
- Segment prior using existing segmented regions
- Semantic prior using reference object features
-
Presents a two-stage approach to handle counting objects specified by free-form text:
- Stage 1: Select reference objects from the image based on an initial text-image similarity map
- Stage 2: Use the selected reference objects with prior-guided mask generation
-
Conducts extensive experiments on standard datasets that demonstrate the proposed training-free counting approach achieves competitive performance compared to learning-based methods.
-
The training-free nature makes object counting more accessible without requiring extensive data collection and model retraining for new scenarios.
In summary, the main contributions are: formulating counting as segmentation, proposing prior-guided mask generation, handling text-specified counting, and experimentally validating the effectiveness of training-free counting.
Method Section
Here is a summary of the method section from the paper:
The paper tackles class-agnostic object counting where the objects to count are specified by prompts like points, boxes or text.
The key idea is to formulate counting as a segmentation task using the Segment Anything Model (SAM). SAM can generate high-quality masks from prompts and performs zero-shot segmentation.
However, directly using SAM for counting lacks class-specific information and leads to poor results. To address this, the paper proposes a prior-guided mask generation method that incorporates 3 types of priors:
-
Similarity prior: Compute a similarity map between reference and image features. Use this to distinguish between positive and negative points when prompting SAM.
-
Segment prior: Maintain an overall segment map to avoid redundant segmentation. Use existing segments to guide SAM’s predictions.
-
Semantic prior: Incorporate reference object features as context when generating masks. This helps SAM identify the target objects.
For text-specified counting, a two-stage approach is proposed:
-
Stage 1: Select reference objects from the image based on an initial text-image similarity map from CLIP-Surgery. Refine the references using binarization and bounding boxes.
-
Stage 2: Use the refined references with prior-guided mask generation.
In summary, the key ideas are: formulating counting as segmentation, enhancing SAM’s masks using similarity, segment and semantic priors, and a two-stage approach to handle text-specified counting.
High-Level Pseudo Code
Here is the high-level pseudo code for the key components of the paper:
# Prior-guided mask generation
def generate_masks(image, prompts):
# Get image feature
feature = get_image_feature(image)
# Get reference masks and features
ref_masks, ref_features = get_reference(prompts)
# Compute similarity map
sim_map = get_similarity_map(feature, ref_features)
# Initialize overall segment map
segment_map = None
# Divide point grid into batches
point_batches = divide_points_into_batches(image)
# Process each batch
for batch in point_batches:
# Remove already segmented points
batch = remove_segmented_points(batch, segment_map)
# Label points as pos/neg using sim_map
batch = label_points(batch, sim_map)
# Pass image feature, ref features and labeled points to SAM
masks = sam(feature, ref_features, batch)
# Update segment map
segment_map.update(masks)
# Return all masks
return segment_map
# Text-specified counting
def count_text(image, text):
# Get coarse similarity map using CLIP-Surgery
sim_map_coarse = clip_surgery(image, text)
# Select reference objects
ref_boxes = select_references(sim_map_coarse)
# Get fine similarity map using references
sim_map_fine = get_similarity_map(image, ref_boxes)
# Generate masks
masks = generate_masks(image, ref_boxes, sim_map_fine)
# Return mask count
return len(masks)
The key steps are:
- Incorporate priors during mask generation
- For text, select good references using a coarse similarity map
- Pass selected references to mask generation
- Count masks as object count
Detailed Pseudo Code
Here is a more detailed pseudo code to implement the key components of the paper:
# Import SAM, CLIP-Surgery etc.
# Prior-guided mask generation
def generate_masks(image, prompts):
# Get image feature
feature = sam.encode_image(image)
# Get reference masks and features
ref_masks, ref_boxes = get_reference(prompts)
ref_features = feature * ref_masks
# Compute similarity map
sim_map = cosine_similarity(feature, ref_features)
# Binarize similarity map
bin_sim_map = otsu_threshold(sim_map)
# Initialize overall segment map
segment_map = {}
# Get point grid
point_grid = get_point_grid(image)
# Divide points into batches
point_batches = divide_into_batches(point_grid)
# Process each batch
for batch in point_batches:
# Remove already segmented points
batch = batch - segment_map.keys()
# Label points as pos/neg
pos_points = []
neg_points = []
for p in batch:
if bin_sim_map[p] > 0:
pos_points.append(p)
else:
neg_points.append(p)
# Pass image feature, ref features and labeled points to SAM
masks = sam.segment_image(feature, ref_features, pos_points)
# Update segment map
for m in masks:
segment_map[m.points] = m
# Return all masks
return segment_map.values()
# Text-specified counting
def count_text(image, text):
# Get coarse similarity map
text_emb = clip.encode_text(text)
sim_map_coarse = clip_surgery(image, text_emb)
# Binarize coarse sim map
bin_sim_map = otsu_threshold(sim_map_coarse)
# Get largest connected component
components = connected_components(bin_sim_map)
largest_cc = get_largest_component(components)
# Divide into sub-contours
sub_contours = divide_contour(largest_cc)
# Get bounding boxes for each sub-contour
ref_boxes = get_boxes(sub_contours)
# Apply NMS on boxes
ref_boxes = nms(ref_boxes)
# Get fine similarity map using refined references
ref_features = sam.encode_image(crop_boxes(image, ref_boxes))
sim_map_fine = cosine_similarity(sam.encode_image(image), ref_features)
# Generate masks
masks = generate_masks(image, ref_boxes, sim_map_fine)
# Return mask count
return len(masks)
The key aspects are:
- Specific implementations for getting features, similarity maps, processing points/batches
- Binarization, connected components and NMS for selecting good reference objects from text
- Reusing selected references to get a high quality similarity map
- Calling mask generation function
Let me know if you would like me to explain or expand on any part of this code!