import requests | |
import torch | |
from PIL import Image | |
from safetensors import safe_open | |
def download_sample_image() -> Image.Image: | |
"""Download chest X-ray with CC license.""" | |
base_url = "https://upload.wikimedia.org/wikipedia/commons" | |
path = "2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg" | |
image_url = f"{base_url}/{path}" | |
headers = {"User-Agent": "RAD-DINO"} | |
response = requests.get(image_url, headers=headers, stream=True) | |
return Image.open(response.raw) | |
def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]: | |
state_dict = {} | |
with safe_open(checkpoint_path, framework="pt") as ckpt_file: | |
for key in ckpt_file.keys(): | |
state_dict[key] = ckpt_file.get_tensor(key) | |
return state_dict | |