ydshieh HF Staff commited on
Commit
19969a2
·
1 Parent(s): 93bcb8d

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -48
pipeline.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- from typing import Dict, List, Any
3
- from PIL import Image
4
- import jax
5
- from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
6
-
7
-
8
- class PreTrainedPipeline():
9
-
10
- def __init__(self, path=""):
11
-
12
- model_dir = path
13
-
14
- self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
15
- self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
-
18
- max_length = 16
19
- num_beams = 4
20
- self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
21
-
22
- @jax.jit
23
- def _generate(pixel_values):
24
-
25
- output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
26
- return output_ids
27
-
28
- self.generate = _generate
29
-
30
- # compile the model
31
- image_path = os.path.join(path, 'val_000000039769.jpg')
32
- image = Image.open(image_path)
33
- self(image)
34
- image.close()
35
-
36
- def __call__(self, inputs: "Image.Image") -> List[str]:
37
- """
38
- Args:
39
- Return:
40
- """
41
-
42
- pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
43
-
44
- output_ids = self.generate(pixel_values)
45
- preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
46
- preds = [pred.strip() for pred in preds]
47
-
48
- return preds