| To load and initialize the `Generator` model from the repository, follow these steps: | |
| 1. **Install Required Packages**: Ensure you have the necessary Python packages installed: | |
| ```python | |
| pip install torch omegaconf huggingface_hub | |
| ``` | |
| 2. **Download Model Files**: Retrieve the `generator.pth`, `config.json`, and `model.py` files from the Hugging Face repository. You can use the `huggingface_hub` library for this: | |
| ```python | |
| from huggingface_hub import hf_hub_download | |
| repo_id = "Kiwinicki/sat2map-generator" | |
| generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.py") | |
| ``` | |
| 3. **Load the Model**: Incorporate the downloaded `model.py` to define the `Generator` class, then load the model's state dictionary and configuration: | |
| ```python | |
| import torch | |
| import json | |
| from omegaconf import OmegaConf | |
| import sys | |
| from pathlib import Path | |
| from model import Generator | |
| # Load configuration | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| cfg = OmegaConf.create(config_dict) | |
| # Initialize and load the generator model | |
| generator = Generator(cfg) | |
| generator.load_state_dict(torch.load(generator_path)) | |
| generator.eval() | |
| x = torch.randn([1, cfg['channels'], 256, 256]) | |
| out = generator(x) | |
| ``` | |
| Here, `generator` is the initialized model ready for inference. |