Update handler.py
Browse files- handler.py +4 -1
handler.py
CHANGED
@@ -70,7 +70,8 @@ class EndpointHandler:
|
|
70 |
inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
|
71 |
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
|
72 |
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
|
73 |
-
inputs.to("cuda").to(torch.bfloat16)
|
|
|
74 |
|
75 |
generation_args = {
|
76 |
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
|
@@ -82,7 +83,9 @@ class EndpointHandler:
|
|
82 |
logger.info(f"Running text generation with the following {generation_args=}")
|
83 |
|
84 |
with torch.inference_mode():
|
|
|
85 |
generate_ids = self.model.generate(**inputs, **generation_args)
|
|
|
86 |
|
87 |
logger.info(f"Generated {generate_ids=}")
|
88 |
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
|
|
70 |
inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
|
71 |
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
|
72 |
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
|
73 |
+
inputs = inputs.to("cuda").to(torch.bfloat16)
|
74 |
+
logger.info(f"Inputs contains {inputs=}")
|
75 |
|
76 |
generation_args = {
|
77 |
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
|
|
|
83 |
logger.info(f"Running text generation with the following {generation_args=}")
|
84 |
|
85 |
with torch.inference_mode():
|
86 |
+
logger.info(f"Inputs contains {inputs['input_ids']=}")
|
87 |
generate_ids = self.model.generate(**inputs, **generation_args)
|
88 |
+
logger.info(f"Generate IDs contains {generate_ids=}")
|
89 |
|
90 |
logger.info(f"Generated {generate_ids=}")
|
91 |
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|