sushmanth commited on
Commit
f8f9a2c
·
1 Parent(s): 2542a21

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +101 -90
README.md CHANGED
@@ -108,123 +108,134 @@ plt.show()
108
 
109
  ## Complete Example with Visualization
110
 
111
- Here's a complete example showing how to use SAM-HQ with the image embedding workflow and how to visualize the results:
112
-
113
  ```python
114
- import torch
115
  import numpy as np
116
  import matplotlib.pyplot as plt
117
- from PIL import Image
118
- import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  from transformers import SamHQModel, SamHQProcessor
120
 
121
- # 1. Load model and processor
122
  device = "cuda" if torch.cuda.is_available() else "cpu"
123
  model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
124
  processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
125
 
126
- # 2. Load and display image
 
127
  img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
128
  raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
129
- plt.figure(figsize=(10, 10))
130
  plt.imshow(raw_image)
131
- plt.axis('off')
132
- plt.show()
133
 
134
- # 3. Compute image embeddings
135
  inputs = processor(raw_image, return_tensors="pt").to(device)
136
  image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
137
 
138
- # 4. Define bounding box and visualize it
139
- input_boxes = [[[306, 132, 925, 893]]] # Define bounding box [x1, y1, x2, y2]
140
 
141
- # Helper function to display bounding box
142
- def show_box(box, ax):
143
- x0, y0 = box[0], box[1]
144
- w, h = box[2] - box[0], box[3] - box[1]
145
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
146
-
147
- plt.figure(figsize=(10, 10))
148
- plt.imshow(raw_image)
149
- for box in input_boxes[0]:
150
- show_box(box, plt.gca())
151
- plt.axis('on')
152
- plt.title("Input Image with Bounding Box")
153
- plt.show()
154
-
155
- # 5. Run inference with the bounding box
156
- # First update the inputs with the image embeddings
157
  inputs.pop("pixel_values", None)
158
  inputs.update({"image_embeddings": image_embeddings})
159
  inputs.update({"intermediate_embeddings": intermediate_embeddings})
160
- inputs.update({"input_boxes": torch.tensor(input_boxes).to(device)})
161
-
162
- # Run inference
163
  with torch.no_grad():
164
  outputs = model(**inputs)
165
-
166
- # 6. Post-process the masks
167
- masks = processor.image_processor.post_process_masks(
168
- outputs.pred_masks.cpu(),
169
- inputs["original_sizes"].cpu(),
170
- inputs["reshaped_input_sizes"].cpu()
171
- )
172
  scores = outputs.iou_scores
173
 
174
- # 7. Visualize results
175
- # Helper function to show masks
176
- def show_mask(mask, ax, random_color=False):
177
- if random_color:
178
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
179
- else:
180
- color = np.array([30/255, 144/255, 255/255, 0.6])
181
- h, w = mask.shape[-2:]
182
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
183
- ax.imshow(mask_image)
184
 
185
- # Show all masks with scores
186
- if len(masks[0].shape) == 4:
187
- masks_to_show = masks[0].squeeze()
188
- else:
189
- masks_to_show = masks[0]
190
-
191
- if scores.shape[0] == 1:
192
- scores_to_show = scores.squeeze()
193
- else:
194
- scores_to_show = scores
195
-
196
- # Create a figure with subplots for each mask
197
- nb_predictions = scores_to_show.shape[-1]
198
- fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
199
-
200
- # Handle the case where there's only one mask
201
- if nb_predictions == 1:
202
- axes = [axes]
203
-
204
- for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
205
- mask = mask.cpu().detach()
206
- axes[i].imshow(np.array(raw_image))
207
- show_mask(mask, axes[i])
208
- axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
209
- axes[i].axis("off")
210
- plt.tight_layout()
211
- plt.show()
212
-
213
- # Show all masks overlaid on a single image
214
- fig, ax = plt.subplots(figsize=(10, 10))
215
- ax.imshow(np.array(raw_image))
216
- for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
217
- if len(mask.shape) > 2:
218
- mask = mask.squeeze()
219
- show_mask(mask, ax, random_color=True)
220
- ax.set_title("All Masks Overlaid")
221
- ax.axis("off")
222
- plt.tight_layout()
223
- plt.show()
224
  ```
225
 
226
- This example demonstrates the complete workflow of using SAM-HQ with the "sushmanth/sam_hq_vit_b" model. It computes image embeddings once and then uses them for inference with a bounding box prompt. The resulting masks are visualized both individually with their confidence scores and overlaid on a single image with different colors.
227
-
228
  # Citation
229
 
230
  ```
 
108
 
109
  ## Complete Example with Visualization
110
 
 
 
111
  ```python
 
112
  import numpy as np
113
  import matplotlib.pyplot as plt
