chansung commited on
Commit
99a3462
·
1 Parent(s): c9e354c

update custom handler

Browse files
__pycache__/handler.cpython-38.pyc CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
 
handler.py CHANGED
@@ -1,16 +1,18 @@
1
  from typing import Dict, List, Any
2
 
 
3
  import base64
4
  import math
5
  import numpy as np
6
  import tensorflow as tf
7
  from tensorflow import keras
8
 
9
- from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD
10
- from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
 
11
 
12
  class EndpointHandler():
13
- def __init__(self, path=""):
14
  self.seed = None
15
 
16
  img_height = 512
@@ -19,12 +21,31 @@ class EndpointHandler():
19
  self.img_width = round(img_width / 128) * 128
20
 
21
  self.MAX_PROMPT_LENGTH = 77
22
- self.diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
23
- diffusion_model_weights_fpath = keras.utils.get_file(
24
- origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
25
- file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
26
- )
27
- self.diffusion_model.load_weights(diffusion_model_weights_fpath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _get_initial_diffusion_noise(self, batch_size, seed):
30
  if seed is not None:
@@ -60,11 +81,17 @@ class EndpointHandler():
60
 
61
  context = base64.b64decode(contexts[0])
62
  context = np.frombuffer(context, dtype="float32")
63
- context = np.reshape(context, (batch_size, 77, 768))
 
 
 
64
 
65
  unconditional_context = base64.b64decode(contexts[1])
66
  unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
67
- unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
 
 
 
68
 
69
  num_steps = data.pop("num_steps", 25)
70
  unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
 
1
  from typing import Dict, List, Any
2
 
3
+ import sys
4
  import base64
5
  import math
6
  import numpy as np
7
  import tensorflow as tf
8
  from tensorflow import keras
9
 
10
+ from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
11
+ from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
12
+ from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModelV2
13
 
14
  class EndpointHandler():
15
+ def __init__(self, path="", version="2"):
16
  self.seed = None
17
 
18
  img_height = 512
 
21
  self.img_width = round(img_width / 128) * 128
22
 
23
  self.MAX_PROMPT_LENGTH = 77
24
+
25
+ self.version = version
26
+ self.diffusion_model = self._instantiate_diffusion_model(version)
27
+ if isinstance(self.diffusion_model, str):
28
+ sys.exit(self.diffusion_model)
29
+
30
+ def _instantiate_diffusion_model(self, version: str):
31
+ if version == "1.4":
32
+ diffusion_model_weights_fpath = keras.utils.get_file(
33
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
34
+ file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
35
+ )
36
+ diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
37
+ diffusion_model.load_weights(diffusion_model_weights_fpath)
38
+ return diffusion_model
39
+ elif version == "2":
40
+ diffusion_model_weights_fpath = keras.utils.get_file(
41
+ origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5",
42
+ file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d",
43
+ )
44
+ diffusion_model = DiffusionModelV2(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
45
+ diffusion_model.load_weights(diffusion_model_weights_fpath)
46
+ return diffusion_model
47
+ else:
48
+ return f"v{version} is not supported"
49
 
50
  def _get_initial_diffusion_noise(self, batch_size, seed):
51
  if seed is not None:
 
81
 
82
  context = base64.b64decode(contexts[0])
83
  context = np.frombuffer(context, dtype="float32")
84
+ if self.version == "1.4":
85
+ context = np.reshape(context, (batch_size, 77, 768))
86
+ else:
87
+ context = np.reshape(context, (batch_size, 77, 1024))
88
 
89
  unconditional_context = base64.b64decode(contexts[1])
90
  unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
91
+ if self.version == "1.4":
92
+ unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
93
+ else:
94
+ unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 1024))
95
 
96
  num_steps = data.pop("num_steps", 25)
97
  unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- keras-cv==0.3.4
2
  tensorflow==2.11
3
- tensorflow_datasets
 
1
+ keras-cv==0.4
2
  tensorflow==2.11
3
+ tensorflow_datasets