sushmanth commited on
Commit
2542a21
·
1 Parent(s): 3961515

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +207 -169
README.md CHANGED
@@ -1,199 +1,237 @@
1
  ---
2
- library_name: transformers
3
- tags: []
 
4
  ---
5
 
6
- # Model Card for Model ID
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
 
 
 
9
 
10
 
 
11
 
12
- ## Model Details
 
 
 
13
 
14
- ### Model Description
15
 
16
- <!-- Provide a longer summary of what this model is. -->
17
 
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
 
36
- ## Uses
37
 
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
 
40
- ### Direct Use
 
 
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
 
44
- [More Information Needed]
45
 
46
- ### Downstream Use [optional]
47
 
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
 
 
 
49
 
50
- [More Information Needed]
 
51
 
52
- ### Out-of-Scope Use
 
 
 
53
 
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
 
 
 
 
 
55
 
56
- [More Information Needed]
57
 
58
- ## Bias, Risks, and Limitations
59
 
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
 
62
- [More Information Needed]
63
 
64
- ### Recommendations
 
 
 
 
 
65
 
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ tags:
4
+ - vision
5
  ---
6
 
7
+ # Model Card for Segment Anything Model in High Quality (SAM-HQ)
8
 
9
+ <p align="center">
10
+ <img src="https://huggingface.co/sushmanth/sam_hq_vit_b/resolve/main/assets/arc.png" alt="SAM-HQ Architecture">
11
+ <em> Architecture of SAM-HQ compared to the original SAM model, showing the HQ-Output Token and Global-local Feature Fusion components.</em>
12
+ </p>
13
 
14
 
15
+ # Table of Contents
16
 