114
+ def show_mask(mask, ax, random_color=False):
115
+ if random_color:
116
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
117
+ else:
118
+ color = np.array([30/255, 144/255, 255/255, 0.6])
119
+ h, w = mask.shape[-2:]
120
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
121
+ ax.imshow(mask_image)
122
+ def show_box(box, ax):
123
+ x0, y0 = box[0], box[1]
124
+ w, h = box[2] - box[0], box[3] - box[1]
125
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
126
+ def show_boxes_on_image(raw_image, boxes):
127
+ plt.figure(figsize=(10,10))
128
+ plt.imshow(raw_image)
129
+ for box in boxes:
130
+ show_box(box, plt.gca())
131
+ plt.axis('on')
132
+ plt.show()
133
+ def show_points_on_image(raw_image, input_points, input_labels=None):
134
+ plt.figure(figsize=(10,10))
135
+ plt.imshow(raw_image)
136
+ input_points = np.array(input_points)
137
+ if input_labels is None:
138
+ labels = np.ones_like(input_points[:, 0])
139
+ else:
140
+ labels = np.array(input_labels)
141
+ show_points(input_points, labels, plt.gca())
142
+ plt.axis('on')
143
+ plt.show()
144
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
145
+ plt.figure(figsize=(10,10))
146
+ plt.imshow(raw_image)
147
+ input_points = np.array(input_points)
148
+ if input_labels is None:
149
+ labels = np.ones_like(input_points[:, 0])
150
+ else:
151
+ labels = np.array(input_labels)
152
+ show_points(input_points, labels, plt.gca())
153
+ for box in boxes:
154
+ show_box(box, plt.gca())
155
+ plt.axis('on')
156
+ plt.show()
157
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
158
+ plt.figure(figsize=(10,10))
159
+ plt.imshow(raw_image)
160
+ input_points = np.array(input_points)
161
+ if input_labels is None:
162
+ labels = np.ones_like(input_points[:, 0])
163
+ else:
164
+ labels = np.array(input_labels)
165
+ show_points(input_points, labels, plt.gca())
166
+ for box in boxes:
167
+ show_box(box, plt.gca())
168
+ plt.axis('on')
169
+ plt.show()
170
+ def show_points(coords, labels, ax, marker_size=375):
171
+ pos_points = coords[labels==1]
172
+ neg_points = coords[labels==0]
173
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
174
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
175
+ def show_masks_on_image(raw_image, masks, scores):
176
+ if len(masks.shape) == 4:
177
+ masks = masks.squeeze()
178
+ if scores.shape[0] == 1:
179
+ scores = scores.squeeze()
180
+ nb_predictions = scores.shape[-1]
181
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
182
+ for i, (mask, score) in enumerate(zip(masks, scores)):
183
+ mask = mask.cpu().detach()
184
+ axes[i].imshow(np.array(raw_image))
185
+ show_mask(mask, axes[i])
186
+ axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
187
+ axes[i].axis("off")
188
+ plt.show()
189
+ def show_masks_on_single_image(raw_image, masks, scores):
190
+ if len(masks.shape) == 4:
191
+ masks = masks.squeeze()
192
+ if scores.shape[0] == 1:
193
+ scores = scores.squeeze()
194
+ # Convert image to numpy array if it's not already
195
+ image_np = np.array(raw_image)
196
+ # Create a figure
197
+ fig, ax = plt.subplots(figsize=(8, 8))
198
+ ax.imshow(image_np)
199
+ # Overlay all masks on the same image
200
+ for i, (mask, score) in enumerate(zip(masks, scores)):
201
+ mask = mask.cpu().detach().numpy() # Convert to NumPy
202
+ show_mask(mask, ax) # Assuming `show_mask` properly overlays the mask
203
+ ax.set_title(f"Overlayed Masks with Scores")
204
+ ax.axis("off")
205
+ plt.show()
206
+
207
+ import torch
208
  from transformers import SamHQModel, SamHQProcessor
209
 
 
210
  device = "cuda" if torch.cuda.is_available() else "cpu"
211
  model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
212
  processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
213
 
214
+ from PIL import Image
215
+ import requests
216
  img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
217
  raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
 
218
  plt.imshow(raw_image)
 
 
219
 
 
220
  inputs = processor(raw_image, return_tensors="pt").to(device)
221
  image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
222
 
223
+ input_boxes = [[[306, 132, 925, 893]]]
224
+ show_boxes_on_image(raw_image, input_boxes[0])
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  inputs.pop("pixel_values", None)
227
  inputs.update({"image_embeddings": image_embeddings})
228
  inputs.update({"intermediate_embeddings": intermediate_embeddings})
 
 
 
229
  with torch.no_grad():
230
  outputs = model(**inputs)
231
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
 
 
 
 
 
 
232
  scores = outputs.iou_scores
233
 
234
+ show_masks_on_single_image(raw_image, masks[0], scores)
 
 
 
 
 
 
 
 
 
235
 
236
+ show_masks_on_image(raw_image, masks[0], scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  ```
238
 
 
 
239
  # Citation
240
 
241
  ```