alvarobartt HF Staff commited on
Commit
6d88bf9
·
verified ·
1 Parent(s): 35a687f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -5
handler.py CHANGED
@@ -1,7 +1,7 @@
1
- from copy import deepcopy
2
  from typing import Any, Dict
3
 
4
  import torch
 
5
  from transformers import AutoModelForCausalLM, AutoProcessor
6
  from transformers.image_utils import load_image
7
 
@@ -23,20 +23,30 @@ class EndpointHandler:
23
  )
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Any:
26
- if "messages" not in data:
 
27
  raise ValueError(
28
- "The request body must contain a key 'messages' with a list of messages."
29
  )
30
 
 
31
  messages, images = [], []
32
- for message in data["messages"]:
 
33
  if isinstance(message["content"], list):
34
  new_message = {"role": message["role"], "content": ""}
35
  for content in message["content"]:
 
36
  if content["type"] == "text":
37
  new_message["content"] += content["text"]
38
  elif content["type"] == "image_url":
39
  images.append(load_image(content["image_url"]["url"]))
 
 
 
 
 
 
40
  if new_message["content"].count(
41
  f"{IMAGE_TOKENS}{SEPARATOR}"
42
  ) < len(images):
@@ -48,12 +58,14 @@ class EndpointHandler:
48
  {"role": message["role"], "content": message["content"]}
49
  )
50
 
51
- data.pop("messages")
52
 
 
53
  prompt = self.processor.tokenizer.apply_chat_template(
54
  messages, tokenize=False, add_generation_prompt=True
55
  )
56
 
 
57
  inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
58
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
59
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
@@ -67,13 +79,16 @@ class EndpointHandler:
67
  "num_beams": 1,
68
  }
69
  generation_args.update(data)
 
70
 
71
  with torch.inference_mode():
72
  generate_ids = self.model.generate(**inputs, **generation_args)
73
 
 
74
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
75
  response = self.processor.decode(
76
  generate_ids[0], skip_special_tokens=True
77
  ).strip()
 
78
 
79
  return {"generated_text": response}
 
 
1
  from typing import Any, Dict
2
 
3
  import torch
4
+ from huggingface_inference_toolkit.logging import logger
5
  from transformers import AutoModelForCausalLM, AutoProcessor
6
  from transformers.image_utils import load_image
7
 
 
23
  )
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Any:
26
+ logger.info(f"Received payload with {data}")
27
+ if "inputs" not in data:
28
  raise ValueError(
29
+ "The request body must contain a key 'inputs' with a list of messages."
30
  )
31
 
32
+ logger.info("Processing the messages...")
33
  messages, images = [], []
34
+ for message in data["inputs"]:
35
+ logger.info(f"Processing {message=}...")
36
  if isinstance(message["content"], list):
37
  new_message = {"role": message["role"], "content": ""}
38
  for content in message["content"]:
39
+ logger.info(f"{message=} is of type {content['type']}")
40
  if content["type"] == "text":
41
  new_message["content"] += content["text"]
42
  elif content["type"] == "image_url":
43
  images.append(load_image(content["image_url"]["url"]))
44
+ logger.info(
45
+ "Loaded image using `transformers.image_utils.load_image`"
46
+ )
47
+ logger.info(
48
+ "Current {new_message['content']} text if any contains {new_message['content'].count(IMAGE_TOKENS)} image tokens"
49
+ )
50
  if new_message["content"].count(
51
  f"{IMAGE_TOKENS}{SEPARATOR}"
52
  ) < len(images):
 
58
  {"role": message["role"], "content": message["content"]}
59
  )
60
 
61
+ data.pop("inputs")
62
 
63
+ logger.info(f"Applying chat template to {messages=}")
64
  prompt = self.processor.tokenizer.apply_chat_template(
65
  messages, tokenize=False, add_generation_prompt=True
66
  )
67
 
68
+ logger.info(f"Processing {len(images)} images...")
69
  inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
70
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
71
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
 
79
  "num_beams": 1,
80
  }
81
  generation_args.update(data)
82
+ logger.info(f"Running text generation with the following {generation_args=}")
83
 
84
  with torch.inference_mode():
85
  generate_ids = self.model.generate(**inputs, **generation_args)
86
 
87
+ logger.info(f"Generated {generate_ids=}")
88
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
89
  response = self.processor.decode(
90
  generate_ids[0], skip_special_tokens=True
91
  ).strip()
92
+ logger.info(f"Generated the {response=}")
93
 
94
  return {"generated_text": response}