Ole-Christian Galbo Engstrøm commited on
Commit
9e3752d
·
1 Parent(s): 3bdae7b

Update to allow model to be loaded with transformers

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. __init__.py +10 -0
  3. requirements.txt +2 -1
  4. unet_config.py +21 -0
  5. unet_hf.py +20 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .unet_config import UNetConfig
4
+ from .unet_hf import UNetModel
5
+
6
+ # Register with Hugging Face Auto classes
7
+ AutoConfig.register("unet", UNetConfig)
8
+ AutoModel.register(UNetConfig, UNetModel)
9
+
10
+ __all__ = ["UNetConfig", "UNetModel"]
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- torch >= 2.7.1
 
 
1
+ torch >= 2.7.1
2
+ transformers >= 4.55.2
unet_config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class UNetConfig(PretrainedConfig):
5
+ model_type = "unet"
6
+
7
+ def __init__(
8
+ self,
9
+ in_channels=3,
10
+ out_channels=1,
11
+ pad=True,
12
+ bilinear=True,
13
+ normalization=None,
14
+ **kwargs,
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.in_channels = in_channels
18
+ self.out_channels = out_channels
19
+ self.pad = pad
20
+ self.bilinear = bilinear
21
+ self.normalization = normalization
unet_hf.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .unet import UNet
3
+ from .unet_config import UNetConfig
4
+
5
+
6
+ class UNetModel(PreTrainedModel):
7
+ config_class = UNetConfig
8
+
9
+ def __init__(self, config: UNetConfig):
10
+ super().__init__(config)
11
+ self.model = UNet(
12
+ in_channels=config.in_channels,
13
+ out_channels=config.out_channels,
14
+ pad=config.pad,
15
+ bilinear=config.bilinear,
16
+ normalization=config.normalization,
17
+ )
18
+
19
+ def forward(self, x):
20
+ return self.model(x)