chansung commited on
Commit
abe0476
·
1 Parent(s): c4a9b02

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -11
handler.py CHANGED
@@ -8,17 +8,21 @@ from keras_cv.models.generative.stable_diffusion.decoder import Decoder
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
- img_height = 512
12
- img_width = 512
13
- img_height = round(img_height / 128) * 128
14
- img_width = round(img_width / 128) * 128
15
-
16
- self.decoder = Decoder(img_height, img_width)
17
- decoder_weights_fpath = keras.utils.get_file(
18
- origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
19
- file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
20
- )
21
- self.decoder.load_weights(decoder_weights_fpath)
 
 
 
 
22
 
23
  def __call__(self, data: Dict[str, Any]) -> str:
24
  # get inputs
 
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
+ gpus = tf.config.list_physical_devices(device_type = 'GPU')
12
+ for gpu in gpus:
13
+ tf.config.experimental.set_memory_growth(gpu, True)
14
+
15
+ img_height = 512
16
+ img_width = 512
17
+ img_height = round(img_height / 128) * 128
18
+ img_width = round(img_width / 128) * 128
19
+
20
+ self.decoder = Decoder(img_height, img_width)
21
+ decoder_weights_fpath = keras.utils.get_file(
22
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
23
+ file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
24
+ )
25
+ self.decoder.load_weights(decoder_weights_fpath)
26
 
27
  def __call__(self, data: Dict[str, Any]) -> str:
28
  # get inputs