alvarobartt HF Staff commited on
Commit
cb5f6e9
·
verified ·
1 Parent(s): 0aea239

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -1
handler.py CHANGED
@@ -70,7 +70,8 @@ class EndpointHandler:
70
  inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
71
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
72
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
73
- inputs.to("cuda").to(torch.bfloat16)
 
74
 
75
  generation_args = {
76
  "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
@@ -82,7 +83,9 @@ class EndpointHandler:
82
  logger.info(f"Running text generation with the following {generation_args=}")
83
 
84
  with torch.inference_mode():
 
85
  generate_ids = self.model.generate(**inputs, **generation_args)
 
86
 
87
  logger.info(f"Generated {generate_ids=}")
88
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
 
70
  inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
71
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
72
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
73
+ inputs = inputs.to("cuda").to(torch.bfloat16)
74
+ logger.info(f"Inputs contains {inputs=}")
75
 
76
  generation_args = {
77
  "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
 
83
  logger.info(f"Running text generation with the following {generation_args=}")
84
 
85
  with torch.inference_mode():
86
+ logger.info(f"Inputs contains {inputs['input_ids']=}")
87
  generate_ids = self.model.generate(**inputs, **generation_args)
88
+ logger.info(f"Generate IDs contains {generate_ids=}")
89
 
90
  logger.info(f"Generated {generate_ids=}")
91
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]