17
+ 0. [TL;DR](#TL;DR)
18
+ 1. [Model Details](#model-details)
19
+ 2. [Usage](#usage)
20
+ 3. [Citation](#citation)
21
 
22
+ # TL;DR
23
 
24
+ SAM-HQ (Segment Anything in High Quality) is an enhanced version of the Segment Anything Model (SAM) that produces higher quality object masks from input prompts such as points or boxes. While SAM was trained on a dataset of 11 million images and 1.1 billion masks, its mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. SAM-HQ addresses these limitations with minimal additional parameters and computation cost.
25
 
26
+ The model excels at generating high-quality segmentation masks, even for objects with complex boundaries and thin structures where the original SAM model often struggles. SAM-HQ maintains SAM's original promptable design, efficiency, and zero-shot generalizability while significantly improving mask quality.
27
 
28
+ # Model Details
 
 
 
 
 
 
29
 
30
+ SAM-HQ builds upon the original SAM architecture with two key innovations while preserving SAM's pretrained weights:
31
 
32
+ 1. **High-Quality Output Token**: A learnable token injected into SAM's mask decoder that is responsible for predicting high-quality masks. Unlike SAM's original output tokens, this token and its associated MLP layers are specifically trained to produce highly accurate segmentation masks.
33
 
34
+ 2. **Global-local Feature Fusion**: Instead of only applying the HQ-Output Token on mask-decoder features, SAM-HQ first fuses these features with early and final ViT features for improved mask details. This combines both the high-level semantic context and low-level boundary information for more accurate segmentation.
 
 
35
 
36
+ SAM-HQ was trained on a carefully curated dataset of 44K fine-grained masks (HQSeg-44K) compiled from several sources with extremely accurate annotations. The training process takes only 4 hours on 8 GPUs, introducing less than 0.5% additional parameters compared to the original SAM model.
37
 
38
+ The model has been evaluated on a suite of 10 diverse segmentation datasets across different downstream tasks, with 8 of them evaluated in a zero-shot transfer protocol. Results demonstrate that SAM-HQ can produce significantly better masks than the original SAM model while maintaining its zero-shot generalization capabilities.
39
 
40
+ SAM-HQ addresses two key problems with the original SAM model:
41
+ 1. Coarse mask boundaries, often neglecting thin object structures
42
+ 2. Incorrect predictions, broken masks, or large errors in challenging cases
43
 
44
+ These improvements make SAM-HQ particularly valuable for applications requiring highly accurate image masks, such as automated annotation and image/video editing tasks.
45
 
46
+ # Usage
47
 
48
+ ## Prompted-Mask-Generation
49
 
50
+ ```python
51
+ from PIL import Image
52
+ import requests
53
+ from transformers import SamHQModel, SamHQProcessor
54
 
55
+ model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
56
+ processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
57
 
58
+ img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
59
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
60
+ input_boxes = [[[306, 132, 925, 893]]] # Bounding box for the image
61
+ ```
62
 
63
+ ```python
64
+ inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to("cuda")
65
+ outputs = model(**inputs)
66
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
67
+ scores = outputs.iou_scores
68
+ ```
69
 
70
+ Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to the official repository. For more details, refer to this notebook, which shows a walkthrough of how to use the model, with a visual example!
71
 
72
+ ## Automatic-Mask-Generation
73
 
74
+ The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompted with a grid of `1024` points which are all fed to the model.
75
 
76
+ The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
77
 
78
+ ```python
79
+ from transformers import pipeline
80
+ generator = pipeline("mask-generation", model="sushmanth/sam_hq_vit_b", device=0, points_per_batch=256)
81
+ image_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
82
+ outputs = generator(image_url, points_per_batch=256)
83
+ ```
84
 
85
+ Now to display the image:
86
+
87
+ ```python
88
+ import matplotlib.pyplot as plt
89
+ from PIL import Image
90
+ import numpy as np
91
+
92
+ def show_mask(mask, ax, random_color=False):
93
+ if random_color:
94
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
95
+ else:
96
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
97
+ h, w = mask.shape[-2:]
98
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
99
+ ax.imshow(mask_image)
100
+
101
+ plt.imshow(np.array(raw_image))
102
+ ax = plt.gca()
103
+ for mask in outputs["masks"]:
104
+ show_mask(mask, ax=ax, random_color=True)
105
+ plt.axis("off")
106
+ plt.show()
107
+ ```
108
+
109
+ ## Complete Example with Visualization
110
+
111
+ Here's a complete example showing how to use SAM-HQ with the image embedding workflow and how to visualize the results:
112
+
113
+ ```python
114
+ import torch
115
+ import numpy as np
116
+ import matplotlib.pyplot as plt
117
+ from PIL import Image
118
+ import requests
119
+ from transformers import SamHQModel, SamHQProcessor
120
+
121
+ # 1. Load model and processor
122
+ device = "cuda" if torch.cuda.is_available() else "cpu"
123
+ model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
124
+ processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
125
+
126
+ # 2. Load and display image
127
+ img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
128
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
129
+ plt.figure(figsize=(10, 10))
130
+ plt.imshow(raw_image)
131
+ plt.axis('off')
132
+ plt.show()
133
+
134
+ # 3. Compute image embeddings
135
+ inputs = processor(raw_image, return_tensors="pt").to(device)
136
+ image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
137
+
138
+ # 4. Define bounding box and visualize it
139
+ input_boxes = [[[306, 132, 925, 893]]] # Define bounding box [x1, y1, x2, y2]
140
+
141
+ # Helper function to display bounding box
142
+ def show_box(box, ax):
143
+ x0, y0 = box[0], box[1]
144
+ w, h = box[2] - box[0], box[3] - box[1]
145
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
146
+
147
+ plt.figure(figsize=(10, 10))
148
+ plt.imshow(raw_image)
149
+ for box in input_boxes[0]:
150
+ show_box(box, plt.gca())
151
+ plt.axis('on')
152
+ plt.title("Input Image with Bounding Box")
153
+ plt.show()
154
+
155
+ # 5. Run inference with the bounding box
156
+ # First update the inputs with the image embeddings
157
+ inputs.pop("pixel_values", None)
158
+ inputs.update({"image_embeddings": image_embeddings})
159
+ inputs.update({"intermediate_embeddings": intermediate_embeddings})
160
+ inputs.update({"input_boxes": torch.tensor(input_boxes).to(device)})
161
+
162
+ # Run inference
163
+ with torch.no_grad():
164
+ outputs = model(**inputs)
165
+
166
+ # 6. Post-process the masks
167
+ masks = processor.image_processor.post_process_masks(
168
+ outputs.pred_masks.cpu(),
169
+ inputs["original_sizes"].cpu(),
170
+ inputs["reshaped_input_sizes"].cpu()
171
+ )
172
+ scores = outputs.iou_scores
173
+
174
+ # 7. Visualize results
175
+ # Helper function to show masks
176
+ def show_mask(mask, ax, random_color=False):
177
+ if random_color:
178
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
179
+ else:
180
+ color = np.array([30/255, 144/255, 255/255, 0.6])
181
+ h, w = mask.shape[-2:]
182
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
183
+ ax.imshow(mask_image)
184
+
185
+ # Show all masks with scores
186
+ if len(masks[0].shape) == 4:
187
+ masks_to_show = masks[0].squeeze()
188
+ else:
189
+ masks_to_show = masks[0]
190
+
191
+ if scores.shape[0] == 1:
192
+ scores_to_show = scores.squeeze()
193
+ else:
194
+ scores_to_show = scores
195
+
196
+ # Create a figure with subplots for each mask
197
+ nb_predictions = scores_to_show.shape[-1]
198
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
199
+
200
+ # Handle the case where there's only one mask
201
+ if nb_predictions == 1:
202
+ axes = [axes]
203
+
204
+ for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
205
+ mask = mask.cpu().detach()
206
+ axes[i].imshow(np.array(raw_image))
207
+ show_mask(mask, axes[i])
208
+ axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
209
+ axes[i].axis("off")
210
+ plt.tight_layout()
211
+ plt.show()
212
+
213
+ # Show all masks overlaid on a single image
214
+ fig, ax = plt.subplots(figsize=(10, 10))
215
+ ax.imshow(np.array(raw_image))
216
+ for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
217
+ if len(mask.shape) > 2:
218
+ mask = mask.squeeze()
219
+ show_mask(mask, ax, random_color=True)
220
+ ax.set_title("All Masks Overlaid")
221
+ ax.axis("off")
222
+ plt.tight_layout()
223
+ plt.show()
224
+ ```
225
+
226
+ This example demonstrates the complete workflow of using SAM-HQ with the "sushmanth/sam_hq_vit_b" model. It computes image embeddings once and then uses them for inference with a bounding box prompt. The resulting masks are visualized both individually with their confidence scores and overlaid on a single image with different colors.
227
+
228
+ # Citation
229
+
230
+ ```
231
+ @inproceedings{sam_hq,
232
+ title={Segment Anything in High Quality},
233
+ author={Ke, Lei and Ye, Mingqiao and Danelljan, Martin and Liu, Yifan and Tai, Yu-Wing and Tang, Chi-Keung and Yu, Fisher},
234
+ booktitle={NeurIPS},
235
+ year={2023}
236
+ }
237
+ ```