[email protected]
		
	commited on
		
		
					Commit 
							
							·
						
						b925209
	
1
								Parent(s):
							
							07a7f16
								
Add benchmark evaluation scripts
Browse files- README.md +16 -0
 - eval/conversation.py +492 -0
 - eval/eval_dataset.py +850 -0
 - eval/full_eval.yaml +188 -0
 - eval/mmmu_utils.py +663 -0
 - eval/requirements.txt +3 -0
 - eval/vqa_utils.py +317 -0
 - run_eval.py +702 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -318,6 +318,22 @@ response = model.chat(tokenizer, pixel_values, question, generation_config) 
     | 
|
| 318 | 
         
             
            print(f'User: {question}\nAssistant: {response}')
         
     | 
| 319 | 
         
             
            ```
         
     | 
| 320 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 321 | 
         
             
            ## Software Integration
         
     | 
| 322 | 
         
             
            **Runtime Engine(s)** 
         
     | 
| 323 | 
         
             
            * PyTorch <br>
         
     | 
| 
         | 
|
| 318 | 
         
             
            print(f'User: {question}\nAssistant: {response}')
         
     | 
| 319 | 
         
             
            ```
         
     | 
| 320 | 
         | 
| 321 | 
         
            +
            ### Benchmark Evaluation
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
            To test our NVLM-1.0 model on the benchmark datasets, you can use the following code:
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
            ```bash
         
     | 
| 326 | 
         
            +
            python run_eval.py --config-path eval/full_eval.yaml \
         
     | 
| 327 | 
         
            +
             --result-save-path path/to/eval_results/ \
         
     | 
| 328 | 
         
            +
             --zero-shot-eval-tasks chartqa coco_caption flickr30k_caption vqav2 mmmu textvqa mathvista mmbench chartqa docvqa realworldqa ocrbench ai2diagram ai2diagram_nomask mmmu_pro docvqa_test
         
     | 
| 329 | 
         
            +
            ```
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
            Specifically,
         
     | 
| 332 | 
         
            +
            - `--config-path eval/full_eval.yaml` file contains the evaluation configurations, including  the evaluation prompt, the evaluation dataset paths, and generation hyper-parameters.
         
     | 
| 333 | 
         
            +
            - `--result-save-path path/to/eval_results/` specifies the path to save the evaluation results.
         
     | 
| 334 | 
         
            +
            - `--zero-shot-eval-tasks` specifies the tasks to evaluate on.
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
             
            ## Software Integration
         
     | 
| 338 | 
         
             
            **Runtime Engine(s)** 
         
     | 
| 339 | 
         
             
            * PyTorch <br>
         
     | 
    	
        eval/conversation.py
    ADDED
    
    | 
         @@ -0,0 +1,492 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # From https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import dataclasses
         
     | 
| 4 | 
         
            +
            from enum import auto, Enum
         
     | 
| 5 | 
         
            +
            from typing import List, Tuple
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class SeparatorStyle(Enum):
         
     | 
| 9 | 
         
            +
                """Different separator style."""
         
     | 
| 10 | 
         
            +
                SINGLE = auto()
         
     | 
| 11 | 
         
            +
                TWO = auto()
         
     | 
| 12 | 
         
            +
                MPT = auto()
         
     | 
| 13 | 
         
            +
                PLAIN = auto()
         
     | 
| 14 | 
         
            +
                LLAMA_2 = auto()
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            @dataclasses.dataclass
         
     | 
| 18 | 
         
            +
            class Conversation:
         
     | 
| 19 | 
         
            +
                """A class that keeps all conversation history."""
         
     | 
| 20 | 
         
            +
                system: str
         
     | 
| 21 | 
         
            +
                roles: List[str]
         
     | 
| 22 | 
         
            +
                messages: List[List[str]]
         
     | 
| 23 | 
         
            +
                offset: int
         
     | 
| 24 | 
         
            +
                sep_style: SeparatorStyle = SeparatorStyle.SINGLE
         
     | 
| 25 | 
         
            +
                sep: str = "###"
         
     | 
| 26 | 
         
            +
                sep2: str = None
         
     | 
| 27 | 
         
            +
                real_sep2: str = None
         
     | 
| 28 | 
         
            +
                version: str = "Unknown"
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                skip_next: bool = False
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def get_prompt(self):
         
     | 
| 33 | 
         
            +
                    messages = self.messages
         
     | 
| 34 | 
         
            +
                    if len(messages) > 0 and type(messages[0][1]) is tuple:
         
     | 
| 35 | 
         
            +
                        messages = self.messages.copy()
         
     | 
| 36 | 
         
            +
                        init_role, init_msg = messages[0].copy()
         
     | 
| 37 | 
         
            +
                        init_msg = init_msg[0].replace("<image>", "").strip()
         
     | 
| 38 | 
         
            +
                        if 'mmtag' in self.version:
         
     | 
| 39 | 
         
            +
                            messages[0] = (init_role, init_msg)
         
     | 
| 40 | 
         
            +
                            messages.insert(0, (self.roles[0], "<Image><image></Image>"))
         
     | 
| 41 | 
         
            +
                            messages.insert(1, (self.roles[1], "Received."))
         
     | 
| 42 | 
         
            +
                        else:
         
     | 
| 43 | 
         
            +
                            messages[0] = (init_role, "<image>\n" + init_msg)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    if self.sep_style == SeparatorStyle.SINGLE:
         
     | 
| 46 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 47 | 
         
            +
                        for role, message in messages:
         
     | 
| 48 | 
         
            +
                            if message:
         
     | 
| 49 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 50 | 
         
            +
                                    message, _, _ = message
         
     | 
| 51 | 
         
            +
                                ret += role + ": " + message + self.sep
         
     | 
| 52 | 
         
            +
                            else:
         
     | 
| 53 | 
         
            +
                                ret += role + ":"
         
     | 
| 54 | 
         
            +
                    elif self.sep_style == SeparatorStyle.TWO:
         
     | 
| 55 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 56 | 
         
            +
                        ret = self.system + seps[0]
         
     | 
| 57 | 
         
            +
                        for i, (role, message) in enumerate(messages):
         
     | 
| 58 | 
         
            +
                            if message:
         
     | 
| 59 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 60 | 
         
            +
                                    message, _, _ = message
         
     | 
| 61 | 
         
            +
                                ret += role + ": " + message + seps[i % 2]
         
     | 
| 62 | 
         
            +
                            else:
         
     | 
| 63 | 
         
            +
                                ret += role + ":"
         
     | 
| 64 | 
         
            +
                    elif self.sep_style == SeparatorStyle.MPT:
         
     | 
| 65 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 66 | 
         
            +
                        for role, message in messages:
         
     | 
| 67 | 
         
            +
                            if message:
         
     | 
| 68 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 69 | 
         
            +
                                    message, _, _ = message
         
     | 
| 70 | 
         
            +
                                ret += role + message + self.sep
         
     | 
| 71 | 
         
            +
                            else:
         
     | 
| 72 | 
         
            +
                                ret += role
         
     | 
| 73 | 
         
            +
                    elif self.sep_style == SeparatorStyle.LLAMA_2:
         
     | 
| 74 | 
         
            +
                        wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
         
     | 
| 75 | 
         
            +
                        wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
         
     | 
| 76 | 
         
            +
                        ret = ""
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                        for i, (role, message) in enumerate(messages):
         
     | 
| 79 | 
         
            +
                            if i == 0:
         
     | 
| 80 | 
         
            +
                                assert message, "first message should not be none"
         
     | 
| 81 | 
         
            +
                                assert role == self.roles[0], "first message should come from user"
         
     | 
| 82 | 
         
            +
                            if message:
         
     | 
| 83 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 84 | 
         
            +
                                    message, _, _ = message
         
     | 
| 85 | 
         
            +
                                if i == 0: message = wrap_sys(self.system) + message
         
     | 
| 86 | 
         
            +
                                if i % 2 == 0:
         
     | 
| 87 | 
         
            +
                                    message = wrap_inst(message)
         
     | 
| 88 | 
         
            +
                                    ret += self.sep + message
         
     | 
| 89 | 
         
            +
                                else:
         
     | 
| 90 | 
         
            +
                                    ret += " " + message + " " + self.sep2
         
     | 
| 91 | 
         
            +
                            else:
         
     | 
| 92 | 
         
            +
                                ret += ""
         
     | 
| 93 | 
         
            +
                        ret = ret.lstrip(self.sep)
         
     | 
| 94 | 
         
            +
                    elif self.sep_style == SeparatorStyle.PLAIN:
         
     | 
| 95 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 96 | 
         
            +
                        ret = self.system
         
     | 
| 97 | 
         
            +
                        for i, (role, message) in enumerate(messages):
         
     | 
| 98 | 
         
            +
                            if message:
         
     | 
| 99 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 100 | 
         
            +
                                    message, _, _ = message
         
     | 
| 101 | 
         
            +
                                ret += message + seps[i % 2]
         
     | 
| 102 | 
         
            +
                            else:
         
     | 
| 103 | 
         
            +
                                ret += ""
         
     | 
| 104 | 
         
            +
                    else:
         
     | 
| 105 | 
         
            +
                        raise ValueError(f"Invalid style: {self.sep_style}")
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    return ret
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def append_message(self, role, message):
         
     | 
| 110 | 
         
            +
                    self.messages.append([role, message])
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def get_images(self, return_pil=False):
         
     | 
| 113 | 
         
            +
                    images = []
         
     | 
| 114 | 
         
            +
                    for i, (role, msg) in enumerate(self.messages[self.offset:]):
         
     | 
| 115 | 
         
            +
                        if i % 2 == 0:
         
     | 
| 116 | 
         
            +
                            if type(msg) is tuple:
         
     | 
| 117 | 
         
            +
                                import base64
         
     | 
| 118 | 
         
            +
                                from io import BytesIO
         
     | 
| 119 | 
         
            +
                                from PIL import Image
         
     | 
| 120 | 
         
            +
                                msg, image, image_process_mode = msg
         
     | 
| 121 | 
         
            +
                                if image_process_mode == "Pad":
         
     | 
| 122 | 
         
            +
                                    def expand2square(pil_img, background_color=(122, 116, 104)):
         
     | 
| 123 | 
         
            +
                                        width, height = pil_img.size
         
     | 
| 124 | 
         
            +
                                        if width == height:
         
     | 
| 125 | 
         
            +
                                            return pil_img
         
     | 
| 126 | 
         
            +
                                        elif width > height:
         
     | 
| 127 | 
         
            +
                                            result = Image.new(pil_img.mode, (width, width), background_color)
         
     | 
| 128 | 
         
            +
                                            result.paste(pil_img, (0, (width - height) // 2))
         
     | 
| 129 | 
         
            +
                                            return result
         
     | 
| 130 | 
         
            +
                                        else:
         
     | 
| 131 | 
         
            +
                                            result = Image.new(pil_img.mode, (height, height), background_color)
         
     | 
| 132 | 
         
            +
                                            result.paste(pil_img, ((height - width) // 2, 0))
         
     | 
| 133 | 
         
            +
                                            return result
         
     | 
| 134 | 
         
            +
                                    image = expand2square(image)
         
     | 
| 135 | 
         
            +
                                elif image_process_mode in ["Default", "Crop"]:
         
     | 
| 136 | 
         
            +
                                    pass
         
     | 
| 137 | 
         
            +
                                elif image_process_mode == "Resize":
         
     | 
| 138 | 
         
            +
                                    image = image.resize((336, 336))
         
     | 
| 139 | 
         
            +
                                else:
         
     | 
| 140 | 
         
            +
                                    raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
         
     | 
| 141 | 
         
            +
                                max_hw, min_hw = max(image.size), min(image.size)
         
     | 
| 142 | 
         
            +
                                aspect_ratio = max_hw / min_hw
         
     | 
| 143 | 
         
            +
                                max_len, min_len = 800, 400
         
     | 
| 144 | 
         
            +
                                shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
         
     | 
| 145 | 
         
            +
                                longest_edge = int(shortest_edge * aspect_ratio)
         
     | 
| 146 | 
         
            +
                                W, H = image.size
         
     | 
| 147 | 
         
            +
                                if longest_edge != max(image.size):
         
     | 
| 148 | 
         
            +
                                    if H > W:
         
     | 
| 149 | 
         
            +
                                        H, W = longest_edge, shortest_edge
         
     | 
| 150 | 
         
            +
                                    else:
         
     | 
| 151 | 
         
            +
                                        H, W = shortest_edge, longest_edge
         
     | 
| 152 | 
         
            +
                                    image = image.resize((W, H))
         
     | 
| 153 | 
         
            +
                                if return_pil:
         
     | 
| 154 | 
         
            +
                                    images.append(image)
         
     | 
| 155 | 
         
            +
                                else:
         
     | 
| 156 | 
         
            +
                                    buffered = BytesIO()
         
     | 
| 157 | 
         
            +
                                    image.save(buffered, format="PNG")
         
     | 
| 158 | 
         
            +
                                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()
         
     | 
| 159 | 
         
            +
                                    images.append(img_b64_str)
         
     | 
| 160 | 
         
            +
                    return images
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def to_gradio_chatbot(self):
         
     | 
| 163 | 
         
            +
                    ret = []
         
     | 
| 164 | 
         
            +
                    for i, (role, msg) in enumerate(self.messages[self.offset:]):
         
     | 
| 165 | 
         
            +
                        if i % 2 == 0:
         
     | 
| 166 | 
         
            +
                            if type(msg) is tuple:
         
     | 
| 167 | 
         
            +
                                import base64
         
     | 
| 168 | 
         
            +
                                from io import BytesIO
         
     | 
| 169 | 
         
            +
                                msg, image, image_process_mode = msg
         
     | 
| 170 | 
         
            +
                                max_hw, min_hw = max(image.size), min(image.size)
         
     | 
| 171 | 
         
            +
                                aspect_ratio = max_hw / min_hw
         
     | 
| 172 | 
         
            +
                                max_len, min_len = 800, 400
         
     | 
| 173 | 
         
            +
                                shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
         
     | 
| 174 | 
         
            +
                                longest_edge = int(shortest_edge * aspect_ratio)
         
     | 
| 175 | 
         
            +
                                W, H = image.size
         
     | 
| 176 | 
         
            +
                                if H > W:
         
     | 
| 177 | 
         
            +
                                    H, W = longest_edge, shortest_edge
         
     | 
| 178 | 
         
            +
                                else:
         
     | 
| 179 | 
         
            +
                                    H, W = shortest_edge, longest_edge
         
     | 
| 180 | 
         
            +
                                image = image.resize((W, H))
         
     | 
| 181 | 
         
            +
                                buffered = BytesIO()
         
     | 
| 182 | 
         
            +
                                image.save(buffered, format="JPEG")
         
     | 
| 183 | 
         
            +
                                img_b64_str = base64.b64encode(buffered.getvalue()).decode()
         
     | 
| 184 | 
         
            +
                                img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
         
     | 
| 185 | 
         
            +
                                msg = img_str + msg.replace('<image>', '').strip()
         
     | 
| 186 | 
         
            +
                                ret.append([msg, None])
         
     | 
| 187 | 
         
            +
                            else:
         
     | 
| 188 | 
         
            +
                                ret.append([msg, None])
         
     | 
| 189 | 
         
            +
                        else:
         
     | 
| 190 | 
         
            +
                            ret[-1][-1] = msg
         
     | 
| 191 | 
         
            +
                    return ret
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def copy(self):
         
     | 
| 194 | 
         
            +
                    return Conversation(
         
     | 
| 195 | 
         
            +
                        system=self.system,
         
     | 
| 196 | 
         
            +
                        roles=self.roles,
         
     | 
| 197 | 
         
            +
                        messages=[[x, y] for x, y in self.messages],
         
     | 
| 198 | 
         
            +
                        offset=self.offset,
         
     | 
| 199 | 
         
            +
                        sep_style=self.sep_style,
         
     | 
| 200 | 
         
            +
                        sep=self.sep,
         
     | 
| 201 | 
         
            +
                        sep2=self.sep2,
         
     | 
| 202 | 
         
            +
                        real_sep2=self.real_sep2,
         
     | 
| 203 | 
         
            +
                        version=self.version)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def dict(self):
         
     | 
| 206 | 
         
            +
                    if len(self.get_images()) > 0:
         
     | 
| 207 | 
         
            +
                        return {
         
     | 
| 208 | 
         
            +
                            "system": self.system,
         
     | 
| 209 | 
         
            +
                            "roles": self.roles,
         
     | 
| 210 | 
         
            +
                            "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
         
     | 
| 211 | 
         
            +
                            "offset": self.offset,
         
     | 
| 212 | 
         
            +
                            "sep": self.sep,
         
     | 
| 213 | 
         
            +
                            "sep2": self.sep2,
         
     | 
| 214 | 
         
            +
                            "real_sep2": self.real_sep2
         
     | 
| 215 | 
         
            +
                        }
         
     | 
| 216 | 
         
            +
                    return {
         
     | 
| 217 | 
         
            +
                        "system": self.system,
         
     | 
| 218 | 
         
            +
                        "roles": self.roles,
         
     | 
| 219 | 
         
            +
                        "messages": self.messages,
         
     | 
| 220 | 
         
            +
                        "offset": self.offset,
         
     | 
| 221 | 
         
            +
                        "sep": self.sep,
         
     | 
| 222 | 
         
            +
                        "sep2": self.sep2,
         
     | 
| 223 | 
         
            +
                        "real_sep2": self.real_sep2
         
     | 
| 224 | 
         
            +
                    }
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            conv_vicuna_v0 = Conversation(
         
     | 
| 228 | 
         
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         
     | 
| 229 | 
         
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         
     | 
| 230 | 
         
            +
                roles=("Human", "Assistant"),
         
     | 
| 231 | 
         
            +
                messages=(
         
     | 
| 232 | 
         
            +
                    ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
         
     | 
| 233 | 
         
            +
                    ("Assistant",
         
     | 
| 234 | 
         
            +
                        "Renewable energy sources are those that can be replenished naturally in a relatively "
         
     | 
| 235 | 
         
            +
                        "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
         
     | 
| 236 | 
         
            +
                        "Non-renewable energy sources, on the other hand, are finite and will eventually be "
         
     | 
| 237 | 
         
            +
                        "depleted, such as coal, oil, and natural gas. Here are some key differences between "
         
     | 
| 238 | 
         
            +
                        "renewable and non-renewable energy sources:\n"
         
     | 
| 239 | 
         
            +
                        "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
         
     | 
| 240 | 
         
            +
                        "energy sources are finite and will eventually run out.\n"
         
     | 
| 241 | 
         
            +
                        "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
         
     | 
| 242 | 
         
            +
                        "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
         
     | 
| 243 | 
         
            +
                        "and other negative effects.\n"
         
     | 
| 244 | 
         
            +
                        "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
         
     | 
| 245 | 
         
            +
                        "have lower operational costs than non-renewable sources.\n"
         
     | 
| 246 | 
         
            +
                        "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
         
     | 
| 247 | 
         
            +
                        "locations than non-renewable sources.\n"
         
     | 
| 248 | 
         
            +
                        "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
         
     | 
| 249 | 
         
            +
                        "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
         
     | 
| 250 | 
         
            +
                        "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
         
     | 
| 251 | 
         
            +
                        "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
         
     | 
| 252 | 
         
            +
                ),
         
     | 
| 253 | 
         
            +
                offset=2,
         
     | 
| 254 | 
         
            +
                sep_style=SeparatorStyle.SINGLE,
         
     | 
| 255 | 
         
            +
                sep="###",
         
     | 
| 256 | 
         
            +
            )
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            ### Used for llava-instruction-tuning stage
         
     | 
| 259 | 
         
            +
            conv_vicuna_v1 = Conversation(
         
     | 
| 260 | 
         
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         
     | 
| 261 | 
         
            +
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         
     | 
| 262 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 263 | 
         
            +
                version="v1",
         
     | 
| 264 | 
         
            +
                messages=(),
         
     | 
| 265 | 
         
            +
                offset=0,
         
     | 
| 266 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 267 | 
         
            +
                sep=" ",
         
     | 
| 268 | 
         
            +
                sep2="</s>",
         
     | 
| 269 | 
         
            +
            )
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
            conv_llama_2 = Conversation(
         
     | 
| 272 | 
         
            +
                system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
         
     | 
| 275 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 276 | 
         
            +
                version="llama_v2",
         
     | 
| 277 | 
         
            +
                messages=(),
         
     | 
| 278 | 
         
            +
                offset=0,
         
     | 
| 279 | 
         
            +
                sep_style=SeparatorStyle.LLAMA_2,
         
     | 
| 280 | 
         
            +
                sep="<s>",
         
     | 
| 281 | 
         
            +
                sep2="</s>",
         
     | 
| 282 | 
         
            +
            )
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            conv_llava_llama_2 = Conversation(
         
     | 
| 285 | 
         
            +
                system="You are a helpful language and vision assistant. "
         
     | 
| 286 | 
         
            +
                       "You are able to understand the visual content that the user provides, "
         
     | 
| 287 | 
         
            +
                       "and assist the user with a variety of tasks using natural language.",
         
     | 
| 288 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 289 | 
         
            +
                version="llama_v2",
         
     | 
| 290 | 
         
            +
                messages=(),
         
     | 
| 291 | 
         
            +
                offset=0,
         
     | 
| 292 | 
         
            +
                sep_style=SeparatorStyle.LLAMA_2,
         
     | 
| 293 | 
         
            +
                sep="<s>",
         
     | 
| 294 | 
         
            +
                sep2="</s>",
         
     | 
| 295 | 
         
            +
            )
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            conv_mpt = Conversation(
         
     | 
| 298 | 
         
            +
                system="""<|im_start|>system
         
     | 
| 299 | 
         
            +
            A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
         
     | 
| 300 | 
         
            +
                roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         
     | 
| 301 | 
         
            +
                version="mpt",
         
     | 
| 302 | 
         
            +
                messages=(),
         
     | 
| 303 | 
         
            +
                offset=0,
         
     | 
| 304 | 
         
            +
                sep_style=SeparatorStyle.MPT,
         
     | 
| 305 | 
         
            +
                sep="<|im_end|>",
         
     | 
| 306 | 
         
            +
            )
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            ### Used for llava-pretraining 
         
     | 
| 311 | 
         
            +
            conv_llava_plain = Conversation(
         
     | 
| 312 | 
         
            +
                system="",
         
     | 
| 313 | 
         
            +
                roles=("", ""),
         
     | 
| 314 | 
         
            +
                messages=(
         
     | 
| 315 | 
         
            +
                ),
         
     | 
| 316 | 
         
            +
                offset=0,
         
     | 
| 317 | 
         
            +
                sep_style=SeparatorStyle.PLAIN,
         
     | 
| 318 | 
         
            +
                sep="\n",
         
     | 
| 319 | 
         
            +
            )
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
            conv_llava_v0 = Conversation(
         
     | 
| 322 | 
         
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         
     | 
| 323 | 
         
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         
     | 
| 324 | 
         
            +
                roles=("Human", "Assistant"),
         
     | 
| 325 | 
         
            +
                messages=(
         
     | 
| 326 | 
         
            +
                ),
         
     | 
| 327 | 
         
            +
                offset=0,
         
     | 
| 328 | 
         
            +
                sep_style=SeparatorStyle.SINGLE,
         
     | 
| 329 | 
         
            +
                sep="###",
         
     | 
| 330 | 
         
            +
            )
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
            conv_llava_v0_mmtag = Conversation(
         
     | 
| 333 | 
         
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         
     | 
| 334 | 
         
            +
                       "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
         
     | 
| 335 | 
         
            +
                       "The visual content will be provided with the following format: <Image>visual content</Image>.",
         
     | 
| 336 | 
         
            +
                roles=("Human", "Assistant"),
         
     | 
| 337 | 
         
            +
                messages=(
         
     | 
| 338 | 
         
            +
                ),
         
     | 
| 339 | 
         
            +
                offset=0,
         
     | 
| 340 | 
         
            +
                sep_style=SeparatorStyle.SINGLE,
         
     | 
| 341 | 
         
            +
                sep="###",
         
     | 
| 342 | 
         
            +
                version="v0_mmtag",
         
     | 
| 343 | 
         
            +
            )
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            conv_llava_v1 = Conversation(
         
     | 
| 346 | 
         
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         
     | 
| 347 | 
         
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         
     | 
| 348 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 349 | 
         
            +
                version="v1",
         
     | 
| 350 | 
         
            +
                messages=(),
         
     | 
| 351 | 
         
            +
                offset=0,
         
     | 
| 352 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 353 | 
         
            +
                sep=" ",
         
     | 
| 354 | 
         
            +
                sep2="</s>",
         
     | 
| 355 | 
         
            +
            )
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
            conv_llava_v1_mmtag = Conversation(
         
     | 
| 358 | 
         
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         
     | 
| 359 | 
         
            +
                       "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
         
     | 
| 360 | 
         
            +
                       "The visual content will be provided with the following format: <Image>visual content</Image>.",
         
     | 
| 361 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 362 | 
         
            +
                messages=(),
         
     | 
| 363 | 
         
            +
                offset=0,
         
     | 
| 364 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 365 | 
         
            +
                sep=" ",
         
     | 
| 366 | 
         
            +
                sep2="</s>",
         
     | 
| 367 | 
         
            +
                version="v1_mmtag",
         
     | 
| 368 | 
         
            +
            )
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
            nvllm_8b_pretrain = Conversation(
         
     | 
| 372 | 
         
            +
                system="",
         
     | 
| 373 | 
         
            +
                roles=(),
         
     | 
| 374 | 
         
            +
                version="nvllm_8b",
         
     | 
| 375 | 
         
            +
                messages=(),
         
     | 
| 376 | 
         
            +
                offset=0,
         
     | 
| 377 | 
         
            +
                sep_style=SeparatorStyle.SINGLE,
         
     | 
| 378 | 
         
            +
                sep="\n",
         
     | 
| 379 | 
         
            +
            )
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
            nvllm_8b_sft = Conversation(
         
     | 
| 382 | 
         
            +
                system="System: This is a chat between a user and an artificial intelligence assistant. "
         
     | 
| 383 | 
         
            +
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         
     | 
| 384 | 
         
            +
                roles=("User", "Assistant"),
         
     | 
| 385 | 
         
            +
                version="nvllm_8b",
         
     | 
| 386 | 
         
            +
                messages=(),
         
     | 
| 387 | 
         
            +
                offset=0,
         
     | 
| 388 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 389 | 
         
            +
                sep="\n\n",
         
     | 
| 390 | 
         
            +
                sep2="\n\n\n",
         
     | 
| 391 | 
         
            +
                real_sep2="\n\n"
         
     | 
| 392 | 
         
            +
            )
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
            chatqa_sft = Conversation(
         
     | 
| 395 | 
         
            +
                system="System: This is a chat between a user and an artificial intelligence assistant. "
         
     | 
| 396 | 
         
            +
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         
     | 
| 397 | 
         
            +
                roles=("User", "Assistant"),
         
     | 
| 398 | 
         
            +
                version="chatqa",
         
     | 
| 399 | 
         
            +
                messages=(),
         
     | 
| 400 | 
         
            +
                offset=0,
         
     | 
| 401 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 402 | 
         
            +
                sep="\n\n",
         
     | 
| 403 | 
         
            +
                sep2="\n\n",
         
     | 
| 404 | 
         
            +
                real_sep2="\n\n"
         
     | 
| 405 | 
         
            +
            )
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
            nvllm_8b_sft_noinstruction = Conversation(
         
     | 
| 408 | 
         
            +
                system="",
         
     | 
| 409 | 
         
            +
                roles=("User", "Assistant"),
         
     | 
| 410 | 
         
            +
                version="nvllm_8b",
         
     | 
| 411 | 
         
            +
                messages=(),
         
     | 
| 412 | 
         
            +
                offset=0,
         
     | 
| 413 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 414 | 
         
            +
                sep="\n\n",
         
     | 
| 415 | 
         
            +
                sep2="\n\n\n",
         
     | 
| 416 | 
         
            +
                real_sep2="\n\n"
         
     | 
| 417 | 
         
            +
            )
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
            # conv_yi = Conversation(
         
     | 
| 420 | 
         
            +
            #     system="""<|im_start|>system
         
     | 
| 421 | 
         
            +
            # A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
         
     | 
| 422 | 
         
            +
            #     roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         
     | 
| 423 | 
         
            +
            #     version="mpt",
         
     | 
| 424 | 
         
            +
            #     messages=(),
         
     | 
| 425 | 
         
            +
            #     offset=0,
         
     | 
| 426 | 
         
            +
            #     sep_style=SeparatorStyle.MPT,
         
     | 
| 427 | 
         
            +
            #     sep="<|im_end|>",
         
     | 
| 428 | 
         
            +
            # )
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
            conv_chatml = Conversation(
         
     | 
| 431 | 
         
            +
                system="""<|im_start|>system
         
     | 
| 432 | 
         
            +
            Answer the questions.""",
         
     | 
| 433 | 
         
            +
                roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         
     | 
| 434 | 
         
            +
                version="mpt",
         
     | 
| 435 | 
         
            +
                messages=(),
         
     | 
| 436 | 
         
            +
                offset=0,
         
     | 
| 437 | 
         
            +
                sep_style=SeparatorStyle.MPT,
         
     | 
| 438 | 
         
            +
                sep="<|im_end|>",
         
     | 
| 439 | 
         
            +
            )
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
            llama3_instruct = Conversation(
         
     | 
| 442 | 
         
            +
                system="<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
         
     | 
| 443 | 
         
            +
                roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
         
     | 
| 444 | 
         
            +
                version="mpt",
         
     | 
| 445 | 
         
            +
                messages=(),
         
     | 
| 446 | 
         
            +
                offset=0,
         
     | 
| 447 | 
         
            +
                sep_style=SeparatorStyle.MPT,
         
     | 
| 448 | 
         
            +
                sep="<|eot_id|>",
         
     | 
| 449 | 
         
            +
            )
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
            llama3_1_instruct = Conversation(
         
     | 
| 452 | 
         
            +
                system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
         
     | 
| 453 | 
         
            +
                roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
         
     | 
| 454 | 
         
            +
                version="mpt",
         
     | 
| 455 | 
         
            +
                messages=(),
         
     | 
| 456 | 
         
            +
                offset=0,
         
     | 
| 457 | 
         
            +
                sep_style=SeparatorStyle.MPT,
         
     | 
| 458 | 
         
            +
                sep="<|eot_id|>",
         
     | 
| 459 | 
         
            +
            )
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
            # default_conversation = conv_vicuna_v0
         
     | 
| 462 | 
         
            +
            default_conversation = nvllm_8b_sft
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
            # original_llava_pretraining = conv_llava_plain
         
     | 
| 465 | 
         
            +
            # original_llava_sft = conv_vicuna_v1
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
            conv_templates = {
         
     | 
| 468 | 
         
            +
                "default": conv_vicuna_v0,
         
     | 
| 469 | 
         
            +
                "v0": conv_vicuna_v0,
         
     | 
| 470 | 
         
            +
                "v1": conv_vicuna_v1,
         
     | 
| 471 | 
         
            +
                "vicuna_v1": conv_vicuna_v1,
         
     | 
| 472 | 
         
            +
                "llama_2": conv_llama_2,
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                "plain": conv_llava_plain,
         
     | 
| 475 | 
         
            +
                "v0_plain": conv_llava_plain,
         
     | 
| 476 | 
         
            +
                "llava_v0": conv_llava_v0,
         
     | 
| 477 | 
         
            +
                "v0_mmtag": conv_llava_v0_mmtag,
         
     | 
| 478 | 
         
            +
                "llava_v1": conv_llava_v1,
         
     | 
| 479 | 
         
            +
                "v1_mmtag": conv_llava_v1_mmtag,
         
     | 
| 480 | 
         
            +
                "llava_llama_2": conv_llava_llama_2,
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                "mpt": conv_mpt,
         
     | 
| 483 | 
         
            +
            }
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                print(default_conversation)
         
     | 
| 489 | 
         
            +
                print(default_conversation.roles[0])
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                # print(default_conversation.get_prompt())
         
     | 
    	
        eval/eval_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,850 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import sys
         
     | 
| 4 | 
         
            +
            import time
         
     | 
| 5 | 
         
            +
            import yaml
         
     | 
| 6 | 
         
            +
            import spacy
         
     | 
| 7 | 
         
            +
            import ast
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            from glob import glob
         
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
            from collections import defaultdict
         
     | 
| 12 | 
         
            +
            import pandas as pd
         
     | 
| 13 | 
         
            +
            from io import BytesIO
         
     | 
| 14 | 
         
            +
            import base64
         
     | 
| 15 | 
         
            +
            from anls import anls_score
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, DistributedSampler
         
     | 
| 18 | 
         
            +
            import torchvision.transforms as T
         
     | 
| 19 | 
         
            +
            from eval import conversation as conversation_lib
         
     | 
| 20 | 
         
            +
            from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
         
     | 
| 21 | 
         
            +
                process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
         
     | 
| 22 | 
         
            +
            from eval.mmmu_utils import evaluate as evaluate_mmmu
         
     | 
| 23 | 
         
            +
            from torchvision.transforms.functional import InterpolationMode
         
     | 
| 24 | 
         
            +
            from datasets import load_dataset, concatenate_datasets
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            IMAGENET_MEAN = (0.485, 0.456, 0.406)
         
     | 
| 27 | 
         
            +
            IMAGENET_STD = (0.229, 0.224, 0.225)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def build_transform(input_size):
         
     | 
| 31 | 
         
            +
                MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
         
     | 
| 32 | 
         
            +
                transform = T.Compose([
         
     | 
| 33 | 
         
            +
                    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         
     | 
| 34 | 
         
            +
                    T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
         
     | 
| 35 | 
         
            +
                    T.ToTensor(),
         
     | 
| 36 | 
         
            +
                    T.Normalize(mean=MEAN, std=STD)
         
     | 
| 37 | 
         
            +
                ])
         
     | 
| 38 | 
         
            +
                return transform
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
         
     | 
| 42 | 
         
            +
                best_ratio_diff = float('inf')
         
     | 
| 43 | 
         
            +
                best_ratio = (1, 1)
         
     | 
| 44 | 
         
            +
                area = width * height
         
     | 
| 45 | 
         
            +
                for ratio in target_ratios:
         
     | 
| 46 | 
         
            +
                    target_aspect_ratio = ratio[0] / ratio[1]
         
     | 
| 47 | 
         
            +
                    ratio_diff = abs(aspect_ratio - target_aspect_ratio)
         
     | 
| 48 | 
         
            +
                    if ratio_diff < best_ratio_diff:
         
     | 
| 49 | 
         
            +
                        best_ratio_diff = ratio_diff
         
     | 
| 50 | 
         
            +
                        best_ratio = ratio
         
     | 
| 51 | 
         
            +
                    elif ratio_diff == best_ratio_diff:
         
     | 
| 52 | 
         
            +
                        if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
         
     | 
| 53 | 
         
            +
                            best_ratio = ratio
         
     | 
| 54 | 
         
            +
                return best_ratio
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
         
     | 
| 58 | 
         
            +
                orig_width, orig_height = image.size
         
     | 
| 59 | 
         
            +
                aspect_ratio = orig_width / orig_height
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                # calculate the existing image aspect ratio
         
     | 
| 62 | 
         
            +
                target_ratios = set(
         
     | 
| 63 | 
         
            +
                    (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
         
     | 
| 64 | 
         
            +
                    i * j <= max_num and i * j >= min_num)
         
     | 
| 65 | 
         
            +
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                # find the closest aspect ratio to the target
         
     | 
| 68 | 
         
            +
                target_aspect_ratio = find_closest_aspect_ratio(
         
     | 
| 69 | 
         
            +
                    aspect_ratio, target_ratios, orig_width, orig_height, image_size)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                # calculate the target width and height
         
     | 
| 72 | 
         
            +
                target_width = image_size * target_aspect_ratio[0]
         
     | 
| 73 | 
         
            +
                target_height = image_size * target_aspect_ratio[1]
         
     | 
| 74 | 
         
            +
                blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                # resize the image
         
     | 
| 77 | 
         
            +
                resized_img = image.resize((target_width, target_height))
         
     | 
| 78 | 
         
            +
                processed_images = []
         
     | 
| 79 | 
         
            +
                for i in range(blocks):
         
     | 
| 80 | 
         
            +
                    box = (
         
     | 
| 81 | 
         
            +
                        (i % (target_width // image_size)) * image_size,
         
     | 
| 82 | 
         
            +
                        (i // (target_width // image_size)) * image_size,
         
     | 
| 83 | 
         
            +
                        ((i % (target_width // image_size)) + 1) * image_size,
         
     | 
| 84 | 
         
            +
                        ((i // (target_width // image_size)) + 1) * image_size
         
     | 
| 85 | 
         
            +
                    )
         
     | 
| 86 | 
         
            +
                    # split the image
         
     | 
| 87 | 
         
            +
                    split_img = resized_img.crop(box)
         
     | 
| 88 | 
         
            +
                    processed_images.append(split_img)
         
     | 
| 89 | 
         
            +
                assert len(processed_images) == blocks
         
     | 
| 90 | 
         
            +
                if use_thumbnail and len(processed_images) != 1:
         
     | 
| 91 | 
         
            +
                    thumbnail_img = image.resize((image_size, image_size))
         
     | 
| 92 | 
         
            +
                    processed_images.append(thumbnail_img)
         
     | 
| 93 | 
         
            +
                return processed_images
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def load_image(image, input_size=448, max_num=6, decoded=False):
         
     | 
| 97 | 
         
            +
                if not decoded:
         
     | 
| 98 | 
         
            +
                    image = Image.open(image).convert('RGB')
         
     | 
| 99 | 
         
            +
                transform = build_transform(input_size=input_size)
         
     | 
| 100 | 
         
            +
                images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
         
     | 
| 101 | 
         
            +
                pixel_values = [transform(image) for image in images]
         
     | 
| 102 | 
         
            +
                pixel_values = torch.stack(pixel_values)
         
     | 
| 103 | 
         
            +
                return pixel_values
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            def levenshtein_distance(s1, s2):
         
     | 
| 107 | 
         
            +
                if len(s1) > len(s2):
         
     | 
| 108 | 
         
            +
                    s1, s2 = s2, s1
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                distances = range(len(s1) + 1)
         
     | 
| 111 | 
         
            +
                for i2, c2 in enumerate(s2):
         
     | 
| 112 | 
         
            +
                    distances_ = [i2 + 1]
         
     | 
| 113 | 
         
            +
                    for i1, c1 in enumerate(s1):
         
     | 
| 114 | 
         
            +
                        if c1 == c2:
         
     | 
| 115 | 
         
            +
                            distances_.append(distances[i1])
         
     | 
| 116 | 
         
            +
                        else:
         
     | 
| 117 | 
         
            +
                            distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
         
     | 
| 118 | 
         
            +
                    distances = distances_
         
     | 
| 119 | 
         
            +
                return distances[-1]
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            def get_anls_score(pred, gold_labels, threshold, llava_eval=False):
         
     | 
| 123 | 
         
            +
                values = []
         
     | 
| 124 | 
         
            +
                for answer in gold_labels:
         
     | 
| 125 | 
         
            +
                    # preprocess both the answers - gt and prediction
         
     | 
| 126 | 
         
            +
                    gt_answer = ' '.join(answer.strip().lower().split())
         
     | 
| 127 | 
         
            +
                    det_answer = ' '.join(pred.strip().lower().split())
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    dist = levenshtein_distance(gt_answer, det_answer)
         
     | 
| 130 | 
         
            +
                    length = max(len(answer.upper()), len(pred.upper()))
         
     | 
| 131 | 
         
            +
                    values.append(0.0 if length == 0 else float(dist) / float(length))
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                question_result = 1 - min(values)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                if llava_eval:
         
     | 
| 136 | 
         
            +
                    question_result = 1.0 if question_result >= threshold else 0.0
         
     | 
| 137 | 
         
            +
                else:
         
     | 
| 138 | 
         
            +
                    if (question_result < threshold):
         
     | 
| 139 | 
         
            +
                        question_result = 0
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                return question_result
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            def isNumber(n: str):
         
     | 
| 145 | 
         
            +
                try:
         
     | 
| 146 | 
         
            +
                    float(n)
         
     | 
| 147 | 
         
            +
                    return True
         
     | 
| 148 | 
         
            +
                except ValueError:
         
     | 
| 149 | 
         
            +
                    return False
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class COCOEvalDataset(Dataset):
         
     | 
| 153 | 
         
            +
                def __init__(self, args, img_dir, subset=None):
         
     | 
| 154 | 
         
            +
                    self.args = args
         
     | 
| 155 | 
         
            +
                    self.img_files = sorted(glob(os.path.join(img_dir, "*")))
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    if subset:
         
     | 
| 158 | 
         
            +
                        self.img_files = self.img_files[:subset]
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def __len__(self):
         
     | 
| 163 | 
         
            +
                    return len(self.img_files)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 166 | 
         
            +
                    img_path = self.img_files[idx]
         
     | 
| 167 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    return self.image_ids[idx], img
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            class Flickr30KEvalDataset(Dataset):
         
     | 
| 173 | 
         
            +
                def __init__(self, args, img_dir, subset=None):
         
     | 
| 174 | 
         
            +
                    self.args = args
         
     | 
| 175 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 176 | 
         
            +
                    self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8'))
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    if subset:
         
     | 
| 179 | 
         
            +
                        self.test_samples = self.test_samples[:subset]
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                def __len__(self):
         
     | 
| 182 | 
         
            +
                    return len(self.test_samples)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 185 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"])
         
     | 
| 186 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", ""))
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    return image_id, img
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            class VQAv2EvalDataset(Dataset):
         
     | 
| 194 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 195 | 
         
            +
                    self.args = args
         
     | 
| 196 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 197 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    if subset:
         
     | 
| 200 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def __len__(self):
         
     | 
| 203 | 
         
            +
                    return len(self.gt)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 206 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
         
     | 
| 207 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 210 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 211 | 
         
            +
                    answer = self.gt[idx]["answer"]
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            class TextVQAEvalDataset(Dataset):
         
     | 
| 217 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 218 | 
         
            +
                    self.args = args
         
     | 
| 219 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 220 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    if subset:
         
     | 
| 223 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                def __len__(self):
         
     | 
| 226 | 
         
            +
                    return len(self.gt)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 229 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg')
         
     | 
| 230 | 
         
            +
                    if not os.path.exists(img_path):
         
     | 
| 231 | 
         
            +
                        img_path = img_path.replace('.jpg', '.png')
         
     | 
| 232 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 235 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 236 | 
         
            +
                    answer = self.gt[idx]["answers"]
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
            class GQAEvalDataset(Dataset):
         
     | 
| 242 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 243 | 
         
            +
                    self.args = args
         
     | 
| 244 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 245 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))
         
     | 
| 246 | 
         
            +
                    self.gt = [{
         
     | 
| 247 | 
         
            +
                        "question_id": int(k),
         
     | 
| 248 | 
         
            +
                        "image": v['imageId'] + ".jpg",
         
     | 
| 249 | 
         
            +
                        "question": v['question'],
         
     | 
| 250 | 
         
            +
                        "answer": v['answer']
         
     | 
| 251 | 
         
            +
                    } for k, v in self.gt.items()]
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    if subset:
         
     | 
| 254 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def __len__(self):
         
     | 
| 257 | 
         
            +
                    return len(self.gt)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 260 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
         
     | 
| 261 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 264 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 265 | 
         
            +
                    answer = self.gt[idx]["answer"]
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    return img, question_id, question, [answer]
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            class ChartQAEvalDataset(Dataset):
         
     | 
| 271 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 272 | 
         
            +
                    self.args = args
         
     | 
| 273 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 274 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))
         
     | 
| 275 | 
         
            +
                    for i in range(len(self.gt)):
         
     | 
| 276 | 
         
            +
                        self.gt[i]['question_id'] = i
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    if subset:
         
     | 
| 279 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                def __len__(self):
         
     | 
| 282 | 
         
            +
                    return len(self.gt)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 285 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"])
         
     | 
| 286 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 289 | 
         
            +
                    question = self.gt[idx]["query"]
         
     | 
| 290 | 
         
            +
                    answer = self.gt[idx]["label"]
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    return img, question_id, question, [answer]
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
            class OKVQAEvalDataset(Dataset):
         
     | 
| 296 | 
         
            +
                def __init__(self, args, img_dir, gt_path, question_path, subset=None):
         
     | 
| 297 | 
         
            +
                    self.args = args
         
     | 
| 298 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 299 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations']
         
     | 
| 300 | 
         
            +
                    self.questions = json.load(open(question_path, 'r'))['questions']
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    if subset:
         
     | 
| 303 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    qid2q = {q['question_id']: q['question'] for q in self.questions}
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    for ann in self.gt:
         
     | 
| 308 | 
         
            +
                        ann['answers'] = [ans['answer'] for ans in ann['answers']]
         
     | 
| 309 | 
         
            +
                        ann['question'] = qid2q[ann['question_id']]
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                def __len__(self):
         
     | 
| 312 | 
         
            +
                    return len(self.gt)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 315 | 
         
            +
                    img_id = str(self.gt[idx]["image_id"])
         
     | 
| 316 | 
         
            +
                    img_id = '0' * (12 - len(img_id)) + img_id
         
     | 
| 317 | 
         
            +
                    img_file_name = f"COCO_val2014_{img_id}.jpg"
         
     | 
| 318 | 
         
            +
                    img_path = os.path.join(self.img_dir, img_file_name)
         
     | 
| 319 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 322 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 323 | 
         
            +
                    answer = self.gt[idx]["answers"]
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
            class DocVQAEvalDataset(Dataset):
         
     | 
| 329 | 
         
            +
                def __init__(self, args, img_dir, gt_path, split='val', subset=None):
         
     | 
| 330 | 
         
            +
                    self.args = args
         
     | 
| 331 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 332 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    if subset:
         
     | 
| 335 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    self.split = split
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                def __len__(self):
         
     | 
| 340 | 
         
            +
                    return len(self.gt)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 343 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1])
         
     | 
| 344 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    question_id = self.gt[idx]["questionId"]
         
     | 
| 347 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    if self.split == 'val':
         
     | 
| 350 | 
         
            +
                        answer = self.gt[idx]["answers"]
         
     | 
| 351 | 
         
            +
                    else:
         
     | 
| 352 | 
         
            +
                        answer = ['']
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
            class OCRBenchEvalDataset(Dataset):
         
     | 
| 358 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 359 | 
         
            +
                    self.args = args
         
     | 
| 360 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 361 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    if subset:
         
     | 
| 364 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                def __len__(self):
         
     | 
| 367 | 
         
            +
                    return len(self.gt)
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 370 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]['image_path'])
         
     | 
| 371 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    dataset_name = self.gt[idx]["dataset_name"]
         
     | 
| 374 | 
         
            +
                    question_id = f"{idx}"
         
     | 
| 375 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 376 | 
         
            +
                    answer = self.gt[idx]["answers"]
         
     | 
| 377 | 
         
            +
                    data_type = self.gt[idx]["type"]
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    return img, question_id, question, answer, dataset_name, data_type
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
            class AI2DiagramEvalDataset(Dataset):
         
     | 
| 383 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 384 | 
         
            +
                    self.args = args
         
     | 
| 385 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    with open(gt_path, 'r') as json_file:
         
     | 
| 388 | 
         
            +
                        json_list = list(json_file)
         
     | 
| 389 | 
         
            +
                    self.gt = [json.loads(json_str) for json_str in json_list]
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    if subset:
         
     | 
| 392 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                def __len__(self):
         
     | 
| 395 | 
         
            +
                    return len(self.gt)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 398 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
         
     | 
| 399 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 402 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 403 | 
         
            +
                    answer = self.gt[idx]["answer"]
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
            class AI2DiagramNoMaskEvalDataset(Dataset):
         
     | 
| 409 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 410 | 
         
            +
                    self.args = args
         
     | 
| 411 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    with open(gt_path, 'r') as json_file:
         
     | 
| 414 | 
         
            +
                        json_list = list(json_file)
         
     | 
| 415 | 
         
            +
                    self.gt = [json.loads(json_str) for json_str in json_list]
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                    if subset:
         
     | 
| 418 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                def __len__(self):
         
     | 
| 421 | 
         
            +
                    return len(self.gt)
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 424 | 
         
            +
                    img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES")
         
     | 
| 425 | 
         
            +
                    img_path = os.path.join(self.img_dir, img_file_name)
         
     | 
| 426 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    question_id = self.gt[idx]["question_id"]
         
     | 
| 429 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 430 | 
         
            +
                    answer = self.gt[idx]["answer"]
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    return img, question_id, question, answer
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
            class RealworldQAEvalDataset(Dataset):
         
     | 
| 436 | 
         
            +
                def __init__(self, args, img_dir, gt_path, subset=None):
         
     | 
| 437 | 
         
            +
                    self.args = args
         
     | 
| 438 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 439 | 
         
            +
                    self.gt = json.load(open(gt_path, encoding='utf-8'))
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                    if subset:
         
     | 
| 442 | 
         
            +
                        self.gt = self.gt[:subset]
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                def __len__(self):
         
     | 
| 445 | 
         
            +
                    return len(self.gt)
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 448 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
         
     | 
| 449 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    question_id = int(self.gt[idx]['image'].replace(".webp", ""))
         
     | 
| 452 | 
         
            +
                    question = self.gt[idx]["question"]
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    if self.gt[idx]['question_type'] == "multi-choice":
         
     | 
| 455 | 
         
            +
                        choices = self.gt[idx]["choices"]
         
     | 
| 456 | 
         
            +
                        start_chr = 'A'
         
     | 
| 457 | 
         
            +
                        choices_str = ''
         
     | 
| 458 | 
         
            +
                        index2ans = {}
         
     | 
| 459 | 
         
            +
                        all_choices = []
         
     | 
| 460 | 
         
            +
                        for choice in choices:
         
     | 
| 461 | 
         
            +
                            all_choices.append(start_chr)
         
     | 
| 462 | 
         
            +
                            index2ans[start_chr] = choice
         
     | 
| 463 | 
         
            +
                            choices_str += f"{start_chr}. {choice}\n"
         
     | 
| 464 | 
         
            +
                            start_chr = chr(ord(start_chr) + 1)
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                        question = question + '\n' + choices_str
         
     | 
| 467 | 
         
            +
                        question = question + "Answer with the option's letter from the given choices directly."
         
     | 
| 468 | 
         
            +
                        answer = chr(ord('A') + self.gt[idx]['correct_choice_index'])
         
     | 
| 469 | 
         
            +
                    else:
         
     | 
| 470 | 
         
            +
                        question = question + "\nAnswer the question using a single word or phrase."
         
     | 
| 471 | 
         
            +
                        answer = self.gt[idx]['answer']
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                    return img, question_id, question, [answer]
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
            class MathVistaEvalDataset(Dataset):
         
     | 
| 477 | 
         
            +
                def __init__(self, args, task_cfg, gt_path=None):
         
     | 
| 478 | 
         
            +
                    self.args = args
         
     | 
| 479 | 
         
            +
                    self.task_cfg = task_cfg
         
     | 
| 480 | 
         
            +
                    self.dataset = load_dataset("AI4Math/MathVista")['testmini']
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                def __len__(self):
         
     | 
| 483 | 
         
            +
                    return len(self.dataset)
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 486 | 
         
            +
                    img = self.dataset[idx]['decoded_image']
         
     | 
| 487 | 
         
            +
                    img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    question_id = self.dataset[idx]["pid"]
         
     | 
| 490 | 
         
            +
                    question = self.dataset[idx]["question"]
         
     | 
| 491 | 
         
            +
                    question_type = self.dataset[idx]["question_type"]  # free_form or multi_choice
         
     | 
| 492 | 
         
            +
                    query = self.dataset[idx]["query"]
         
     | 
| 493 | 
         
            +
                    choices = self.dataset[idx]["choices"]
         
     | 
| 494 | 
         
            +
                    answer = self.dataset[idx]["answer"]
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                    if question_type == 'multi_choice':
         
     | 
| 497 | 
         
            +
                        start_chr = 'A'
         
     | 
| 498 | 
         
            +
                        choices_str = ''
         
     | 
| 499 | 
         
            +
                        index2ans = {}
         
     | 
| 500 | 
         
            +
                        all_choices = []
         
     | 
| 501 | 
         
            +
                        for choice in choices:
         
     | 
| 502 | 
         
            +
                            all_choices.append(start_chr)
         
     | 
| 503 | 
         
            +
                            index2ans[start_chr] = choice
         
     | 
| 504 | 
         
            +
                            choices_str += f"{start_chr}. {choice}\n"
         
     | 
| 505 | 
         
            +
                            start_chr = chr(ord(start_chr) + 1)
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                        question = question + '\n' + choices_str
         
     | 
| 508 | 
         
            +
                        question = question + "Answer with the option's letter from the given choices directly."
         
     | 
| 509 | 
         
            +
                        answer = chr(ord('A') + choices.index(answer))
         
     | 
| 510 | 
         
            +
                    else:
         
     | 
| 511 | 
         
            +
                        question = query.replace("Hint: ", "")
         
     | 
| 512 | 
         
            +
                        index2ans = {}
         
     | 
| 513 | 
         
            +
                        all_choices = []
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    return img, question_id, question_type, question, answer, str(index2ans), str(all_choices)
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
            def construct_prompt_for_fewshot(sample):
         
     | 
| 519 | 
         
            +
                config = {
         
     | 
| 520 | 
         
            +
                    "task_instructions": "",
         
     | 
| 521 | 
         
            +
                    "multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.",
         
     | 
| 522 | 
         
            +
                    "short_ans_example_format": "{}\nAnswer the question using a single word or phrase."
         
     | 
| 523 | 
         
            +
                }
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                question = sample['question'].strip()
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                options = eval(sample['options'])
         
     | 
| 529 | 
         
            +
                example = ""
         
     | 
| 530 | 
         
            +
                if sample['question_type'] == 'multiple-choice':
         
     | 
| 531 | 
         
            +
                    start_chr = 'A'
         
     | 
| 532 | 
         
            +
                    prediction_range = []
         
     | 
| 533 | 
         
            +
                    index2ans = {}
         
     | 
| 534 | 
         
            +
                    for option in options:
         
     | 
| 535 | 
         
            +
                        prediction_range.append(start_chr)
         
     | 
| 536 | 
         
            +
                        example += f"({start_chr}) {option}\n"
         
     | 
| 537 | 
         
            +
                        index2ans[start_chr] = option
         
     | 
| 538 | 
         
            +
                        start_chr = chr(ord(start_chr) + 1)
         
     | 
| 539 | 
         
            +
                    empty_prompt_sample_structure = config['multi_choice_example_format']
         
     | 
| 540 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question, example)
         
     | 
| 541 | 
         
            +
                    res_dict = {'type': 'multichoice'}
         
     | 
| 542 | 
         
            +
                    res_dict['index2ans'] = index2ans
         
     | 
| 543 | 
         
            +
                    res_dict['correct_choice'] = sample['answer']
         
     | 
| 544 | 
         
            +
                    res_dict['all_choices'] = prediction_range
         
     | 
| 545 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 546 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 547 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 548 | 
         
            +
                    else:
         
     | 
| 549 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
         
     | 
| 552 | 
         
            +
                else:
         
     | 
| 553 | 
         
            +
                    empty_prompt_sample_structure = config['short_ans_example_format']
         
     | 
| 554 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question)
         
     | 
| 555 | 
         
            +
                    res_dict = {'type': 'open'}
         
     | 
| 556 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 557 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 558 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 559 | 
         
            +
                    else:
         
     | 
| 560 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 561 | 
         
            +
                    res_dict['gt_content'] = sample['answer']
         
     | 
| 562 | 
         
            +
             
     | 
| 563 | 
         
            +
                res_dict.update(sample)
         
     | 
| 564 | 
         
            +
                return res_dict
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
            def process_image_tag(q):
         
     | 
| 568 | 
         
            +
                q = q.strip()
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                # heuristic way of removing <image 1>
         
     | 
| 571 | 
         
            +
                if q == '<image 1>':
         
     | 
| 572 | 
         
            +
                    q = 'Answer the question in the image.'
         
     | 
| 573 | 
         
            +
                elif ':<image 1>' in q:
         
     | 
| 574 | 
         
            +
                    q = q.replace(':<image 1>', ' in the image. ')
         
     | 
| 575 | 
         
            +
                    q = q.strip()
         
     | 
| 576 | 
         
            +
                elif ': <image 1>' in q:
         
     | 
| 577 | 
         
            +
                    q = q.replace(': <image 1>', ' in the image. ')
         
     | 
| 578 | 
         
            +
                    q = q.strip()
         
     | 
| 579 | 
         
            +
                elif '.<image 1>' in q or '. <image 1>' in q:
         
     | 
| 580 | 
         
            +
                    q_list = q.split('<image 1>')
         
     | 
| 581 | 
         
            +
                    q_list = [part.strip() for part in q_list if part.strip() != '']
         
     | 
| 582 | 
         
            +
                    q = ' '.join(q_list)
         
     | 
| 583 | 
         
            +
                elif q.startswith('<image 1> '):
         
     | 
| 584 | 
         
            +
                    if q[10].isupper():
         
     | 
| 585 | 
         
            +
                        q = q.replace('<image 1>', '')
         
     | 
| 586 | 
         
            +
                    else:
         
     | 
| 587 | 
         
            +
                        q = q.replace('<image 1>', 'The image')
         
     | 
| 588 | 
         
            +
                    q = q.strip()
         
     | 
| 589 | 
         
            +
                elif q.startswith('<image 1>'):
         
     | 
| 590 | 
         
            +
                    q = q.replace('<image 1>', '')
         
     | 
| 591 | 
         
            +
                elif q.endswith('<image 1>?'):
         
     | 
| 592 | 
         
            +
                    q = q.replace('<image 1>', 'the image')
         
     | 
| 593 | 
         
            +
                elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'):
         
     | 
| 594 | 
         
            +
                    q = q.replace('<image 1>', '')
         
     | 
| 595 | 
         
            +
                    q = q.strip()
         
     | 
| 596 | 
         
            +
                elif ' <image 1> ' in q:
         
     | 
| 597 | 
         
            +
                    q = q.replace('<image 1>', 'the image')
         
     | 
| 598 | 
         
            +
                elif ' <image 1>' in q:
         
     | 
| 599 | 
         
            +
                    q = q.replace('<image 1>', 'the image')
         
     | 
| 600 | 
         
            +
                elif '()<image 1>' in q:
         
     | 
| 601 | 
         
            +
                    q = q.replace('()<image 1>', '')
         
     | 
| 602 | 
         
            +
                elif '(<image 1>)' in q:
         
     | 
| 603 | 
         
            +
                    q = q.replace('(<image 1>)', '')
         
     | 
| 604 | 
         
            +
                elif '<image 1>.' in q:
         
     | 
| 605 | 
         
            +
                    q = q.replace("<image 1>.", ". ")
         
     | 
| 606 | 
         
            +
                else:
         
     | 
| 607 | 
         
            +
                    q = q.replace("<image 1>", ". ")
         
     | 
| 608 | 
         
            +
                    q = q.strip()
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                # remove <image 2> to <image 8>
         
     | 
| 611 | 
         
            +
                for i in range(2, 8):
         
     | 
| 612 | 
         
            +
                    q = q.replace(f"<image {i}>", "")
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                return q
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
            class MMMUProEvalDataset(Dataset):
         
     | 
| 618 | 
         
            +
                def __init__(self, args, task_cfg, subset=None):
         
     | 
| 619 | 
         
            +
                    self.args = args
         
     | 
| 620 | 
         
            +
                    self.task_cfg = task_cfg
         
     | 
| 621 | 
         
            +
                    sub_dataset_list = []
         
     | 
| 622 | 
         
            +
                    # load_dataset will throw error if split is 'dev'
         
     | 
| 623 | 
         
            +
                    # 'dev' is part of the 'validation' and we need to manually split them
         
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
                    MMMU_path = "MMMU/MMMU_Pro"
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                    _split = "test"
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                    self.dataset = load_dataset(MMMU_path, "standard", split=_split)
         
     | 
| 630 | 
         
            +
                    if subset:
         
     | 
| 631 | 
         
            +
                        self.dataset = self.dataset[:subset]
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                def __len__(self):
         
     | 
| 634 | 
         
            +
                    return len(self.dataset)
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 637 | 
         
            +
                    # ===== single-image =====
         
     | 
| 638 | 
         
            +
                    sample = self.dataset[idx]
         
     | 
| 639 | 
         
            +
                    sample = process_single_sample_pro(sample)
         
     | 
| 640 | 
         
            +
                    sample = construct_prompt_pro(sample, self.task_cfg)
         
     | 
| 641 | 
         
            +
                    img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                    # img = img.reshape(-1, 3, self.args.img_h, self.args.img_w)
         
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
                    question_id = sample['id']
         
     | 
| 646 | 
         
            +
                    question = sample['final_input_prompt']
         
     | 
| 647 | 
         
            +
                    answer = sample['answer']
         
     | 
| 648 | 
         
            +
             
     | 
| 649 | 
         
            +
                    question = process_image_tag(question)
         
     | 
| 650 | 
         
            +
                    question = self.task_cfg['default_image_token'] + '\n' + question
         
     | 
| 651 | 
         
            +
             
     | 
| 652 | 
         
            +
                    if sample['question_type'] == 'multiple-choice':
         
     | 
| 653 | 
         
            +
                        index2ans = sample['index2ans']
         
     | 
| 654 | 
         
            +
                        all_choices = sample['all_choices']
         
     | 
| 655 | 
         
            +
                    else:
         
     | 
| 656 | 
         
            +
                        index2ans = {}
         
     | 
| 657 | 
         
            +
                        all_choices = []
         
     | 
| 658 | 
         
            +
             
     | 
| 659 | 
         
            +
                    return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
         
     | 
| 660 | 
         
            +
                        (all_choices)
         
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
            class MMMUEvalDataset(Dataset):
         
     | 
| 664 | 
         
            +
                def __init__(self, args, task_cfg, subset=None, start_idx=None):
         
     | 
| 665 | 
         
            +
                    self.args = args
         
     | 
| 666 | 
         
            +
                    self.task_cfg = task_cfg
         
     | 
| 667 | 
         
            +
                    sub_dataset_list = []
         
     | 
| 668 | 
         
            +
                    # load_dataset will throw error if split is 'dev'
         
     | 
| 669 | 
         
            +
                    # 'dev' is part of the 'validation' and we need to manually split them
         
     | 
| 670 | 
         
            +
             
     | 
| 671 | 
         
            +
                    MMMU_path = "MMMU/MMMU"
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
                    _split = "test" if task_cfg["split"] == "test" else "validation"
         
     | 
| 674 | 
         
            +
                    for subject in CAT_SHORT2LONG.values():
         
     | 
| 675 | 
         
            +
                        sub_dataset = load_dataset(
         
     | 
| 676 | 
         
            +
                            MMMU_path, subject,
         
     | 
| 677 | 
         
            +
                            split=_split,
         
     | 
| 678 | 
         
            +
                        )
         
     | 
| 679 | 
         
            +
                        sub_dataset_list.append(sub_dataset)
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                    dataset = concatenate_datasets(sub_dataset_list)
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                    if task_cfg["split"] != "test":
         
     | 
| 684 | 
         
            +
                        dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])]
         
     | 
| 685 | 
         
            +
             
     | 
| 686 | 
         
            +
                    # dataset = [s for s in dataset if s['image_2'] is not None][1:]
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                    self.dataset = dataset
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                    if subset:
         
     | 
| 691 | 
         
            +
                        self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))]
         
     | 
| 692 | 
         
            +
                        print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}")
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                def __len__(self):
         
     | 
| 695 | 
         
            +
                    return len(self.dataset)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 698 | 
         
            +
                    # ===== single-image =====
         
     | 
| 699 | 
         
            +
                    sample = self.dataset[idx]
         
     | 
| 700 | 
         
            +
                    sample = process_single_sample(sample)
         
     | 
| 701 | 
         
            +
                    sample = construct_prompt(sample, self.task_cfg)
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
                    img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                    question_id = sample['id']
         
     | 
| 706 | 
         
            +
                    question = sample['final_input_prompt']
         
     | 
| 707 | 
         
            +
                    answer = sample['answer']
         
     | 
| 708 | 
         
            +
             
     | 
| 709 | 
         
            +
                    question = process_image_tag(question)
         
     | 
| 710 | 
         
            +
                    question = self.task_cfg['default_image_token'] + '\n' + question
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                    if sample['question_type'] == 'multiple-choice':
         
     | 
| 714 | 
         
            +
                        index2ans = sample['index2ans']
         
     | 
| 715 | 
         
            +
                        all_choices = sample['all_choices']
         
     | 
| 716 | 
         
            +
                    else:
         
     | 
| 717 | 
         
            +
                        index2ans = {}
         
     | 
| 718 | 
         
            +
                        all_choices = []
         
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
                    return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
         
     | 
| 721 | 
         
            +
                        (all_choices)
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
             
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
            class VizWizEvalDataset(Dataset):
         
     | 
| 726 | 
         
            +
                def __init__(self, args, img_dir, question_path, subset=None):
         
     | 
| 727 | 
         
            +
                    self.args = args
         
     | 
| 728 | 
         
            +
                    self.img_dir = img_dir
         
     | 
| 729 | 
         
            +
                    self.questions = json.load(open(question_path, encoding='utf-8'))
         
     | 
| 730 | 
         
            +
             
     | 
| 731 | 
         
            +
                def __len__(self):
         
     | 
| 732 | 
         
            +
                    return len(self.questions)
         
     | 
| 733 | 
         
            +
             
     | 
| 734 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 735 | 
         
            +
                    img_path = os.path.join(self.img_dir, self.questions[idx]["image"])
         
     | 
| 736 | 
         
            +
                    img = load_image(img_path, max_num=6).to(torch.bfloat16)
         
     | 
| 737 | 
         
            +
                    question = self.questions[idx]["question"]
         
     | 
| 738 | 
         
            +
                    question_id = self.questions[idx]["image"]
         
     | 
| 739 | 
         
            +
             
     | 
| 740 | 
         
            +
                    return img, question_id, question
         
     | 
| 741 | 
         
            +
             
     | 
| 742 | 
         
            +
             
     | 
| 743 | 
         
            +
            class MMBenchEvalDataset(Dataset):
         
     | 
| 744 | 
         
            +
                def __init__(self, args, gt_path, subset=None):
         
     | 
| 745 | 
         
            +
                    self.args = args
         
     | 
| 746 | 
         
            +
                    df = pd.read_csv(gt_path, sep='\t')
         
     | 
| 747 | 
         
            +
                    self.dataset = []
         
     | 
| 748 | 
         
            +
                    for i, row in df.iterrows():
         
     | 
| 749 | 
         
            +
                        choices = []
         
     | 
| 750 | 
         
            +
                        for choice in ['A', 'B', 'C', 'D']:
         
     | 
| 751 | 
         
            +
                            if str(row[choice]) != 'nan':
         
     | 
| 752 | 
         
            +
                                choices.append(row[choice])
         
     | 
| 753 | 
         
            +
             
     | 
| 754 | 
         
            +
                        this_sample = {
         
     | 
| 755 | 
         
            +
                            'index': row['index'],
         
     | 
| 756 | 
         
            +
                            'question': row['question'],
         
     | 
| 757 | 
         
            +
                            'hint': row['hint'],
         
     | 
| 758 | 
         
            +
                            'category': row['category'],
         
     | 
| 759 | 
         
            +
                            'image': Image.open(BytesIO(base64.b64decode(row['image']))),
         
     | 
| 760 | 
         
            +
                            'choices': choices
         
     | 
| 761 | 
         
            +
                        }
         
     | 
| 762 | 
         
            +
             
     | 
| 763 | 
         
            +
                        # Only dev set gives the ground truth answer
         
     | 
| 764 | 
         
            +
                        if 'answer' in row.keys():
         
     | 
| 765 | 
         
            +
                            this_sample['answer'] = row['answer']
         
     | 
| 766 | 
         
            +
                        else:
         
     | 
| 767 | 
         
            +
                            this_sample['answer'] = ''
         
     | 
| 768 | 
         
            +
             
     | 
| 769 | 
         
            +
                        self.dataset.append(this_sample)
         
     | 
| 770 | 
         
            +
             
     | 
| 771 | 
         
            +
                def __len__(self):
         
     | 
| 772 | 
         
            +
                    return len(self.dataset)
         
     | 
| 773 | 
         
            +
             
     | 
| 774 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 775 | 
         
            +
                    img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
         
     | 
| 776 | 
         
            +
             
     | 
| 777 | 
         
            +
                    question = self.dataset[idx]["question"]
         
     | 
| 778 | 
         
            +
                    hint = self.dataset[idx]["hint"]
         
     | 
| 779 | 
         
            +
                    question_id = self.dataset[idx]["index"]
         
     | 
| 780 | 
         
            +
                    choices = self.dataset[idx]["choices"]
         
     | 
| 781 | 
         
            +
                    answer = self.dataset[idx]["answer"]
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                    start_chr = 'A'
         
     | 
| 784 | 
         
            +
                    choices_str = ''
         
     | 
| 785 | 
         
            +
                    index2ans = {}
         
     | 
| 786 | 
         
            +
                    all_choices = []
         
     | 
| 787 | 
         
            +
                    for choice in choices:
         
     | 
| 788 | 
         
            +
                        all_choices.append(start_chr)
         
     | 
| 789 | 
         
            +
                        index2ans[start_chr] = choice
         
     | 
| 790 | 
         
            +
                        choices_str += f"{start_chr}. {choice}\n"
         
     | 
| 791 | 
         
            +
                        start_chr = chr(ord(start_chr) + 1)
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                    question = question + '\n' + choices_str
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                    return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"]
         
     | 
| 796 | 
         
            +
             
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
            def get_task_dataloader(task_name, task_cfg, args):
         
     | 
| 799 | 
         
            +
                if "subset" in task_cfg.keys():
         
     | 
| 800 | 
         
            +
                    subset = task_cfg["subset"]
         
     | 
| 801 | 
         
            +
                else:
         
     | 
| 802 | 
         
            +
                    subset = None
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                if task_name == "coco_caption":
         
     | 
| 805 | 
         
            +
                    dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset)
         
     | 
| 806 | 
         
            +
                elif task_name == "flickr30k_caption":
         
     | 
| 807 | 
         
            +
                    dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset)
         
     | 
| 808 | 
         
            +
                elif task_name == "vqav2":
         
     | 
| 809 | 
         
            +
                    dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 810 | 
         
            +
                elif task_name == "textvqa":
         
     | 
| 811 | 
         
            +
                    dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 812 | 
         
            +
                elif task_name == "gqa":
         
     | 
| 813 | 
         
            +
                    dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 814 | 
         
            +
                elif task_name == "chartqa":
         
     | 
| 815 | 
         
            +
                    dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 816 | 
         
            +
                elif task_name == "okvqa":
         
     | 
| 817 | 
         
            +
                    dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset)
         
     | 
| 818 | 
         
            +
                elif task_name == "vizwiz":
         
     | 
| 819 | 
         
            +
                    dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset)
         
     | 
| 820 | 
         
            +
                elif task_name == "docvqa":
         
     | 
| 821 | 
         
            +
                    dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset)
         
     | 
| 822 | 
         
            +
                elif task_name == "docvqa_test":
         
     | 
| 823 | 
         
            +
                    dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset)
         
     | 
| 824 | 
         
            +
                elif task_name == "realworldqa":
         
     | 
| 825 | 
         
            +
                    dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 826 | 
         
            +
                elif task_name == "mmmu":
         
     | 
| 827 | 
         
            +
                    dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx)
         
     | 
| 828 | 
         
            +
                elif task_name == "mmmu_pro":
         
     | 
| 829 | 
         
            +
                    dataset = MMMUProEvalDataset(args, task_cfg)
         
     | 
| 830 | 
         
            +
                elif task_name == "mathvista":
         
     | 
| 831 | 
         
            +
                    dataset = MathVistaEvalDataset(args, task_cfg)
         
     | 
| 832 | 
         
            +
                elif task_name == "mmbench":
         
     | 
| 833 | 
         
            +
                    dataset = MMBenchEvalDataset(args, task_cfg["gt_path"])
         
     | 
| 834 | 
         
            +
                elif task_name == 'ocrbench':
         
     | 
| 835 | 
         
            +
                    dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 836 | 
         
            +
                elif task_name == 'ai2diagram':
         
     | 
| 837 | 
         
            +
                    dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 838 | 
         
            +
                elif task_name == 'ai2diagram_nomask':
         
     | 
| 839 | 
         
            +
                    dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
         
     | 
| 840 | 
         
            +
                else:
         
     | 
| 841 | 
         
            +
                    raise NotImplementedError(f"Task {task_name} is not supported yet.")
         
     | 
| 842 | 
         
            +
             
     | 
| 843 | 
         
            +
                dataloader = DataLoader(
         
     | 
| 844 | 
         
            +
                    dataset,
         
     | 
| 845 | 
         
            +
                    batch_size=1,
         
     | 
| 846 | 
         
            +
                    shuffle=False,
         
     | 
| 847 | 
         
            +
                    pin_memory=True,
         
     | 
| 848 | 
         
            +
                )
         
     | 
| 849 | 
         
            +
             
     | 
| 850 | 
         
            +
                return dataloader
         
     | 
    	
        eval/full_eval.yaml
    ADDED
    
    | 
         @@ -0,0 +1,188 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            datasets:
         
     | 
| 2 | 
         
            +
              coco_caption:
         
     | 
| 3 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 4 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 5 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
         
     | 
| 6 | 
         
            +
                beam_search: True
         
     | 
| 7 | 
         
            +
                beam_size: 1
         
     | 
| 8 | 
         
            +
                output_max_len: 30
         
     | 
| 9 | 
         
            +
                top_k: 3
         
     | 
| 10 | 
         
            +
                temperature: 1.0
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
              flickr30k_caption:
         
     | 
| 13 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 14 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 15 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
         
     | 
| 16 | 
         
            +
                beam_search: True
         
     | 
| 17 | 
         
            +
                beam_size: 1
         
     | 
| 18 | 
         
            +
                output_max_len: 30
         
     | 
| 19 | 
         
            +
                top_k: 3
         
     | 
| 20 | 
         
            +
                temperature: 1.0
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
              vqav2:
         
     | 
| 23 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 24 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 25 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word or phrase.<|im_end|><|im_start|>assistant\n"
         
     | 
| 26 | 
         
            +
                beam_search: True
         
     | 
| 27 | 
         
            +
                beam_size: 1
         
     | 
| 28 | 
         
            +
                top_k: 1
         
     | 
| 29 | 
         
            +
                top_p: 0.0
         
     | 
| 30 | 
         
            +
                output_max_len: 8
         
     | 
| 31 | 
         
            +
                temperature: 1.0
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
              mmmu:
         
     | 
| 34 | 
         
            +
                split: "validation"
         
     | 
| 35 | 
         
            +
                beam_search: True
         
     | 
| 36 | 
         
            +
                beam_size: 1
         
     | 
| 37 | 
         
            +
                top_k: 1
         
     | 
| 38 | 
         
            +
                top_p: 0.0
         
     | 
| 39 | 
         
            +
                output_max_len: 1024
         
     | 
| 40 | 
         
            +
                temperature: 1.0
         
     | 
| 41 | 
         
            +
                apply_lemmatizer: False
         
     | 
| 42 | 
         
            +
                task_instructions: ""
         
     | 
| 43 | 
         
            +
                multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
         
     | 
| 44 | 
         
            +
                short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
         
     | 
| 45 | 
         
            +
                use_chat_format: True
         
     | 
| 46 | 
         
            +
                conv_format: "yi_nous_sft"
         
     | 
| 47 | 
         
            +
                default_image_token: "<image>"
         
     | 
| 48 | 
         
            +
                prompt_offset: 4
         
     | 
| 49 | 
         
            +
                answer_dict: "path/to/answer_dict_val.json"
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
              textvqa:
         
     | 
| 52 | 
         
            +
                split: "val"
         
     | 
| 53 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 54 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 55 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
         
     | 
| 56 | 
         
            +
                beam_search: True
         
     | 
| 57 | 
         
            +
                beam_size: 1
         
     | 
| 58 | 
         
            +
                top_k: 1
         
     | 
| 59 | 
         
            +
                top_p: 0.0
         
     | 
| 60 | 
         
            +
                output_max_len: 10
         
     | 
| 61 | 
         
            +
                temperature: 1.0
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
              mathvista:
         
     | 
| 64 | 
         
            +
                split: "testmini"
         
     | 
| 65 | 
         
            +
                prompt: "<|im_start|>system\nYou are math expert. Use your math knowledge to calculate the answer.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
         
     | 
| 66 | 
         
            +
                beam_search: True
         
     | 
| 67 | 
         
            +
                beam_size: 1
         
     | 
| 68 | 
         
            +
                top_k: 1
         
     | 
| 69 | 
         
            +
                top_p: 0.0
         
     | 
| 70 | 
         
            +
                output_max_len: 1024
         
     | 
| 71 | 
         
            +
                temperature: 1.0
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
              mmbench:
         
     | 
| 74 | 
         
            +
                split: "dev"
         
     | 
| 75 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 76 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}Answer with the option's letter from the given choices directly.<|im_end|><|im_start|>assistant\n"
         
     | 
| 77 | 
         
            +
                beam_search: True
         
     | 
| 78 | 
         
            +
                beam_size: 1
         
     | 
| 79 | 
         
            +
                top_k: 1
         
     | 
| 80 | 
         
            +
                top_p: 0.0
         
     | 
| 81 | 
         
            +
                output_max_len: 10
         
     | 
| 82 | 
         
            +
                temperature: 1.0
         
     | 
| 83 | 
         
            +
                submission: False
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
              chartqa:
         
     | 
| 86 | 
         
            +
                split: "test"
         
     | 
| 87 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 88 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 89 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
         
     | 
| 90 | 
         
            +
                
         
     | 
| 91 | 
         
            +
                beam_search: True
         
     | 
| 92 | 
         
            +
                beam_size: 1
         
     | 
| 93 | 
         
            +
                top_k: 1
         
     | 
| 94 | 
         
            +
                top_p: 0.0
         
     | 
| 95 | 
         
            +
                output_max_len: 20
         
     | 
| 96 | 
         
            +
                temperature: 1.0
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
              docvqa:
         
     | 
| 99 | 
         
            +
                split: "val"
         
     | 
| 100 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 101 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 102 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
         
     | 
| 103 | 
         
            +
                beam_search: True
         
     | 
| 104 | 
         
            +
                beam_size: 1
         
     | 
| 105 | 
         
            +
                top_k: 1
         
     | 
| 106 | 
         
            +
                top_p: 0.0
         
     | 
| 107 | 
         
            +
                output_max_len: 20
         
     | 
| 108 | 
         
            +
                temperature: 1.0
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
              realworldqa:
         
     | 
| 111 | 
         
            +
                split: "test"
         
     | 
| 112 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 113 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 114 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
         
     | 
| 115 | 
         
            +
                beam_search: True
         
     | 
| 116 | 
         
            +
                beam_size: 1
         
     | 
| 117 | 
         
            +
                top_k: 1
         
     | 
| 118 | 
         
            +
                top_p: 0.0
         
     | 
| 119 | 
         
            +
                output_max_len: 20
         
     | 
| 120 | 
         
            +
                temperature: 1.0
         
     | 
| 121 | 
         
            +
                submission: False
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
              ocrbench:
         
     | 
| 124 | 
         
            +
                split: "test"
         
     | 
| 125 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 126 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 127 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
         
     | 
| 128 | 
         
            +
                beam_search: True
         
     | 
| 129 | 
         
            +
                beam_size: 1
         
     | 
| 130 | 
         
            +
                top_k: 1
         
     | 
| 131 | 
         
            +
                top_p: 0.0
         
     | 
| 132 | 
         
            +
                output_max_len: 70
         
     | 
| 133 | 
         
            +
                temperature: 1.0
         
     | 
| 134 | 
         
            +
                submission: False
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
              ai2diagram:
         
     | 
| 137 | 
         
            +
                split: "test"
         
     | 
| 138 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 139 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 140 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
         
     | 
| 141 | 
         
            +
                beam_search: True
         
     | 
| 142 | 
         
            +
                beam_size: 1
         
     | 
| 143 | 
         
            +
                top_k: 1
         
     | 
| 144 | 
         
            +
                top_p: 0.0
         
     | 
| 145 | 
         
            +
                output_max_len: 20
         
     | 
| 146 | 
         
            +
                temperature: 1.0
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
              ai2diagram_nomask:
         
     | 
| 149 | 
         
            +
                split: "test"
         
     | 
| 150 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 151 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 152 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
         
     | 
| 153 | 
         
            +
                beam_search: True
         
     | 
| 154 | 
         
            +
                beam_size: 1
         
     | 
| 155 | 
         
            +
                top_k: 1
         
     | 
| 156 | 
         
            +
                top_p: 0.0
         
     | 
| 157 | 
         
            +
                output_max_len: 20
         
     | 
| 158 | 
         
            +
                temperature: 1.0
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
              mmmu_pro:
         
     | 
| 161 | 
         
            +
                split: "validation"
         
     | 
| 162 | 
         
            +
                beam_search: True
         
     | 
| 163 | 
         
            +
                beam_size: 1
         
     | 
| 164 | 
         
            +
                top_k: 1
         
     | 
| 165 | 
         
            +
                top_p: 0.0
         
     | 
| 166 | 
         
            +
                output_max_len: 10
         
     | 
| 167 | 
         
            +
                temperature: 1.0
         
     | 
| 168 | 
         
            +
                apply_lemmatizer: False
         
     | 
| 169 | 
         
            +
                task_instructions: ""
         
     | 
| 170 | 
         
            +
                multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
         
     | 
| 171 | 
         
            +
                short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
         
     | 
| 172 | 
         
            +
                use_chat_format: True
         
     | 
| 173 | 
         
            +
                conv_format: "yi_nous_sft"
         
     | 
| 174 | 
         
            +
                default_image_token: "<image>"
         
     | 
| 175 | 
         
            +
                prompt_offset: 4
         
     | 
| 176 | 
         
            +
                answer_dict: "path/to/answer_dict.json"
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
              docvqa_test:
         
     | 
| 179 | 
         
            +
                split: "test"
         
     | 
| 180 | 
         
            +
                image_dir: "path/to/image"
         
     | 
| 181 | 
         
            +
                gt_path: "path/to/ground_truth"
         
     | 
| 182 | 
         
            +
                prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|>\n<|im_start|>user\n<image>\n{}\nAnswer this question using the text in the image directly.<|im_end|>\n<|im_start|>assistant\n"
         
     | 
| 183 | 
         
            +
                beam_search: True
         
     | 
| 184 | 
         
            +
                beam_size: 1
         
     | 
| 185 | 
         
            +
                top_k: 1
         
     | 
| 186 | 
         
            +
                top_p: 0.0
         
     | 
| 187 | 
         
            +
                output_max_len: 20
         
     | 
| 188 | 
         
            +
                temperature: 1.0
         
     | 
    	
        eval/mmmu_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,663 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/utils/data_utils.py
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            """Utils for data load, save, and process (e.g., prompt construction)"""
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            import json
         
     | 
| 7 | 
         
            +
            import yaml
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            DOMAIN_CAT2SUB_CAT = {
         
     | 
| 11 | 
         
            +
                'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
         
     | 
| 12 | 
         
            +
                'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
         
     | 
| 13 | 
         
            +
                'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
         
     | 
| 14 | 
         
            +
                'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
         
     | 
| 15 | 
         
            +
                                        'Pharmacy', 'Public_Health'],
         
     | 
| 16 | 
         
            +
                'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
         
     | 
| 17 | 
         
            +
                'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
         
     | 
| 18 | 
         
            +
                                         'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            CAT_SHORT2LONG = {
         
     | 
| 22 | 
         
            +
                'acc': 'Accounting',
         
     | 
| 23 | 
         
            +
                'agri': 'Agriculture',
         
     | 
| 24 | 
         
            +
                'arch': 'Architecture_and_Engineering',
         
     | 
| 25 | 
         
            +
                'art': 'Art',
         
     | 
| 26 | 
         
            +
                'art_theory': 'Art_Theory',
         
     | 
| 27 | 
         
            +
                'bas_med': 'Basic_Medical_Science',
         
     | 
| 28 | 
         
            +
                'bio': 'Biology',
         
     | 
| 29 | 
         
            +
                'chem': 'Chemistry',
         
     | 
| 30 | 
         
            +
                'cli_med': 'Clinical_Medicine',
         
     | 
| 31 | 
         
            +
                'cs': 'Computer_Science',
         
     | 
| 32 | 
         
            +
                'design': 'Design',
         
     | 
| 33 | 
         
            +
                'diag_med': 'Diagnostics_and_Laboratory_Medicine',
         
     | 
| 34 | 
         
            +
                'econ': 'Economics',
         
     | 
| 35 | 
         
            +
                'elec': 'Electronics',
         
     | 
| 36 | 
         
            +
                'ep': 'Energy_and_Power',
         
     | 
| 37 | 
         
            +
                'fin': 'Finance',
         
     | 
| 38 | 
         
            +
                'geo': 'Geography',
         
     | 
| 39 | 
         
            +
                'his': 'History',
         
     | 
| 40 | 
         
            +
                'liter': 'Literature',
         
     | 
| 41 | 
         
            +
                'manage': 'Manage',
         
     | 
| 42 | 
         
            +
                'mark': 'Marketing',
         
     | 
| 43 | 
         
            +
                'mate': 'Materials',
         
     | 
| 44 | 
         
            +
                'math': 'Math',
         
     | 
| 45 | 
         
            +
                'mech': 'Mechanical_Engineering',
         
     | 
| 46 | 
         
            +
                'music': 'Music',
         
     | 
| 47 | 
         
            +
                'phar': 'Pharmacy',
         
     | 
| 48 | 
         
            +
                'phys': 'Physics',
         
     | 
| 49 | 
         
            +
                'psy': 'Psychology',
         
     | 
| 50 | 
         
            +
                'pub_health': 'Public_Health',
         
     | 
| 51 | 
         
            +
                'socio': 'Sociology'
         
     | 
| 52 | 
         
            +
            }
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            # DATA SAVING
         
     | 
| 56 | 
         
            +
            def save_json(filename, ds):
         
     | 
| 57 | 
         
            +
                with open(filename, 'w') as f:
         
     | 
| 58 | 
         
            +
                    json.dump(ds, f, indent=4)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def get_multi_choice_info(options):
         
     | 
| 62 | 
         
            +
                """
         
     | 
| 63 | 
         
            +
                Given the list of options for multiple choice question
         
     | 
| 64 | 
         
            +
                Return the index2ans and all_choices
         
     | 
| 65 | 
         
            +
                """
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                start_chr = 'A'
         
     | 
| 68 | 
         
            +
                all_choices = []
         
     | 
| 69 | 
         
            +
                index2ans = {}
         
     | 
| 70 | 
         
            +
                for i, option in enumerate(options):
         
     | 
| 71 | 
         
            +
                    index2ans[chr(ord(start_chr) + i)] = option
         
     | 
| 72 | 
         
            +
                    all_choices.append(chr(ord(start_chr) + i))
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                return index2ans, all_choices
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def load_yaml(file_path):
         
     | 
| 78 | 
         
            +
                with open(file_path, 'r') as stream:
         
     | 
| 79 | 
         
            +
                    try:
         
     | 
| 80 | 
         
            +
                        yaml_dict = yaml.safe_load(stream)
         
     | 
| 81 | 
         
            +
                    except yaml.YAMLError as exc:
         
     | 
| 82 | 
         
            +
                        print(exc)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                return yaml_dict
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def parse_img_path(text):
         
     | 
| 88 | 
         
            +
                matches = re.findall("<img='(.*?)'>", text)
         
     | 
| 89 | 
         
            +
                return matches
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def process_single_sample(data):
         
     | 
| 93 | 
         
            +
                question = data['question']
         
     | 
| 94 | 
         
            +
                o_imgs_paths = []
         
     | 
| 95 | 
         
            +
                for option in data['options']:
         
     | 
| 96 | 
         
            +
                    current_o_imgs_paths = parse_img_path(option)
         
     | 
| 97 | 
         
            +
                    for img_path in current_o_imgs_paths:
         
     | 
| 98 | 
         
            +
                        o_imgs_paths.append(img_path)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                if len(o_imgs_paths) > 1:  # multiple images in options, used for random selection
         
     | 
| 101 | 
         
            +
                    return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
         
     | 
| 102 | 
         
            +
                            'image': None, 'question_type': data['question_type'], 'subfield': data['subfield']}
         
     | 
| 103 | 
         
            +
                else:
         
     | 
| 104 | 
         
            +
                    return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
         
     | 
| 105 | 
         
            +
                            'image': data['image_1'], 'question_type': data['question_type'], 'subfield': data['subfield']}
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def process_single_sample_pro(data):
         
     | 
| 109 | 
         
            +
                question = data['question']
         
     | 
| 110 | 
         
            +
                o_imgs_paths = []
         
     | 
| 111 | 
         
            +
                for option in data['options']:
         
     | 
| 112 | 
         
            +
                    current_o_imgs_paths = parse_img_path(option)
         
     | 
| 113 | 
         
            +
                    for img_path in current_o_imgs_paths:
         
     | 
| 114 | 
         
            +
                        o_imgs_paths.append(img_path)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                if len(o_imgs_paths) > 1:  # multiple images in options, used for random selection
         
     | 
| 117 | 
         
            +
                    return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
         
     | 
| 118 | 
         
            +
                            'image': None, 'question_type': 'multiple-choice', 'subfield': data['subject']}
         
     | 
| 119 | 
         
            +
                else:
         
     | 
| 120 | 
         
            +
                    return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
         
     | 
| 121 | 
         
            +
                            'image': data['image_1'], 'question_type': 'multiple-choice', 'subfield': data['subject']}
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            # DATA SAVING
         
     | 
| 125 | 
         
            +
            def save_json(filename, ds):
         
     | 
| 126 | 
         
            +
                with open(filename, 'w') as f:
         
     | 
| 127 | 
         
            +
                    json.dump(ds, f, indent=4)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def save_jsonl(filename, data):
         
     | 
| 131 | 
         
            +
                """
         
     | 
| 132 | 
         
            +
                Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                Args:
         
     | 
| 135 | 
         
            +
                    filename (str): The path to the file where the data should be saved.
         
     | 
| 136 | 
         
            +
                    data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
         
     | 
| 137 | 
         
            +
                """
         
     | 
| 138 | 
         
            +
                with open(filename, 'w', encoding='utf-8') as f:
         
     | 
| 139 | 
         
            +
                    for img_path, caption in data.items():
         
     | 
| 140 | 
         
            +
                        # Extract the base filename without the extension
         
     | 
| 141 | 
         
            +
                        base_filename = os.path.basename(img_path)
         
     | 
| 142 | 
         
            +
                        # Create a JSON object with the filename as the key and caption as the value
         
     | 
| 143 | 
         
            +
                        json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
         
     | 
| 144 | 
         
            +
                        # Write the JSON object to the file, one per line
         
     | 
| 145 | 
         
            +
                        f.write(json_record + '\n')
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            def save_args(args, path_dir):
         
     | 
| 149 | 
         
            +
                argsDict = args.__dict__
         
     | 
| 150 | 
         
            +
                with open(path_dir + 'setting.txt', 'w') as f:
         
     | 
| 151 | 
         
            +
                    f.writelines('------------------ start ------------------' + '\n')
         
     | 
| 152 | 
         
            +
                    for eachArg, value in argsDict.items():
         
     | 
| 153 | 
         
            +
                        f.writelines(eachArg + ' : ' + str(value) + '\n')
         
     | 
| 154 | 
         
            +
                    f.writelines('------------------- end -------------------')
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            # DATA PROCESSING
         
     | 
| 158 | 
         
            +
            def construct_prompt(sample, config):
         
     | 
| 159 | 
         
            +
                question = sample['question'].strip()
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                # for i in range(8):
         
     | 
| 162 | 
         
            +
                #     question = question.replace(f" <image {i}> ", " ")
         
     | 
| 163 | 
         
            +
                #     question = question.replace(f" <image {i}>", " ")
         
     | 
| 164 | 
         
            +
                #     question = question.replace(f"<image {i}> ", " ")
         
     | 
| 165 | 
         
            +
                #     question = question.replace(f"<image {i}>", " ")
         
     | 
| 166 | 
         
            +
                #     question = question.strip()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                options = eval(sample['options'])
         
     | 
| 169 | 
         
            +
                example = ""
         
     | 
| 170 | 
         
            +
                if sample['question_type'] == 'multiple-choice':
         
     | 
| 171 | 
         
            +
                    start_chr = 'A'
         
     | 
| 172 | 
         
            +
                    prediction_range = []
         
     | 
| 173 | 
         
            +
                    index2ans = {}
         
     | 
| 174 | 
         
            +
                    for option in options:
         
     | 
| 175 | 
         
            +
                        prediction_range.append(start_chr)
         
     | 
| 176 | 
         
            +
                        example += f"({start_chr}) {option}\n"
         
     | 
| 177 | 
         
            +
                        # example += f"{start_chr}. {option}\n"
         
     | 
| 178 | 
         
            +
                        index2ans[start_chr] = option
         
     | 
| 179 | 
         
            +
                        start_chr = chr(ord(start_chr) + 1)
         
     | 
| 180 | 
         
            +
                    # example = example.rstrip()
         
     | 
| 181 | 
         
            +
                    empty_prompt_sample_structure = config['multi_choice_example_format']
         
     | 
| 182 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question, example)
         
     | 
| 183 | 
         
            +
                    res_dict = {'type': 'multichoice'}
         
     | 
| 184 | 
         
            +
                    res_dict['index2ans'] = index2ans
         
     | 
| 185 | 
         
            +
                    res_dict['correct_choice'] = sample['answer']
         
     | 
| 186 | 
         
            +
                    res_dict['all_choices'] = prediction_range
         
     | 
| 187 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 188 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 189 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 190 | 
         
            +
                    else:
         
     | 
| 191 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
         
     | 
| 194 | 
         
            +
                else:
         
     | 
| 195 | 
         
            +
                    empty_prompt_sample_structure = config['short_ans_example_format']
         
     | 
| 196 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question)
         
     | 
| 197 | 
         
            +
                    res_dict = {'type': 'open'}
         
     | 
| 198 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 199 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 200 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 201 | 
         
            +
                    else:
         
     | 
| 202 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 203 | 
         
            +
                    res_dict['gt_content'] = sample['answer']
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                res_dict.update(sample)
         
     | 
| 206 | 
         
            +
                return res_dict
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            def construct_prompt_pro(sample, config):
         
     | 
| 210 | 
         
            +
                question = sample['question'].strip()
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                # for i in range(8):
         
     | 
| 213 | 
         
            +
                #     question = question.replace(f" <image {i}> ", " ")
         
     | 
| 214 | 
         
            +
                #     question = question.replace(f" <image {i}>", " ")
         
     | 
| 215 | 
         
            +
                #     question = question.replace(f"<image {i}> ", " ")
         
     | 
| 216 | 
         
            +
                #     question = question.replace(f"<image {i}>", " ")
         
     | 
| 217 | 
         
            +
                #     question = question.strip()
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                options = eval(sample['options'])
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                if len(options) == 1:
         
     | 
| 222 | 
         
            +
                    print("This is wrongly formated. We correct to options[0].")
         
     | 
| 223 | 
         
            +
                    options = options[0]
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                example = ""
         
     | 
| 226 | 
         
            +
                if sample['question_type'] == 'multiple-choice':
         
     | 
| 227 | 
         
            +
                    start_chr = 'A'
         
     | 
| 228 | 
         
            +
                    prediction_range = []
         
     | 
| 229 | 
         
            +
                    index2ans = {}
         
     | 
| 230 | 
         
            +
                    for option in options:
         
     | 
| 231 | 
         
            +
                        prediction_range.append(start_chr)
         
     | 
| 232 | 
         
            +
                        example += f"({start_chr}) {option}\n"
         
     | 
| 233 | 
         
            +
                        # example += f"{start_chr}. {option}\n"
         
     | 
| 234 | 
         
            +
                        index2ans[start_chr] = option
         
     | 
| 235 | 
         
            +
                        start_chr = chr(ord(start_chr) + 1)
         
     | 
| 236 | 
         
            +
                    # example = example.rstrip()
         
     | 
| 237 | 
         
            +
                    empty_prompt_sample_structure = config['multi_choice_example_format']
         
     | 
| 238 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question, example)
         
     | 
| 239 | 
         
            +
                    res_dict = {'type': 'multichoice'}
         
     | 
| 240 | 
         
            +
                    res_dict['index2ans'] = index2ans
         
     | 
| 241 | 
         
            +
                    res_dict['correct_choice'] = sample['answer']
         
     | 
| 242 | 
         
            +
                    res_dict['all_choices'] = prediction_range
         
     | 
| 243 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 244 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 245 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 246 | 
         
            +
                    else:
         
     | 
| 247 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
         
     | 
| 250 | 
         
            +
                else:
         
     | 
| 251 | 
         
            +
                    empty_prompt_sample_structure = config['short_ans_example_format']
         
     | 
| 252 | 
         
            +
                    empty_prompt = empty_prompt_sample_structure.format(question)
         
     | 
| 253 | 
         
            +
                    res_dict = {'type': 'open'}
         
     | 
| 254 | 
         
            +
                    res_dict['empty_prompt'] = empty_prompt
         
     | 
| 255 | 
         
            +
                    if config['task_instructions']:
         
     | 
| 256 | 
         
            +
                        res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
         
     | 
| 257 | 
         
            +
                    else:
         
     | 
| 258 | 
         
            +
                        res_dict['final_input_prompt'] = empty_prompt
         
     | 
| 259 | 
         
            +
                    res_dict['gt_content'] = sample['answer']
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                res_dict.update(sample)
         
     | 
| 262 | 
         
            +
                return res_dict
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
            """Response Parsing and Evaluation for various models"""
         
     | 
| 265 | 
         
            +
            from typing import Dict
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
            import re
         
     | 
| 268 | 
         
            +
            import random
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            import numpy as np
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
            # ----------- Process Multi-choice -------------
         
     | 
| 274 | 
         
            +
            def parse_multi_choice_response(response, all_choices, index2ans):
         
     | 
| 275 | 
         
            +
                """
         
     | 
| 276 | 
         
            +
                Parse the prediction from the generated response.
         
     | 
| 277 | 
         
            +
                Return the predicted index e.g., A, B, C, D.
         
     | 
| 278 | 
         
            +
                """
         
     | 
| 279 | 
         
            +
                for char in [',', '.', '!', '?', ';', ':', "'"]:
         
     | 
| 280 | 
         
            +
                    response = response.strip(char)
         
     | 
| 281 | 
         
            +
                response = " " + response + " "  # add space to avoid partial match
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                index_ans = True
         
     | 
| 284 | 
         
            +
                ans_with_brack = False
         
     | 
| 285 | 
         
            +
                candidates = []
         
     | 
| 286 | 
         
            +
                for choice in all_choices:  # e.g., (A) (B) (C) (D) A) B) C) D)
         
     | 
| 287 | 
         
            +
                    if f'({choice})' in response or f'{choice})' in response:
         
     | 
| 288 | 
         
            +
                        candidates.append(choice)
         
     | 
| 289 | 
         
            +
                        ans_with_brack = True
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                if len(candidates) == 0:
         
     | 
| 292 | 
         
            +
                    for choice in all_choices:  # e.g., A B C D
         
     | 
| 293 | 
         
            +
                        if f' {choice} ' in response:
         
     | 
| 294 | 
         
            +
                            candidates.append(choice)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
         
     | 
| 297 | 
         
            +
                if len(candidates) == 0 and len(response.split()) > 5:
         
     | 
| 298 | 
         
            +
                    for index, ans in index2ans.items():
         
     | 
| 299 | 
         
            +
                        if ans.lower() in response.lower():
         
     | 
| 300 | 
         
            +
                            candidates.append(index)
         
     | 
| 301 | 
         
            +
                            index_ans = False  # it's content ans.
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                if len(candidates) == 0:  # still not get answer, randomly choose one.
         
     | 
| 304 | 
         
            +
                    pred_index = all_choices[0]
         
     | 
| 305 | 
         
            +
                elif len(candidates) > 1:
         
     | 
| 306 | 
         
            +
                    start_indexes = []
         
     | 
| 307 | 
         
            +
                    if index_ans:
         
     | 
| 308 | 
         
            +
                        if ans_with_brack:
         
     | 
| 309 | 
         
            +
                            for can in candidates:
         
     | 
| 310 | 
         
            +
                                index = response.rfind(f'({can})')
         
     | 
| 311 | 
         
            +
                                start_indexes.append(index)  # -1 will be ignored anyway
         
     | 
| 312 | 
         
            +
                            # start_indexes = [generated_response.index(f'({can})') for can in candidates]
         
     | 
| 313 | 
         
            +
                        else:
         
     | 
| 314 | 
         
            +
                            for can in candidates:
         
     | 
| 315 | 
         
            +
                                index = response.rfind(f" {can} ")
         
     | 
| 316 | 
         
            +
                                start_indexes.append(index)
         
     | 
| 317 | 
         
            +
                    else:
         
     | 
| 318 | 
         
            +
                        for can in candidates:
         
     | 
| 319 | 
         
            +
                            index = response.lower().rfind(index2ans[can].lower())
         
     | 
| 320 | 
         
            +
                            start_indexes.append(index)
         
     | 
| 321 | 
         
            +
                    # get the last one
         
     | 
| 322 | 
         
            +
                    pred_index = candidates[np.argmax(start_indexes)]
         
     | 
| 323 | 
         
            +
                else:  # if only one candidate, use it.
         
     | 
| 324 | 
         
            +
                    pred_index = candidates[0]
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                return pred_index
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
            # ----------- Process Open -------------
         
     | 
| 330 | 
         
            +
            def check_is_number(string):
         
     | 
| 331 | 
         
            +
                """
         
     | 
| 332 | 
         
            +
                Check if the given string a number.
         
     | 
| 333 | 
         
            +
                """
         
     | 
| 334 | 
         
            +
                try:
         
     | 
| 335 | 
         
            +
                    float(string.replace(',', ''))
         
     | 
| 336 | 
         
            +
                    return True
         
     | 
| 337 | 
         
            +
                except ValueError:
         
     | 
| 338 | 
         
            +
                    # check if there's comma inside
         
     | 
| 339 | 
         
            +
                    return False
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
            def normalize_str(string):
         
     | 
| 343 | 
         
            +
                """
         
     | 
| 344 | 
         
            +
                Normalize the str to lower case and make them float numbers if possible.
         
     | 
| 345 | 
         
            +
                """
         
     | 
| 346 | 
         
            +
                # check if characters in the string
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                # if number, numerize it.
         
     | 
| 349 | 
         
            +
                string = string.strip()
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                is_number = check_is_number(string)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                if is_number:
         
     | 
| 354 | 
         
            +
                    string = string.replace(',', '')
         
     | 
| 355 | 
         
            +
                    string = float(string)
         
     | 
| 356 | 
         
            +
                    # leave 2 decimal
         
     | 
| 357 | 
         
            +
                    string = round(string, 2)
         
     | 
| 358 | 
         
            +
                    return [string]
         
     | 
| 359 | 
         
            +
                else:  # it's likely to be a string
         
     | 
| 360 | 
         
            +
                    # lower it
         
     | 
| 361 | 
         
            +
                    string = string.lower()
         
     | 
| 362 | 
         
            +
                    if len(string) == 1:
         
     | 
| 363 | 
         
            +
                        return [" " + string, string + " "]  # avoid trivial matches
         
     | 
| 364 | 
         
            +
                    return [string]
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
            def extract_numbers(string):
         
     | 
| 368 | 
         
            +
                """
         
     | 
| 369 | 
         
            +
                Exact all forms of numbers from a string with regex.
         
     | 
| 370 | 
         
            +
                """
         
     | 
| 371 | 
         
            +
                # Pattern for numbers with commas
         
     | 
| 372 | 
         
            +
                pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
         
     | 
| 373 | 
         
            +
                # Pattern for scientific notation
         
     | 
| 374 | 
         
            +
                pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
         
     | 
| 375 | 
         
            +
                # Pattern for simple numbers without commas
         
     | 
| 376 | 
         
            +
                pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                # Extract numbers with commas
         
     | 
| 379 | 
         
            +
                numbers_with_commas = re.findall(pattern_commas, string)
         
     | 
| 380 | 
         
            +
                # Extract numbers in scientific notation
         
     | 
| 381 | 
         
            +
                numbers_scientific = re.findall(pattern_scientific, string)
         
     | 
| 382 | 
         
            +
                # Extract simple numbers without commas
         
     | 
| 383 | 
         
            +
                numbers_simple = re.findall(pattern_simple, string)
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                # Combine all extracted numbers
         
     | 
| 386 | 
         
            +
                all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
         
     | 
| 387 | 
         
            +
                return all_numbers
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
            def parse_open_response(response):
         
     | 
| 391 | 
         
            +
                """
         
     | 
| 392 | 
         
            +
                Parse the prediction from the generated response.
         
     | 
| 393 | 
         
            +
                Return a list of predicted strings or numbers.
         
     | 
| 394 | 
         
            +
                """
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                # content = content.strip("\n").strip(".").strip(" ")
         
     | 
| 397 | 
         
            +
                def get_key_subresponses(response):
         
     | 
| 398 | 
         
            +
                    key_responses = []
         
     | 
| 399 | 
         
            +
                    response = response.strip().strip(".").lower()
         
     | 
| 400 | 
         
            +
                    sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
         
     | 
| 401 | 
         
            +
                    indicators_of_keys = ['could be ', 'so ', 'is ',
         
     | 
| 402 | 
         
            +
                                          'thus ', 'therefore ', 'final ', 'answer ', 'result ']
         
     | 
| 403 | 
         
            +
                    key_responses = []
         
     | 
| 404 | 
         
            +
                    for index, resp in enumerate(sub_responses):
         
     | 
| 405 | 
         
            +
                        # if last one, accept it's an equation (the entire response can be just one sentence with equation)
         
     | 
| 406 | 
         
            +
                        if index == len(sub_responses) - 1:
         
     | 
| 407 | 
         
            +
                            indicators_of_keys.extend(['='])
         
     | 
| 408 | 
         
            +
                        shortest_key_response = None  # the shortest response that may contain the answer (tail part of the response)
         
     | 
| 409 | 
         
            +
                        for indicator in indicators_of_keys:
         
     | 
| 410 | 
         
            +
                            if indicator in resp:
         
     | 
| 411 | 
         
            +
                                if not shortest_key_response:
         
     | 
| 412 | 
         
            +
                                    shortest_key_response = resp.split(indicator)[-1].strip()
         
     | 
| 413 | 
         
            +
                                else:
         
     | 
| 414 | 
         
            +
                                    if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
         
     | 
| 415 | 
         
            +
                                        shortest_key_response = resp.split(indicator)[-1].strip()
         
     | 
| 416 | 
         
            +
                                # key_responses.append(resp.split(indicator)[1].strip())
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                        if shortest_key_response:
         
     | 
| 419 | 
         
            +
                            # and it's not trivial
         
     | 
| 420 | 
         
            +
                            if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
         
     | 
| 421 | 
         
            +
                                key_responses.append(shortest_key_response)
         
     | 
| 422 | 
         
            +
                    if len(key_responses) == 0:  # did not found any
         
     | 
| 423 | 
         
            +
                        return [response]
         
     | 
| 424 | 
         
            +
                    return key_responses
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                # pdb.set_trace()
         
     | 
| 427 | 
         
            +
                key_responses = get_key_subresponses(response)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                pred_list = key_responses.copy()  # keep the original string response
         
     | 
| 430 | 
         
            +
                for resp in key_responses:
         
     | 
| 431 | 
         
            +
                    pred_list.extend(extract_numbers(resp))
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                tmp_pred_list = []
         
     | 
| 434 | 
         
            +
                for i in range(len(pred_list)):
         
     | 
| 435 | 
         
            +
                    tmp_pred_list.extend(normalize_str(pred_list[i]))
         
     | 
| 436 | 
         
            +
                pred_list = tmp_pred_list
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                # remove duplicates
         
     | 
| 439 | 
         
            +
                pred_list = list(set(pred_list))
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                return pred_list
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
            # ----------- Evaluation -------------
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
            def eval_multi_choice(gold_i, pred_i):
         
     | 
| 447 | 
         
            +
                """
         
     | 
| 448 | 
         
            +
                Evaluate a multiple choice instance.
         
     | 
| 449 | 
         
            +
                """
         
     | 
| 450 | 
         
            +
                correct = False
         
     | 
| 451 | 
         
            +
                # only they are exactly the same, we consider it as correct
         
     | 
| 452 | 
         
            +
                if isinstance(gold_i, list):
         
     | 
| 453 | 
         
            +
                    for answer in gold_i:
         
     | 
| 454 | 
         
            +
                        if answer == pred_i:
         
     | 
| 455 | 
         
            +
                            correct = True
         
     | 
| 456 | 
         
            +
                            break
         
     | 
| 457 | 
         
            +
                else:  # gold_i is a string
         
     | 
| 458 | 
         
            +
                    if gold_i == pred_i:
         
     | 
| 459 | 
         
            +
                        correct = True
         
     | 
| 460 | 
         
            +
                return correct
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
            def eval_open(gold_i, pred_i):
         
     | 
| 464 | 
         
            +
                """
         
     | 
| 465 | 
         
            +
                Evaluate an open question instance
         
     | 
| 466 | 
         
            +
                """
         
     | 
| 467 | 
         
            +
                correct = False
         
     | 
| 468 | 
         
            +
                if isinstance(gold_i, list):
         
     | 
| 469 | 
         
            +
                    # use float to avoid trivial matches
         
     | 
| 470 | 
         
            +
                    norm_answers = []
         
     | 
| 471 | 
         
            +
                    for answer in gold_i:
         
     | 
| 472 | 
         
            +
                        norm_answers.extend(normalize_str(answer))
         
     | 
| 473 | 
         
            +
                else:
         
     | 
| 474 | 
         
            +
                    norm_answers = normalize_str(gold_i)
         
     | 
| 475 | 
         
            +
                for pred in pred_i:  # pred is already normalized in parse response phase
         
     | 
| 476 | 
         
            +
                    if isinstance(pred, str):  # if it's a string, then find if ans in the pred_i
         
     | 
| 477 | 
         
            +
                        for norm_ans in norm_answers:
         
     | 
| 478 | 
         
            +
                            # only see if the string answer in the string pred
         
     | 
| 479 | 
         
            +
                            if isinstance(norm_ans, str) and norm_ans in pred:
         
     | 
| 480 | 
         
            +
                                if not correct:
         
     | 
| 481 | 
         
            +
                                    correct = True
         
     | 
| 482 | 
         
            +
                                break
         
     | 
| 483 | 
         
            +
                    else:  # it's a float number
         
     | 
| 484 | 
         
            +
                        if pred in norm_answers:
         
     | 
| 485 | 
         
            +
                            if not correct:
         
     | 
| 486 | 
         
            +
                                correct = True
         
     | 
| 487 | 
         
            +
                            break
         
     | 
| 488 | 
         
            +
                return correct
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
            # ----------- Batch Evaluation -------------
         
     | 
| 492 | 
         
            +
            def evaluate(samples):
         
     | 
| 493 | 
         
            +
                """
         
     | 
| 494 | 
         
            +
                Batch evaluation for multiple choice and open questions.
         
     | 
| 495 | 
         
            +
                """
         
     | 
| 496 | 
         
            +
                pred_correct = 0
         
     | 
| 497 | 
         
            +
                judge_dict = dict()
         
     | 
| 498 | 
         
            +
                for sample in samples:
         
     | 
| 499 | 
         
            +
                    gold_i = sample['answer']
         
     | 
| 500 | 
         
            +
                    pred_i = sample['parsed_pred']
         
     | 
| 501 | 
         
            +
                    if sample['question_type'] == 'multiple-choice':
         
     | 
| 502 | 
         
            +
                        correct = eval_multi_choice(gold_i, pred_i)
         
     | 
| 503 | 
         
            +
                    else:  # open question
         
     | 
| 504 | 
         
            +
                        correct = eval_open(gold_i, pred_i)
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    if correct:
         
     | 
| 507 | 
         
            +
                        judge_dict[sample['id']] = 'Correct'
         
     | 
| 508 | 
         
            +
                        pred_correct += 1
         
     | 
| 509 | 
         
            +
                    else:
         
     | 
| 510 | 
         
            +
                        judge_dict[sample['id']] = 'Wrong'
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                if len(samples) == 0:
         
     | 
| 513 | 
         
            +
                    return {'acc': 0}
         
     | 
| 514 | 
         
            +
                return judge_dict, {'acc': pred_correct / len(samples)}
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
            # ----------- Calculate Accuracy -------------
         
     | 
| 518 | 
         
            +
            def calculate_ins_level_acc(results: Dict):
         
     | 
| 519 | 
         
            +
                """Calculate the instruction level accuracy for given Subject results"""
         
     | 
| 520 | 
         
            +
                acc = 0
         
     | 
| 521 | 
         
            +
                ins_num = 0
         
     | 
| 522 | 
         
            +
                for cat_results in results.values():
         
     | 
| 523 | 
         
            +
                    acc += cat_results['acc'] * cat_results['num_example']
         
     | 
| 524 | 
         
            +
                    ins_num += cat_results['num_example']
         
     | 
| 525 | 
         
            +
                if ins_num == 0:
         
     | 
| 526 | 
         
            +
                    return 0
         
     | 
| 527 | 
         
            +
                return acc / ins_num
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
            def mmmu_main_eval(output_dict, task_cfg):
         
     | 
| 531 | 
         
            +
                answer_dict = json.load(open(task_cfg["answer_dict"]))
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                # group by category
         
     | 
| 534 | 
         
            +
                output_dict_w_cat = {}
         
     | 
| 535 | 
         
            +
                for data_id, parsed_pred in output_dict.items():
         
     | 
| 536 | 
         
            +
                    category = "_".join(data_id.split("_")[1:-1])
         
     | 
| 537 | 
         
            +
                    if category not in output_dict_w_cat:
         
     | 
| 538 | 
         
            +
                        output_dict_w_cat.update({category: {}})
         
     | 
| 539 | 
         
            +
                    output_dict_w_cat[category].update({data_id: parsed_pred})
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                # group by category
         
     | 
| 542 | 
         
            +
                answer_dict_w_cat = {}
         
     | 
| 543 | 
         
            +
                for data_id, parsed_pred in answer_dict.items():
         
     | 
| 544 | 
         
            +
                    category = "_".join(data_id.split("_")[1:-1])
         
     | 
| 545 | 
         
            +
                    if category not in answer_dict_w_cat:
         
     | 
| 546 | 
         
            +
                        answer_dict_w_cat.update({category: {}})
         
     | 
| 547 | 
         
            +
                    answer_dict_w_cat[category].update({data_id: parsed_pred})
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                evaluation_result = {}
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                for category in CAT_SHORT2LONG.values():
         
     | 
| 552 | 
         
            +
                    # print("Evaluating: {}".format(category))
         
     | 
| 553 | 
         
            +
                    # get cat_outputs and cat_answers
         
     | 
| 554 | 
         
            +
                    try:
         
     | 
| 555 | 
         
            +
                        cat_outputs = output_dict_w_cat[category]
         
     | 
| 556 | 
         
            +
                        cat_answers = answer_dict_w_cat[category]
         
     | 
| 557 | 
         
            +
                    except KeyError:
         
     | 
| 558 | 
         
            +
                        print("Skipping {} for not found".format(category))
         
     | 
| 559 | 
         
            +
                        continue
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    exampels_to_eval = []
         
     | 
| 562 | 
         
            +
                    for data_id, parsed_pred in cat_outputs.items():
         
     | 
| 563 | 
         
            +
                        question_type = cat_answers[data_id]['question_type']
         
     | 
| 564 | 
         
            +
                        if question_type != 'multiple-choice':
         
     | 
| 565 | 
         
            +
                            parsed_pred = parse_open_response(parsed_pred)  # mainly for type consistency (make it number, etc.)
         
     | 
| 566 | 
         
            +
                        else:
         
     | 
| 567 | 
         
            +
                            parsed_pred = parsed_pred
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                        exampels_to_eval.append({
         
     | 
| 570 | 
         
            +
                            "id": data_id,
         
     | 
| 571 | 
         
            +
                            "question_type": question_type,
         
     | 
| 572 | 
         
            +
                            "answer": cat_answers[data_id]['ground_truth'],
         
     | 
| 573 | 
         
            +
                            "parsed_pred": parsed_pred
         
     | 
| 574 | 
         
            +
                        })
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                    judge_dict, metric_dict = evaluate(exampels_to_eval)
         
     | 
| 577 | 
         
            +
                    metric_dict.update({"num_example": len(exampels_to_eval)})
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                    evaluation_result[category] = metric_dict
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                printable_results = {}
         
     | 
| 582 | 
         
            +
                # pdb.set_trace()
         
     | 
| 583 | 
         
            +
                # add domain Subject
         
     | 
| 584 | 
         
            +
                for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
         
     | 
| 585 | 
         
            +
                    in_domain_cat_results = {}
         
     | 
| 586 | 
         
            +
                    for cat_name in in_domain_cats:  # use the order in DOMAIN_CAT2SUB_CAT
         
     | 
| 587 | 
         
            +
                        if cat_name in evaluation_result.keys():
         
     | 
| 588 | 
         
            +
                            in_domain_cat_results[cat_name] = evaluation_result[cat_name]
         
     | 
| 589 | 
         
            +
                        else:
         
     | 
| 590 | 
         
            +
                            pass
         
     | 
| 591 | 
         
            +
                    in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
         
     | 
| 592 | 
         
            +
                    in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
         
     | 
| 593 | 
         
            +
                    printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
         
     | 
| 594 | 
         
            +
                                                              "acc": round(in_domain_ins_acc, 4)
         
     | 
| 595 | 
         
            +
                                                              }
         
     | 
| 596 | 
         
            +
                    # add sub category
         
     | 
| 597 | 
         
            +
                    for cat_name, cat_results in in_domain_cat_results.items():
         
     | 
| 598 | 
         
            +
                        printable_results[cat_name] = {"num": int(cat_results['num_example']),
         
     | 
| 599 | 
         
            +
                                                       "acc": round(cat_results['acc'], 4)
         
     | 
| 600 | 
         
            +
                                                       }
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                # table.append(["-----------------------------", "-----", "----"])
         
     | 
| 603 | 
         
            +
                all_ins_acc = calculate_ins_level_acc(evaluation_result)
         
     | 
| 604 | 
         
            +
                printable_results['Overall'] = {
         
     | 
| 605 | 
         
            +
                    "num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
         
     | 
| 606 | 
         
            +
                    "acc": round(all_ins_acc, 4)
         
     | 
| 607 | 
         
            +
                    }
         
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
                # print(printable_results)
         
     | 
| 610 | 
         
            +
                return printable_results
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 614 | 
         
            +
                # tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi_oci.yaml"))['datasets']
         
     | 
| 615 | 
         
            +
                tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi.yaml"))['datasets']
         
     | 
| 616 | 
         
            +
                print(tasks)
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                # with open("/lustre/fs4/portfolios/adlr/users/boxinw/llava-megatron-gen/checkpoints/test/eval_mmmu_iter500_merged.4node.json") as f:
         
     | 
| 619 | 
         
            +
                with open("/lustre/fsw/portfolios/llmservice/users/boxinw/eval_mmmu_iter6000_merged.0.53.json") as f:
         
     | 
| 620 | 
         
            +
                    merged_results = json.load(f)
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                eval_samples = []
         
     | 
| 623 | 
         
            +
                eval_output_dict = {}
         
     | 
| 624 | 
         
            +
                for res in merged_results:
         
     | 
| 625 | 
         
            +
                    pred_ans = res["answer"].upper()
         
     | 
| 626 | 
         
            +
                    gt_ans = res['gt_answer']
         
     | 
| 627 | 
         
            +
                    if res['question_type'] == 'multiple-choice':
         
     | 
| 628 | 
         
            +
                        parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
         
     | 
| 629 | 
         
            +
                        if pred_ans != parsed_pred:
         
     | 
| 630 | 
         
            +
                            print(f"MC: Original: {pred_ans}, Parsed: {parsed_pred}")
         
     | 
| 631 | 
         
            +
                        eval_samples.append(
         
     | 
| 632 | 
         
            +
                            {
         
     | 
| 633 | 
         
            +
                                'id': res['question_id'],
         
     | 
| 634 | 
         
            +
                                'question_type': res['question_type'],
         
     | 
| 635 | 
         
            +
                                'answer': res['gt_answer'],  # the content in option, not answer index.
         
     | 
| 636 | 
         
            +
                                'response': pred_ans,
         
     | 
| 637 | 
         
            +
                                'parsed_pred': parsed_pred,
         
     | 
| 638 | 
         
            +
                                'index2ans': res['index2ans'],
         
     | 
| 639 | 
         
            +
                            }
         
     | 
| 640 | 
         
            +
                        )
         
     | 
| 641 | 
         
            +
                        eval_output_dict[res['question_id']] = parsed_pred
         
     | 
| 642 | 
         
            +
                    else:
         
     | 
| 643 | 
         
            +
                        parsed_pred = parse_open_response(pred_ans)
         
     | 
| 644 | 
         
            +
                        if pred_ans != parsed_pred:
         
     | 
| 645 | 
         
            +
                            print(f"Open: Original: {pred_ans}, Parsed: {parsed_pred}")
         
     | 
| 646 | 
         
            +
                        eval_samples.append(
         
     | 
| 647 | 
         
            +
                            {
         
     | 
| 648 | 
         
            +
                                'id': res['question_id'],
         
     | 
| 649 | 
         
            +
                                'question_type': res['question_type'],
         
     | 
| 650 | 
         
            +
                                'answer': res['gt_answer'],
         
     | 
| 651 | 
         
            +
                                'response': pred_ans,
         
     | 
| 652 | 
         
            +
                                'parsed_pred': parsed_pred,
         
     | 
| 653 | 
         
            +
                            }
         
     | 
| 654 | 
         
            +
                        )
         
     | 
| 655 | 
         
            +
                        eval_output_dict[res['question_id']] = pred_ans
         
     | 
| 656 | 
         
            +
             
     | 
| 657 | 
         
            +
                json.dump(eval_output_dict, open("validation_mmmu_iter6000_merged.0.53.sorted.json", "w"), indent=4, sort_keys=True)
         
     | 
| 658 | 
         
            +
             
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                x = mmmu_main_eval(eval_output_dict,
         
     | 
| 661 | 
         
            +
                               task_cfg=tasks['mmmu'])
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                print(x)
         
     | 
    	
        eval/requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            anls
         
     | 
| 2 | 
         
            +
            datasets
         
     | 
| 3 | 
         
            +
            pycocoevalcap
         
     | 
    	
        eval/vqa_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,317 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            __author__ = "aagrawal"
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
         
     | 
| 6 | 
         
            +
            # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
         
     | 
| 7 | 
         
            +
            import sys
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class VQAEval:
         
     | 
| 12 | 
         
            +
                def __init__(self, vqa=None, vqaRes=None, n=2):
         
     | 
| 13 | 
         
            +
                    self.n = n
         
     | 
| 14 | 
         
            +
                    self.accuracy = {}
         
     | 
| 15 | 
         
            +
                    self.evalQA = {}
         
     | 
| 16 | 
         
            +
                    self.evalQuesType = {}
         
     | 
| 17 | 
         
            +
                    self.evalAnsType = {}
         
     | 
| 18 | 
         
            +
                    self.vqa = vqa
         
     | 
| 19 | 
         
            +
                    self.vqaRes = vqaRes
         
     | 
| 20 | 
         
            +
                    if vqa is not None:
         
     | 
| 21 | 
         
            +
                        self.params = {"question_id": vqa.getQuesIds()}
         
     | 
| 22 | 
         
            +
                    self.contractions = {
         
     | 
| 23 | 
         
            +
                        "aint": "ain't",
         
     | 
| 24 | 
         
            +
                        "arent": "aren't",
         
     | 
| 25 | 
         
            +
                        "cant": "can't",
         
     | 
| 26 | 
         
            +
                        "couldve": "could've",
         
     | 
| 27 | 
         
            +
                        "couldnt": "couldn't",
         
     | 
| 28 | 
         
            +
                        "couldn'tve": "couldn't've",
         
     | 
| 29 | 
         
            +
                        "couldnt've": "couldn't've",
         
     | 
| 30 | 
         
            +
                        "didnt": "didn't",
         
     | 
| 31 | 
         
            +
                        "doesnt": "doesn't",
         
     | 
| 32 | 
         
            +
                        "dont": "don't",
         
     | 
| 33 | 
         
            +
                        "hadnt": "hadn't",
         
     | 
| 34 | 
         
            +
                        "hadnt've": "hadn't've",
         
     | 
| 35 | 
         
            +
                        "hadn'tve": "hadn't've",
         
     | 
| 36 | 
         
            +
                        "hasnt": "hasn't",
         
     | 
| 37 | 
         
            +
                        "havent": "haven't",
         
     | 
| 38 | 
         
            +
                        "hed": "he'd",
         
     | 
| 39 | 
         
            +
                        "hed've": "he'd've",
         
     | 
| 40 | 
         
            +
                        "he'dve": "he'd've",
         
     | 
| 41 | 
         
            +
                        "hes": "he's",
         
     | 
| 42 | 
         
            +
                        "howd": "how'd",
         
     | 
| 43 | 
         
            +
                        "howll": "how'll",
         
     | 
| 44 | 
         
            +
                        "hows": "how's",
         
     | 
| 45 | 
         
            +
                        "Id've": "I'd've",
         
     | 
| 46 | 
         
            +
                        "I'dve": "I'd've",
         
     | 
| 47 | 
         
            +
                        "Im": "I'm",
         
     | 
| 48 | 
         
            +
                        "Ive": "I've",
         
     | 
| 49 | 
         
            +
                        "isnt": "isn't",
         
     | 
| 50 | 
         
            +
                        "itd": "it'd",
         
     | 
| 51 | 
         
            +
                        "itd've": "it'd've",
         
     | 
| 52 | 
         
            +
                        "it'dve": "it'd've",
         
     | 
| 53 | 
         
            +
                        "itll": "it'll",
         
     | 
| 54 | 
         
            +
                        "let's": "let's",
         
     | 
| 55 | 
         
            +
                        "maam": "ma'am",
         
     | 
| 56 | 
         
            +
                        "mightnt": "mightn't",
         
     | 
| 57 | 
         
            +
                        "mightnt've": "mightn't've",
         
     | 
| 58 | 
         
            +
                        "mightn'tve": "mightn't've",
         
     | 
| 59 | 
         
            +
                        "mightve": "might've",
         
     | 
| 60 | 
         
            +
                        "mustnt": "mustn't",
         
     | 
| 61 | 
         
            +
                        "mustve": "must've",
         
     | 
| 62 | 
         
            +
                        "neednt": "needn't",
         
     | 
| 63 | 
         
            +
                        "notve": "not've",
         
     | 
| 64 | 
         
            +
                        "oclock": "o'clock",
         
     | 
| 65 | 
         
            +
                        "oughtnt": "oughtn't",
         
     | 
| 66 | 
         
            +
                        "ow's'at": "'ow's'at",
         
     | 
| 67 | 
         
            +
                        "'ows'at": "'ow's'at",
         
     | 
| 68 | 
         
            +
                        "'ow'sat": "'ow's'at",
         
     | 
| 69 | 
         
            +
                        "shant": "shan't",
         
     | 
| 70 | 
         
            +
                        "shed've": "she'd've",
         
     | 
| 71 | 
         
            +
                        "she'dve": "she'd've",
         
     | 
| 72 | 
         
            +
                        "she's": "she's",
         
     | 
| 73 | 
         
            +
                        "shouldve": "should've",
         
     | 
| 74 | 
         
            +
                        "shouldnt": "shouldn't",
         
     | 
| 75 | 
         
            +
                        "shouldnt've": "shouldn't've",
         
     | 
| 76 | 
         
            +
                        "shouldn'tve": "shouldn't've",
         
     | 
| 77 | 
         
            +
                        "somebody'd": "somebodyd",
         
     | 
| 78 | 
         
            +
                        "somebodyd've": "somebody'd've",
         
     | 
| 79 | 
         
            +
                        "somebody'dve": "somebody'd've",
         
     | 
| 80 | 
         
            +
                        "somebodyll": "somebody'll",
         
     | 
| 81 | 
         
            +
                        "somebodys": "somebody's",
         
     | 
| 82 | 
         
            +
                        "someoned": "someone'd",
         
     | 
| 83 | 
         
            +
                        "someoned've": "someone'd've",
         
     | 
| 84 | 
         
            +
                        "someone'dve": "someone'd've",
         
     | 
| 85 | 
         
            +
                        "someonell": "someone'll",
         
     | 
| 86 | 
         
            +
                        "someones": "someone's",
         
     | 
| 87 | 
         
            +
                        "somethingd": "something'd",
         
     | 
| 88 | 
         
            +
                        "somethingd've": "something'd've",
         
     | 
| 89 | 
         
            +
                        "something'dve": "something'd've",
         
     | 
| 90 | 
         
            +
                        "somethingll": "something'll",
         
     | 
| 91 | 
         
            +
                        "thats": "that's",
         
     | 
| 92 | 
         
            +
                        "thered": "there'd",
         
     | 
| 93 | 
         
            +
                        "thered've": "there'd've",
         
     | 
| 94 | 
         
            +
                        "there'dve": "there'd've",
         
     | 
| 95 | 
         
            +
                        "therere": "there're",
         
     | 
| 96 | 
         
            +
                        "theres": "there's",
         
     | 
| 97 | 
         
            +
                        "theyd": "they'd",
         
     | 
| 98 | 
         
            +
                        "theyd've": "they'd've",
         
     | 
| 99 | 
         
            +
                        "they'dve": "they'd've",
         
     | 
| 100 | 
         
            +
                        "theyll": "they'll",
         
     | 
| 101 | 
         
            +
                        "theyre": "they're",
         
     | 
| 102 | 
         
            +
                        "theyve": "they've",
         
     | 
| 103 | 
         
            +
                        "twas": "'twas",
         
     | 
| 104 | 
         
            +
                        "wasnt": "wasn't",
         
     | 
| 105 | 
         
            +
                        "wed've": "we'd've",
         
     | 
| 106 | 
         
            +
                        "we'dve": "we'd've",
         
     | 
| 107 | 
         
            +
                        "weve": "we've",
         
     | 
| 108 | 
         
            +
                        "werent": "weren't",
         
     | 
| 109 | 
         
            +
                        "whatll": "what'll",
         
     | 
| 110 | 
         
            +
                        "whatre": "what're",
         
     | 
| 111 | 
         
            +
                        "whats": "what's",
         
     | 
| 112 | 
         
            +
                        "whatve": "what've",
         
     | 
| 113 | 
         
            +
                        "whens": "when's",
         
     | 
| 114 | 
         
            +
                        "whered": "where'd",
         
     | 
| 115 | 
         
            +
                        "wheres": "where's",
         
     | 
| 116 | 
         
            +
                        "whereve": "where've",
         
     | 
| 117 | 
         
            +
                        "whod": "who'd",
         
     | 
| 118 | 
         
            +
                        "whod've": "who'd've",
         
     | 
| 119 | 
         
            +
                        "who'dve": "who'd've",
         
     | 
| 120 | 
         
            +
                        "wholl": "who'll",
         
     | 
| 121 | 
         
            +
                        "whos": "who's",
         
     | 
| 122 | 
         
            +
                        "whove": "who've",
         
     | 
| 123 | 
         
            +
                        "whyll": "why'll",
         
     | 
| 124 | 
         
            +
                        "whyre": "why're",
         
     | 
| 125 | 
         
            +
                        "whys": "why's",
         
     | 
| 126 | 
         
            +
                        "wont": "won't",
         
     | 
| 127 | 
         
            +
                        "wouldve": "would've",
         
     | 
| 128 | 
         
            +
                        "wouldnt": "wouldn't",
         
     | 
| 129 | 
         
            +
                        "wouldnt've": "wouldn't've",
         
     | 
| 130 | 
         
            +
                        "wouldn'tve": "wouldn't've",
         
     | 
| 131 | 
         
            +
                        "yall": "y'all",
         
     | 
| 132 | 
         
            +
                        "yall'll": "y'all'll",
         
     | 
| 133 | 
         
            +
                        "y'allll": "y'all'll",
         
     | 
| 134 | 
         
            +
                        "yall'd've": "y'all'd've",
         
     | 
| 135 | 
         
            +
                        "y'alld've": "y'all'd've",
         
     | 
| 136 | 
         
            +
                        "y'all'dve": "y'all'd've",
         
     | 
| 137 | 
         
            +
                        "youd": "you'd",
         
     | 
| 138 | 
         
            +
                        "youd've": "you'd've",
         
     | 
| 139 | 
         
            +
                        "you'dve": "you'd've",
         
     | 
| 140 | 
         
            +
                        "youll": "you'll",
         
     | 
| 141 | 
         
            +
                        "youre": "you're",
         
     | 
| 142 | 
         
            +
                        "youve": "you've",
         
     | 
| 143 | 
         
            +
                    }
         
     | 
| 144 | 
         
            +
                    self.manualMap = {
         
     | 
| 145 | 
         
            +
                        "none": "0",
         
     | 
| 146 | 
         
            +
                        "zero": "0",
         
     | 
| 147 | 
         
            +
                        "one": "1",
         
     | 
| 148 | 
         
            +
                        "two": "2",
         
     | 
| 149 | 
         
            +
                        "three": "3",
         
     | 
| 150 | 
         
            +
                        "four": "4",
         
     | 
| 151 | 
         
            +
                        "five": "5",
         
     | 
| 152 | 
         
            +
                        "six": "6",
         
     | 
| 153 | 
         
            +
                        "seven": "7",
         
     | 
| 154 | 
         
            +
                        "eight": "8",
         
     | 
| 155 | 
         
            +
                        "nine": "9",
         
     | 
| 156 | 
         
            +
                        "ten": "10",
         
     | 
| 157 | 
         
            +
                    }
         
     | 
| 158 | 
         
            +
                    self.articles = ["a", "an", "the"]
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
         
     | 
| 161 | 
         
            +
                    self.commaStrip = re.compile("(\d)(,)(\d)")
         
     | 
| 162 | 
         
            +
                    self.punct = [
         
     | 
| 163 | 
         
            +
                        ";",
         
     | 
| 164 | 
         
            +
                        r"/",
         
     | 
| 165 | 
         
            +
                        "[",
         
     | 
| 166 | 
         
            +
                        "]",
         
     | 
| 167 | 
         
            +
                        '"',
         
     | 
| 168 | 
         
            +
                        "{",
         
     | 
| 169 | 
         
            +
                        "}",
         
     | 
| 170 | 
         
            +
                        "(",
         
     | 
| 171 | 
         
            +
                        ")",
         
     | 
| 172 | 
         
            +
                        "=",
         
     | 
| 173 | 
         
            +
                        "+",
         
     | 
| 174 | 
         
            +
                        "\\",
         
     | 
| 175 | 
         
            +
                        "_",
         
     | 
| 176 | 
         
            +
                        "-",
         
     | 
| 177 | 
         
            +
                        ">",
         
     | 
| 178 | 
         
            +
                        "<",
         
     | 
| 179 | 
         
            +
                        "@",
         
     | 
| 180 | 
         
            +
                        "`",
         
     | 
| 181 | 
         
            +
                        ",",
         
     | 
| 182 | 
         
            +
                        "?",
         
     | 
| 183 | 
         
            +
                        "!",
         
     | 
| 184 | 
         
            +
                    ]
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                def evaluate(self, quesIds=None):
         
     | 
| 187 | 
         
            +
                    if quesIds == None:
         
     | 
| 188 | 
         
            +
                        quesIds = [quesId for quesId in self.params["question_id"]]
         
     | 
| 189 | 
         
            +
                    gts = {}
         
     | 
| 190 | 
         
            +
                    res = {}
         
     | 
| 191 | 
         
            +
                    for quesId in quesIds:
         
     | 
| 192 | 
         
            +
                        gts[quesId] = self.vqa.qa[quesId]
         
     | 
| 193 | 
         
            +
                        res[quesId] = self.vqaRes.qa[quesId]
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    # =================================================
         
     | 
| 196 | 
         
            +
                    # Compute accuracy
         
     | 
| 197 | 
         
            +
                    # =================================================
         
     | 
| 198 | 
         
            +
                    accQA = []
         
     | 
| 199 | 
         
            +
                    accQuesType = {}
         
     | 
| 200 | 
         
            +
                    accAnsType = {}
         
     | 
| 201 | 
         
            +
                    print("computing accuracy")
         
     | 
| 202 | 
         
            +
                    step = 0
         
     | 
| 203 | 
         
            +
                    for quesId in quesIds:
         
     | 
| 204 | 
         
            +
                        resAns = res[quesId]["answer"]
         
     | 
| 205 | 
         
            +
                        resAns = resAns.replace("\n", " ")
         
     | 
| 206 | 
         
            +
                        resAns = resAns.replace("\t", " ")
         
     | 
| 207 | 
         
            +
                        resAns = resAns.strip()
         
     | 
| 208 | 
         
            +
                        resAns = self.processPunctuation(resAns)
         
     | 
| 209 | 
         
            +
                        resAns = self.processDigitArticle(resAns)
         
     | 
| 210 | 
         
            +
                        gtAcc = []
         
     | 
| 211 | 
         
            +
                        gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
         
     | 
| 212 | 
         
            +
                        if len(set(gtAnswers)) > 1:
         
     | 
| 213 | 
         
            +
                            for ansDic in gts[quesId]["answers"]:
         
     | 
| 214 | 
         
            +
                                ansDic["answer"] = self.processPunctuation(ansDic["answer"])
         
     | 
| 215 | 
         
            +
                        for gtAnsDatum in gts[quesId]["answers"]:
         
     | 
| 216 | 
         
            +
                            otherGTAns = [
         
     | 
| 217 | 
         
            +
                                item for item in gts[quesId]["answers"] if item != gtAnsDatum
         
     | 
| 218 | 
         
            +
                            ]
         
     | 
| 219 | 
         
            +
                            matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
         
     | 
| 220 | 
         
            +
                            acc = min(1, float(len(matchingAns)) / 3)
         
     | 
| 221 | 
         
            +
                            gtAcc.append(acc)
         
     | 
| 222 | 
         
            +
                        quesType = gts[quesId]["question_type"]
         
     | 
| 223 | 
         
            +
                        ansType = gts[quesId]["answer_type"]
         
     | 
| 224 | 
         
            +
                        avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
         
     | 
| 225 | 
         
            +
                        accQA.append(avgGTAcc)
         
     | 
| 226 | 
         
            +
                        if quesType not in accQuesType:
         
     | 
| 227 | 
         
            +
                            accQuesType[quesType] = []
         
     | 
| 228 | 
         
            +
                        accQuesType[quesType].append(avgGTAcc)
         
     | 
| 229 | 
         
            +
                        if ansType not in accAnsType:
         
     | 
| 230 | 
         
            +
                            accAnsType[ansType] = []
         
     | 
| 231 | 
         
            +
                        accAnsType[ansType].append(avgGTAcc)
         
     | 
| 232 | 
         
            +
                        self.setEvalQA(quesId, avgGTAcc)
         
     | 
| 233 | 
         
            +
                        self.setEvalQuesType(quesId, quesType, avgGTAcc)
         
     | 
| 234 | 
         
            +
                        self.setEvalAnsType(quesId, ansType, avgGTAcc)
         
     | 
| 235 | 
         
            +
                        if step % 100 == 0:
         
     | 
| 236 | 
         
            +
                            self.updateProgress(step / float(len(quesIds)))
         
     | 
| 237 | 
         
            +
                        step = step + 1
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    self.setAccuracy(accQA, accQuesType, accAnsType)
         
     | 
| 240 | 
         
            +
                    print("Done computing accuracy")
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                def processPunctuation(self, inText):
         
     | 
| 243 | 
         
            +
                    outText = inText
         
     | 
| 244 | 
         
            +
                    for p in self.punct:
         
     | 
| 245 | 
         
            +
                        if (p + " " in inText or " " + p in inText) or (
         
     | 
| 246 | 
         
            +
                            re.search(self.commaStrip, inText) != None
         
     | 
| 247 | 
         
            +
                        ):
         
     | 
| 248 | 
         
            +
                            outText = outText.replace(p, "")
         
     | 
| 249 | 
         
            +
                        else:
         
     | 
| 250 | 
         
            +
                            outText = outText.replace(p, " ")
         
     | 
| 251 | 
         
            +
                    outText = self.periodStrip.sub("", outText, re.UNICODE)
         
     | 
| 252 | 
         
            +
                    return outText
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                def processDigitArticle(self, inText):
         
     | 
| 255 | 
         
            +
                    outText = []
         
     | 
| 256 | 
         
            +
                    tempText = inText.lower().split()
         
     | 
| 257 | 
         
            +
                    for word in tempText:
         
     | 
| 258 | 
         
            +
                        word = self.manualMap.setdefault(word, word)
         
     | 
| 259 | 
         
            +
                        if word not in self.articles:
         
     | 
| 260 | 
         
            +
                            outText.append(word)
         
     | 
| 261 | 
         
            +
                        else:
         
     | 
| 262 | 
         
            +
                            pass
         
     | 
| 263 | 
         
            +
                    for wordId, word in enumerate(outText):
         
     | 
| 264 | 
         
            +
                        if word in self.contractions:
         
     | 
| 265 | 
         
            +
                            outText[wordId] = self.contractions[word]
         
     | 
| 266 | 
         
            +
                    outText = " ".join(outText)
         
     | 
| 267 | 
         
            +
                    return outText
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                def setAccuracy(self, accQA, accQuesType, accAnsType):
         
     | 
| 270 | 
         
            +
                    self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
         
     | 
| 271 | 
         
            +
                    self.accuracy["perQuestionType"] = {
         
     | 
| 272 | 
         
            +
                        quesType: round(
         
     | 
| 273 | 
         
            +
                            100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
         
     | 
| 274 | 
         
            +
                            self.n,
         
     | 
| 275 | 
         
            +
                        )
         
     | 
| 276 | 
         
            +
                        for quesType in accQuesType
         
     | 
| 277 | 
         
            +
                    }
         
     | 
| 278 | 
         
            +
                    self.accuracy["perAnswerType"] = {
         
     | 
| 279 | 
         
            +
                        ansType: round(
         
     | 
| 280 | 
         
            +
                            100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
         
     | 
| 281 | 
         
            +
                        )
         
     | 
| 282 | 
         
            +
                        for ansType in accAnsType
         
     | 
| 283 | 
         
            +
                    }
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def setEvalQA(self, quesId, acc):
         
     | 
| 286 | 
         
            +
                    self.evalQA[quesId] = round(100 * acc, self.n)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                def setEvalQuesType(self, quesId, quesType, acc):
         
     | 
| 289 | 
         
            +
                    if quesType not in self.evalQuesType:
         
     | 
| 290 | 
         
            +
                        self.evalQuesType[quesType] = {}
         
     | 
| 291 | 
         
            +
                    self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def setEvalAnsType(self, quesId, ansType, acc):
         
     | 
| 294 | 
         
            +
                    if ansType not in self.evalAnsType:
         
     | 
| 295 | 
         
            +
                        self.evalAnsType[ansType] = {}
         
     | 
| 296 | 
         
            +
                    self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                def updateProgress(self, progress):
         
     | 
| 299 | 
         
            +
                    barLength = 20
         
     | 
| 300 | 
         
            +
                    status = ""
         
     | 
| 301 | 
         
            +
                    if isinstance(progress, int):
         
     | 
| 302 | 
         
            +
                        progress = float(progress)
         
     | 
| 303 | 
         
            +
                    if not isinstance(progress, float):
         
     | 
| 304 | 
         
            +
                        progress = 0
         
     | 
| 305 | 
         
            +
                        status = "error: progress var must be float\r\n"
         
     | 
| 306 | 
         
            +
                    if progress < 0:
         
     | 
| 307 | 
         
            +
                        progress = 0
         
     | 
| 308 | 
         
            +
                        status = "Halt...\r\n"
         
     | 
| 309 | 
         
            +
                    if progress >= 1:
         
     | 
| 310 | 
         
            +
                        progress = 1
         
     | 
| 311 | 
         
            +
                        status = "Done...\r\n"
         
     | 
| 312 | 
         
            +
                    block = int(round(barLength * progress))
         
     | 
| 313 | 
         
            +
                    text = "\rFinshed Percent: [{0}] {1}% {2}".format(
         
     | 
| 314 | 
         
            +
                        "#" * block + "-" * (barLength - block), int(progress * 100), status
         
     | 
| 315 | 
         
            +
                    )
         
     | 
| 316 | 
         
            +
                    sys.stdout.write(text)
         
     | 
| 317 | 
         
            +
                    sys.stdout.flush()
         
     | 
    	
        run_eval.py
    ADDED
    
    | 
         @@ -0,0 +1,702 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import ast
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from transformers import AutoTokenizer, AutoModel
         
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import time
         
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
            import yaml
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from eval.eval_dataset import get_task_dataloader, isNumber, get_anls_score
         
     | 
| 13 | 
         
            +
            from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
         
     | 
| 14 | 
         
            +
                process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
         
     | 
| 15 | 
         
            +
            from eval.mmmu_utils import evaluate as evaluate_mmmu
         
     | 
| 16 | 
         
            +
            from pycocotools.coco import COCO
         
     | 
| 17 | 
         
            +
            from pycocoevalcap.eval import COCOEvalCap
         
     | 
| 18 | 
         
            +
            from anls import anls_score
         
     | 
| 19 | 
         
            +
            import pandas as pd
         
     | 
| 20 | 
         
            +
            from eval.vqa_utils import VQAEval
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def split_model():
         
     | 
| 24 | 
         
            +
                device_map = {}
         
     | 
| 25 | 
         
            +
                world_size = torch.cuda.device_count()
         
     | 
| 26 | 
         
            +
                num_layers = 80
         
     | 
| 27 | 
         
            +
                # Since the first GPU will be used for ViT, treat it as half a GPU.
         
     | 
| 28 | 
         
            +
                num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
         
     | 
| 29 | 
         
            +
                num_layers_per_gpu = [num_layers_per_gpu] * world_size
         
     | 
| 30 | 
         
            +
                num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
         
     | 
| 31 | 
         
            +
                layer_cnt = 0
         
     | 
| 32 | 
         
            +
                for i, num_layer in enumerate(num_layers_per_gpu):
         
     | 
| 33 | 
         
            +
                    for j in range(num_layer):
         
     | 
| 34 | 
         
            +
                        device_map[f'language_model.model.layers.{layer_cnt}'] = i
         
     | 
| 35 | 
         
            +
                        layer_cnt += 1
         
     | 
| 36 | 
         
            +
                device_map['vision_model'] = 0
         
     | 
| 37 | 
         
            +
                device_map['mlp1'] = 0
         
     | 
| 38 | 
         
            +
                device_map['language_model.model.tok_embeddings'] = 0
         
     | 
| 39 | 
         
            +
                device_map['language_model.model.embed_tokens'] = 0
         
     | 
| 40 | 
         
            +
                device_map['language_model.output'] = 0
         
     | 
| 41 | 
         
            +
                device_map['language_model.model.norm'] = 0
         
     | 
| 42 | 
         
            +
                device_map['language_model.lm_head'] = 0
         
     | 
| 43 | 
         
            +
                device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                return device_map
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def generate_task_results(
         
     | 
| 49 | 
         
            +
                    task_name,
         
     | 
| 50 | 
         
            +
                    task_cfg,
         
     | 
| 51 | 
         
            +
                    args,
         
     | 
| 52 | 
         
            +
                    model,
         
     | 
| 53 | 
         
            +
                    tokenizer,
         
     | 
| 54 | 
         
            +
                    generation_config,
         
     | 
| 55 | 
         
            +
                    dataloader,
         
     | 
| 56 | 
         
            +
                    results_save_dir,
         
     | 
| 57 | 
         
            +
            ):
         
     | 
| 58 | 
         
            +
                results = []
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                if "prompt" in task_cfg:
         
     | 
| 61 | 
         
            +
                    prompt = task_cfg["prompt"]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                if task_name == "mmmu" or task_name == 'mmmu_pro':
         
     | 
| 64 | 
         
            +
                    for i, (img, question_id, subfield, question_type, question, answer, index2ans, all_choices) in enumerate(
         
     | 
| 65 | 
         
            +
                            dataloader):
         
     | 
| 66 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 67 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 68 | 
         
            +
                        question_id = question_id[0]
         
     | 
| 69 | 
         
            +
                        question_type = question_type[0]
         
     | 
| 70 | 
         
            +
                        question = question[0]
         
     | 
| 71 | 
         
            +
                        answer = answer[0]
         
     | 
| 72 | 
         
            +
                        subfield = subfield[0]
         
     | 
| 73 | 
         
            +
                        index2ans = ast.literal_eval(index2ans[0])
         
     | 
| 74 | 
         
            +
                        all_choices = ast.literal_eval(all_choices[0])
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                        this_prompt = question
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                        if task_name == 'mmmu':
         
     | 
| 79 | 
         
            +
                            categories = list(CAT_SHORT2LONG.values())
         
     | 
| 80 | 
         
            +
                            new_sys_msg = "You are expert in {} ({}). Read the image and use your knowledge in {} ({}) to answer the question."
         
     | 
| 81 | 
         
            +
                            for c in categories:
         
     | 
| 82 | 
         
            +
                                if c in question_id:
         
     | 
| 83 | 
         
            +
                                    cat = c.lower().replace('_', ' ')
         
     | 
| 84 | 
         
            +
                                    new_sys_msg = new_sys_msg.format(cat, subfield, cat, subfield)
         
     | 
| 85 | 
         
            +
                                    break
         
     | 
| 86 | 
         
            +
                        else:
         
     | 
| 87 | 
         
            +
                            new_sys_msg = f"You are expert in {subfield}. Read the image and use your knowledge in {subfield} to answer the question."
         
     | 
| 88 | 
         
            +
                        model.system_message = new_sys_msg
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                        generated_answer = model.chat(tokenizer, img, question, generation_config)
         
     | 
| 91 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                        results.append({
         
     | 
| 94 | 
         
            +
                            "question_id": question_id,
         
     | 
| 95 | 
         
            +
                            "question_type": question_type,
         
     | 
| 96 | 
         
            +
                            "answer": generated_answer,
         
     | 
| 97 | 
         
            +
                            "gt_answer": answer,
         
     | 
| 98 | 
         
            +
                            "index2ans": index2ans,
         
     | 
| 99 | 
         
            +
                            "all_choices": all_choices
         
     | 
| 100 | 
         
            +
                        })
         
     | 
| 101 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                elif task_name == "coco_caption" or task_name == "flickr30k_caption":
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    for i, (image_id, img,) in enumerate(dataloader):
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 109 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, prompt, generation_config)
         
     | 
| 112 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                        results.append({"image_id": image_id.item(), "caption": generated_answer})
         
     | 
| 115 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                elif task_name in ["vqav2", "gqa", "okvqa", "textvqa", "chartqa", "docvqa", "realworldqa",
         
     | 
| 118 | 
         
            +
                                   "ai2diagram", "ai2diagram_nomask", "docvqa_test"]:
         
     | 
| 119 | 
         
            +
                    for i, (img, question_id, question, answer,) in enumerate(dataloader):
         
     | 
| 120 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 121 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                        # Only works for batch size = 1
         
     | 
| 124 | 
         
            +
                        # need to implement collate function when we use bs > 1 in the future
         
     | 
| 125 | 
         
            +
                        question = question[0]
         
     | 
| 126 | 
         
            +
                        if task_name == 'ai2diagram' or task_name == 'ai2diagram_nomask':
         
     | 
| 127 | 
         
            +
                            question_id = question_id[0]
         
     | 
| 128 | 
         
            +
                        else:
         
     | 
| 129 | 
         
            +
                            question_id = question_id[0].item()
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                        if type(answer) == list:
         
     | 
| 132 | 
         
            +
                            answer = [ans[0] for ans in answer]
         
     | 
| 133 | 
         
            +
                        else:
         
     | 
| 134 | 
         
            +
                            answer = [answer[0]]
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                        # Need to change if using batch size > 1 in the future
         
     | 
| 137 | 
         
            +
                        this_prompt = prompt.format(question)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
         
     | 
| 140 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                        results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer})
         
     | 
| 143 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                elif task_name == "ocrbench":
         
     | 
| 147 | 
         
            +
                    for i, (img, question_id, question, answer, dataset_name, data_type,) in enumerate(
         
     | 
| 148 | 
         
            +
                            dataloader):
         
     | 
| 149 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 150 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                        question_id = question_id[0]
         
     | 
| 153 | 
         
            +
                        question = question[0]
         
     | 
| 154 | 
         
            +
                        answer = answer[0]
         
     | 
| 155 | 
         
            +
                        dataset_name = dataset_name[0]
         
     | 
| 156 | 
         
            +
                        data_type = data_type[0]
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                        this_prompt = prompt.format(question)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
         
     | 
| 161 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                        results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer,
         
     | 
| 164 | 
         
            +
                                        "dataset_name": dataset_name, "type": data_type})
         
     | 
| 165 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                elif task_name == "mathvista":
         
     | 
| 169 | 
         
            +
                    for i, (
         
     | 
| 170 | 
         
            +
                            img, question_id, question_type, question, answer, index2ans, all_choices,) in enumerate(
         
     | 
| 171 | 
         
            +
                        dataloader):
         
     | 
| 172 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 173 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        question_id = question_id[0]
         
     | 
| 176 | 
         
            +
                        question = question[0]
         
     | 
| 177 | 
         
            +
                        question_type = question_type[0]
         
     | 
| 178 | 
         
            +
                        answer = answer[0]
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        index2ans = ast.literal_eval(index2ans[0])
         
     | 
| 181 | 
         
            +
                        all_choices = ast.literal_eval(all_choices[0])
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        this_prompt = prompt.format(question)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
         
     | 
| 186 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                        results.append({
         
     | 
| 189 | 
         
            +
                            "question_id": question_id,
         
     | 
| 190 | 
         
            +
                            "question_type": question_type,
         
     | 
| 191 | 
         
            +
                            "answer": generated_answer,
         
     | 
| 192 | 
         
            +
                            "gt_answer": answer,
         
     | 
| 193 | 
         
            +
                            "index2ans": index2ans,
         
     | 
| 194 | 
         
            +
                            "all_choices": all_choices
         
     | 
| 195 | 
         
            +
                        })
         
     | 
| 196 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                elif task_name == "mmbench":
         
     | 
| 200 | 
         
            +
                    for i, (
         
     | 
| 201 | 
         
            +
                            img, question_id, question, answer, index2ans, all_choices, original_q,) in enumerate(
         
     | 
| 202 | 
         
            +
                        dataloader):
         
     | 
| 203 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 204 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        question_id = question_id[0]
         
     | 
| 207 | 
         
            +
                        question = question[0]
         
     | 
| 208 | 
         
            +
                        answer = answer[0]
         
     | 
| 209 | 
         
            +
                        original_q = original_q[0]
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                        index2ans = ast.literal_eval(index2ans[0])
         
     | 
| 212 | 
         
            +
                        all_choices = ast.literal_eval(all_choices[0])
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                        this_prompt = prompt.format(question)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
         
     | 
| 217 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                        results.append({
         
     | 
| 220 | 
         
            +
                            "question_id": question_id.item(),
         
     | 
| 221 | 
         
            +
                            "question": original_q,
         
     | 
| 222 | 
         
            +
                            "answer": generated_answer,
         
     | 
| 223 | 
         
            +
                            "gt_answer": answer,
         
     | 
| 224 | 
         
            +
                            "index2ans": index2ans,
         
     | 
| 225 | 
         
            +
                            "all_choices": all_choices
         
     | 
| 226 | 
         
            +
                        })
         
     | 
| 227 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                elif task_name == "vizwiz":
         
     | 
| 231 | 
         
            +
                    for i, (img, question_id, question,) in enumerate(dataloader):
         
     | 
| 232 | 
         
            +
                        if img.dim() == 5:
         
     | 
| 233 | 
         
            +
                            img = img.squeeze(0)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                        question = question[0]
         
     | 
| 236 | 
         
            +
                        question_id = question_id[0]
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                        this_prompt = prompt.format(question)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                        generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
         
     | 
| 241 | 
         
            +
                        print("generated_answer", generated_answer)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                        results.append({"image": question_id, "answer": generated_answer})
         
     | 
| 244 | 
         
            +
                        print("idx:", i, results[-1])
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                else:
         
     | 
| 247 | 
         
            +
                    raise NotImplementedError(f"Task {task_name} is not supported yet.")
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                os.makedirs(results_save_dir, exist_ok=True)
         
     | 
| 250 | 
         
            +
                if args.subset is not None:
         
     | 
| 251 | 
         
            +
                    results_save_path = os.path.join(results_save_dir,
         
     | 
| 252 | 
         
            +
                                                     f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
         
     | 
| 253 | 
         
            +
                else:
         
     | 
| 254 | 
         
            +
                    results_save_path = os.path.join(results_save_dir,
         
     | 
| 255 | 
         
            +
                                                     f"eval_{task_name}.json")
         
     | 
| 256 | 
         
            +
                print("Saving to ", results_save_path)
         
     | 
| 257 | 
         
            +
                json.dump(results, open(results_save_path, 'w'))
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            def calc_task_metrics(args, task_name, task_cfg, results_save_dir):
         
     | 
| 261 | 
         
            +
                if args.subset is not None:
         
     | 
| 262 | 
         
            +
                    merged_results_path = os.path.join(results_save_dir,
         
     | 
| 263 | 
         
            +
                                                       f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
         
     | 
| 264 | 
         
            +
                else:
         
     | 
| 265 | 
         
            +
                    merged_results_path = os.path.join(results_save_dir,
         
     | 
| 266 | 
         
            +
                                                       f"eval_{task_name}.json")
         
     | 
| 267 | 
         
            +
                merged_results = json.load(open(merged_results_path, "r"))
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                if task_name == "coco_caption" or task_name == "flickr30k_caption":
         
     | 
| 270 | 
         
            +
                    # Calculate scores
         
     | 
| 271 | 
         
            +
                    coco = COCO(task_cfg["gt_path"])
         
     | 
| 272 | 
         
            +
                    coco_result = coco.loadRes(merged_results_path)
         
     | 
| 273 | 
         
            +
                    coco_eval = COCOEvalCap(coco, coco_result)
         
     | 
| 274 | 
         
            +
                    coco_eval.params["image_id"] = coco_result.getImgIds()
         
     | 
| 275 | 
         
            +
                    coco_eval.evaluate()
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    # Print and save scores
         
     | 
| 278 | 
         
            +
                    print(f"====== {task_name} scores ======")
         
     | 
| 279 | 
         
            +
                    with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
         
     | 
| 280 | 
         
            +
                        f.write(f"{task_name} scores:\n")
         
     | 
| 281 | 
         
            +
                        for k, v in coco_eval.eval.items():
         
     | 
| 282 | 
         
            +
                            msg = f"{k} = {v * 100:.3f}"
         
     | 
| 283 | 
         
            +
                            print(msg)
         
     | 
| 284 | 
         
            +
                            f.write(msg + "\n")
         
     | 
| 285 | 
         
            +
                        f.write("\n")
         
     | 
| 286 | 
         
            +
                    return coco_eval.eval["CIDEr"]
         
     | 
| 287 | 
         
            +
                elif task_name == "vqav2" or task_name == "okvqa":
         
     | 
| 288 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 289 | 
         
            +
                    all_acc = []
         
     | 
| 290 | 
         
            +
                    for res in merged_results:
         
     | 
| 291 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 292 | 
         
            +
                        pred_ans = vqa_tool.processPunctuation(pred_ans)
         
     | 
| 293 | 
         
            +
                        pred_ans = vqa_tool.processDigitArticle(pred_ans)
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                        gt_ans = res["gt_answer"]
         
     | 
| 296 | 
         
            +
                        gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
         
     | 
| 297 | 
         
            +
                        gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                        num_match = sum([pred_ans == ans for ans in gt_ans])
         
     | 
| 300 | 
         
            +
                        acc = min(1.0, num_match / 3.0)
         
     | 
| 301 | 
         
            +
                        all_acc.append(acc)
         
     | 
| 302 | 
         
            +
                    acc_avg = sum(all_acc) / len(all_acc) * 100
         
     | 
| 303 | 
         
            +
                    print(f"===== {task_name} Accuracy {acc_avg:.2f}% =====")
         
     | 
| 304 | 
         
            +
                    with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
         
     | 
| 305 | 
         
            +
                        f.write(f"{task_name} Accuracy = {acc_avg:.2f}%\n\n")
         
     | 
| 306 | 
         
            +
                    return acc_avg
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                elif task_name == "textvqa":
         
     | 
| 309 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 310 | 
         
            +
                    all_acc = []
         
     | 
| 311 | 
         
            +
                    for res in merged_results:
         
     | 
| 312 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 313 | 
         
            +
                        pred_ans = vqa_tool.processPunctuation(pred_ans)
         
     | 
| 314 | 
         
            +
                        pred_ans = vqa_tool.processDigitArticle(pred_ans)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                        gt_ans = res["gt_answer"]
         
     | 
| 317 | 
         
            +
                        gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
         
     | 
| 318 | 
         
            +
                        gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                        num_match = sum([pred_ans == ans for ans in gt_ans])
         
     | 
| 321 | 
         
            +
                        acc = min(1.0, num_match / 3.0)
         
     | 
| 322 | 
         
            +
                        all_acc.append(acc)
         
     | 
| 323 | 
         
            +
                    acc_avg = sum(all_acc) / len(all_acc) * 100
         
     | 
| 324 | 
         
            +
                    print(
         
     | 
| 325 | 
         
            +
                        f"===== {task_name} Accuracy {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy) =====")
         
     | 
| 326 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 327 | 
         
            +
                        f.write(
         
     | 
| 328 | 
         
            +
                            f"{task_name} Accuracy = {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy)\n\n")
         
     | 
| 329 | 
         
            +
                    return acc_avg
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                elif task_name in ["gqa", "realworldqa", "ai2diagram", "ai2diagram_nomask"]:
         
     | 
| 332 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 333 | 
         
            +
                    acc = 0
         
     | 
| 334 | 
         
            +
                    for res in merged_results:
         
     | 
| 335 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 336 | 
         
            +
                        pred_ans = vqa_tool.processPunctuation(pred_ans)
         
     | 
| 337 | 
         
            +
                        pred_ans = vqa_tool.processDigitArticle(pred_ans)
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                        gt_ans = res["gt_answer"][0]
         
     | 
| 340 | 
         
            +
                        gt_ans = vqa_tool.processPunctuation(gt_ans)
         
     | 
| 341 | 
         
            +
                        gt_ans = vqa_tool.processDigitArticle(gt_ans)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                        if pred_ans == gt_ans:
         
     | 
| 344 | 
         
            +
                            acc += 1
         
     | 
| 345 | 
         
            +
                    acc = acc / len(merged_results) * 100
         
     | 
| 346 | 
         
            +
                    print(f"===== {task_name} Accuracy {acc:.2f}% =====")
         
     | 
| 347 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 348 | 
         
            +
                        f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
         
     | 
| 349 | 
         
            +
                    return acc
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                elif task_name == "docvqa" or task_name == "docvqa_test":
         
     | 
| 352 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 353 | 
         
            +
                    anls = 0
         
     | 
| 354 | 
         
            +
                    for res in merged_results:
         
     | 
| 355 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 356 | 
         
            +
                        gt_ans = res["gt_answer"]
         
     | 
| 357 | 
         
            +
                        anls += get_anls_score(pred=pred_ans, gold_labels=gt_ans, threshold=0.5)
         
     | 
| 358 | 
         
            +
                    anls = anls / len(merged_results) * 100
         
     | 
| 359 | 
         
            +
                    print(f"===== {task_name} ANLS {anls:.2f}% =====")
         
     | 
| 360 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 361 | 
         
            +
                        f.write(f" {task_name} ANLS = {anls:.2f}%\n\n")
         
     | 
| 362 | 
         
            +
                    return anls
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                elif task_name == "chartqa":
         
     | 
| 365 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 366 | 
         
            +
                    acc = 0
         
     | 
| 367 | 
         
            +
                    for res in merged_results:
         
     | 
| 368 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 369 | 
         
            +
                        pred_ans = vqa_tool.processPunctuation(pred_ans)
         
     | 
| 370 | 
         
            +
                        pred_ans = vqa_tool.processDigitArticle(pred_ans)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                        gt_ans = res["gt_answer"][0]
         
     | 
| 373 | 
         
            +
                        gt_ans = vqa_tool.processPunctuation(gt_ans)
         
     | 
| 374 | 
         
            +
                        gt_ans = vqa_tool.processDigitArticle(gt_ans)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                        # ChartQA uses relaxed accuracy:
         
     | 
| 377 | 
         
            +
                        # "We consider an answer to be correct if it is within 5% of the gold answer.
         
     | 
| 378 | 
         
            +
                        #  For non-numeric answers, we still need an exact match to consider an answer to be correct."
         
     | 
| 379 | 
         
            +
                        if isNumber(pred_ans) and isNumber(gt_ans):
         
     | 
| 380 | 
         
            +
                            pred_ans = float(pred_ans)
         
     | 
| 381 | 
         
            +
                            gt_ans = float(gt_ans)
         
     | 
| 382 | 
         
            +
                            if pred_ans >= (gt_ans * 0.95) and pred_ans <= (gt_ans * 1.05):
         
     | 
| 383 | 
         
            +
                                acc += 1
         
     | 
| 384 | 
         
            +
                        elif pred_ans == gt_ans:
         
     | 
| 385 | 
         
            +
                            acc += 1
         
     | 
| 386 | 
         
            +
                    acc = acc / len(merged_results) * 100
         
     | 
| 387 | 
         
            +
                    print(f"===== {task_name} Accuracy {acc:.2f}% =====")
         
     | 
| 388 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 389 | 
         
            +
                        f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
         
     | 
| 390 | 
         
            +
                    return acc
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                elif task_name == 'ocrbench':
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    from collections import defaultdict
         
     | 
| 395 | 
         
            +
                    OCRBench_score = {"Regular Text Recognition": 0, "Irregular Text Recognition": 0,
         
     | 
| 396 | 
         
            +
                                      "Artistic Text Recognition": 0, "Handwriting Recognition": 0,
         
     | 
| 397 | 
         
            +
                                      "Digit String Recognition": 0, "Non-Semantic Text Recognition": 0,
         
     | 
| 398 | 
         
            +
                                      "Scene Text-centric VQA": 0, "Doc-oriented VQA": 0, "Doc-oriented VQA": 0,
         
     | 
| 399 | 
         
            +
                                      "Key Information Extraction": 0, "Handwritten Mathematical Expression Recognition": 0}
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    for res in merged_results:
         
     | 
| 402 | 
         
            +
                        predict = res["answer"]
         
     | 
| 403 | 
         
            +
                        answers = res["gt_answer"]
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                        dataset_name = res["dataset_name"]
         
     | 
| 406 | 
         
            +
                        ocr_type = res["type"]
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                        # data[i]['result'] = 0
         
     | 
| 409 | 
         
            +
                        if dataset_name == "HME100k":
         
     | 
| 410 | 
         
            +
                            if type(answers) == list:
         
     | 
| 411 | 
         
            +
                                for j in range(len(answers)):
         
     | 
| 412 | 
         
            +
                                    answer = answers[j].strip().replace("\n", " ").replace(" ", "")
         
     | 
| 413 | 
         
            +
                                    predict = predict.strip().replace("\n", " ").replace(" ", "")
         
     | 
| 414 | 
         
            +
                                    if answer in predict:
         
     | 
| 415 | 
         
            +
                                        OCRBench_score[ocr_type] += 1
         
     | 
| 416 | 
         
            +
                            else:
         
     | 
| 417 | 
         
            +
                                answers = answers.strip().replace("\n", " ").replace(" ", "")
         
     | 
| 418 | 
         
            +
                                predict = predict.strip().replace("\n", " ").replace(" ", "")
         
     | 
| 419 | 
         
            +
                                if answers in predict:
         
     | 
| 420 | 
         
            +
                                    OCRBench_score[ocr_type] += 1
         
     | 
| 421 | 
         
            +
                        else:
         
     | 
| 422 | 
         
            +
                            if type(answers) == list:
         
     | 
| 423 | 
         
            +
                                for j in range(len(answers)):
         
     | 
| 424 | 
         
            +
                                    answer = answers[j].lower().strip().replace("\n", " ")
         
     | 
| 425 | 
         
            +
                                    predict = predict.lower().strip().replace("\n", " ")
         
     | 
| 426 | 
         
            +
                                    if answer in predict:
         
     | 
| 427 | 
         
            +
                                        OCRBench_score[ocr_type] += 1
         
     | 
| 428 | 
         
            +
                            else:
         
     | 
| 429 | 
         
            +
                                answers = answers.lower().strip().replace("\n", " ")
         
     | 
| 430 | 
         
            +
                                predict = predict.lower().strip().replace("\n", " ")
         
     | 
| 431 | 
         
            +
                                if answers in predict:
         
     | 
| 432 | 
         
            +
                                    OCRBench_score[ocr_type] += 1
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    recognition_score = OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition'] + \
         
     | 
| 435 | 
         
            +
                                        OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition'] + \
         
     | 
| 436 | 
         
            +
                                        OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
         
     | 
| 437 | 
         
            +
                    Final_score = recognition_score + OCRBench_score['Scene Text-centric VQA'] + OCRBench_score[
         
     | 
| 438 | 
         
            +
                        'Doc-oriented VQA'] + OCRBench_score['Key Information Extraction'] + OCRBench_score[
         
     | 
| 439 | 
         
            +
                                      'Handwritten Mathematical Expression Recognition']
         
     | 
| 440 | 
         
            +
                    result_log = f"\n###########################OCRBench##############################\n\
         
     | 
| 441 | 
         
            +
                        Text Recognition(Total 300):{recognition_score}\n\
         
     | 
| 442 | 
         
            +
                        ------------------Details of Recognition Score-------------------\n\
         
     | 
| 443 | 
         
            +
                        Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}\n\
         
     | 
| 444 | 
         
            +
                        Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}\n\
         
     | 
| 445 | 
         
            +
                        Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}\n\
         
     | 
| 446 | 
         
            +
                        Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}\n\
         
     | 
| 447 | 
         
            +
                        Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}\n\
         
     | 
| 448 | 
         
            +
                        Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}\n\
         
     | 
| 449 | 
         
            +
                        ----------------------------------------------------------------\n\
         
     | 
| 450 | 
         
            +
                        Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}\n\
         
     | 
| 451 | 
         
            +
                        ----------------------------------------------------------------\n\
         
     | 
| 452 | 
         
            +
                        Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}\n\
         
     | 
| 453 | 
         
            +
                        ----------------------------------------------------------------\n\
         
     | 
| 454 | 
         
            +
                        Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}\n\
         
     | 
| 455 | 
         
            +
                        ----------------------------------------------------------------\n\
         
     | 
| 456 | 
         
            +
                        Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}\n\
         
     | 
| 457 | 
         
            +
                        ----------------------Final Score-------------------------------\n\
         
     | 
| 458 | 
         
            +
                        Final Score(Total 1000): {Final_score}\n"
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    print(f"=====  {task_name} Final_score(Total 1000): {Final_score:.2f}% ===== {result_log}")
         
     | 
| 461 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 462 | 
         
            +
                        f.write(f" {task_name} Accuracy = {Final_score:.2f}% {result_log}\n\n")
         
     | 
| 463 | 
         
            +
                    return Final_score
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                elif task_name == "mathvista":
         
     | 
| 467 | 
         
            +
                    import re
         
     | 
| 468 | 
         
            +
                    def extra_processing(text):
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                        ## max decimal point capped to 2 decimal point
         
     | 
| 471 | 
         
            +
                        regex = re.compile(r'^\d+\.\d+$')
         
     | 
| 472 | 
         
            +
                        decimal = regex.findall(text)
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                        if len(decimal) > 0:
         
     | 
| 475 | 
         
            +
                            non_decimal = len(decimal[0].split(".")[0])
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                            # if decimal values are all 0, trim them
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                            decimal_digits = [int(d) for d in decimal[0].split(".")[1]]
         
     | 
| 480 | 
         
            +
                            if sum(decimal_digits) == 0:
         
     | 
| 481 | 
         
            +
                                text = decimal[0][:non_decimal]
         
     | 
| 482 | 
         
            +
                            else:
         
     | 
| 483 | 
         
            +
                                text = decimal[0][:non_decimal + 3]
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                        ## remove %
         
     | 
| 486 | 
         
            +
                        text = text.replace("%", "")
         
     | 
| 487 | 
         
            +
                        try:
         
     | 
| 488 | 
         
            +
                            if text[-1] == ".":
         
     | 
| 489 | 
         
            +
                                text = text[:-1]
         
     | 
| 490 | 
         
            +
                        except:
         
     | 
| 491 | 
         
            +
                            print(text)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                        return text
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    def extract_answer(text):
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                        alphabet = re.findall(r'[a-zA-Z]+', text)
         
     | 
| 498 | 
         
            +
                        if len(alphabet) > 0 and "e+" not in text:
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                            template_1 = re.findall(r'answer is -*\d+\.*\d*', text)
         
     | 
| 501 | 
         
            +
                            if len(template_1) > 0:
         
     | 
| 502 | 
         
            +
                                text = template_1[0]
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                                numbers = re.findall(r'-*\d+\.*\d*', text)
         
     | 
| 505 | 
         
            +
                                text = numbers[0] if len(numbers) > 0 else text
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                        return text
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                    vqa_tool = VQAEval()
         
     | 
| 510 | 
         
            +
                    acc = 0
         
     | 
| 511 | 
         
            +
                    for res in merged_results:
         
     | 
| 512 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 513 | 
         
            +
                        if res['question_type'] == 'multi_choice':
         
     | 
| 514 | 
         
            +
                            pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
         
     | 
| 515 | 
         
            +
                        else:
         
     | 
| 516 | 
         
            +
                            pred_ans = vqa_tool.processPunctuation(pred_ans)
         
     | 
| 517 | 
         
            +
                            pred_ans = vqa_tool.processDigitArticle(pred_ans)
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                            pred_ans = extra_processing(pred_ans)
         
     | 
| 520 | 
         
            +
                            pred_ans = extract_answer(pred_ans)
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                        gt_ans = res["gt_answer"]
         
     | 
| 523 | 
         
            +
                        if res['question_type'] != 'multi_choice':
         
     | 
| 524 | 
         
            +
                            gt_ans = vqa_tool.processPunctuation(gt_ans)
         
     | 
| 525 | 
         
            +
                            gt_ans = vqa_tool.processDigitArticle(gt_ans)
         
     | 
| 526 | 
         
            +
                            gt_ans = extra_processing(gt_ans)
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                        if pred_ans == gt_ans:
         
     | 
| 529 | 
         
            +
                            acc += 1
         
     | 
| 530 | 
         
            +
                    acc = acc / len(merged_results) * 100
         
     | 
| 531 | 
         
            +
                    print(f"===== {task_name} Accuracy {acc:.2f}% =====")
         
     | 
| 532 | 
         
            +
                    with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 533 | 
         
            +
                        f.write(f" {task_name} Accuracy = {acc:.2f}%\n\n")
         
     | 
| 534 | 
         
            +
                    return acc
         
     | 
| 535 | 
         
            +
                elif task_name == "mmbench":
         
     | 
| 536 | 
         
            +
                    if task_cfg['split'] == 'dev':
         
     | 
| 537 | 
         
            +
                        acc = 0
         
     | 
| 538 | 
         
            +
                        for res in merged_results:
         
     | 
| 539 | 
         
            +
                            gt_ans = res["gt_answer"]
         
     | 
| 540 | 
         
            +
                            pred_ans = res["answer"]
         
     | 
| 541 | 
         
            +
                            pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                            if pred_ans == gt_ans:
         
     | 
| 544 | 
         
            +
                                acc += 1
         
     | 
| 545 | 
         
            +
                        acc = acc / len(merged_results) * 100
         
     | 
| 546 | 
         
            +
                        print(f"===== {task_name} Accuracy {acc:.2f}% =====")
         
     | 
| 547 | 
         
            +
                        with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
         
     | 
| 548 | 
         
            +
                            f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                        return acc
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                    # Generate submission.xlsx for test set online evaluation, "https://mmbench.opencompass.org.cn/mmbench-submission"
         
     | 
| 553 | 
         
            +
                    if task_cfg['submission']:
         
     | 
| 554 | 
         
            +
                        res_df = pd.DataFrame(
         
     | 
| 555 | 
         
            +
                            {
         
     | 
| 556 | 
         
            +
                                "index": [r['question_id'] for r in merged_results],
         
     | 
| 557 | 
         
            +
                                "question": [r['question'] for r in merged_results],
         
     | 
| 558 | 
         
            +
                                "A": [r['index2ans']['A'] if 'A' in r['index2ans'] else None for r in merged_results],
         
     | 
| 559 | 
         
            +
                                "B": [r['index2ans']['B'] if 'B' in r['index2ans'] else None for r in merged_results],
         
     | 
| 560 | 
         
            +
                                "C": [r['index2ans']['C'] if 'C' in r['index2ans'] else None for r in merged_results],
         
     | 
| 561 | 
         
            +
                                "D": [r['index2ans']['D'] if 'D' in r['index2ans'] else None for r in merged_results],
         
     | 
| 562 | 
         
            +
                                "prediction": [parse_multi_choice_response(r['answer'], r['all_choices'], r['index2ans']) for r in
         
     | 
| 563 | 
         
            +
                                               merged_results],
         
     | 
| 564 | 
         
            +
                            },
         
     | 
| 565 | 
         
            +
                            columns=['index', 'question', 'A', 'B', 'C', 'D', 'prediction']
         
     | 
| 566 | 
         
            +
                        )
         
     | 
| 567 | 
         
            +
                        res_df.to_excel(os.path.join(results_save_dir, "submission.xlsx"))
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                elif task_name == "vizwiz":
         
     | 
| 571 | 
         
            +
                    print(
         
     | 
| 572 | 
         
            +
                        f"VizWiz result file is saved at: {merged_results_path}\n"
         
     | 
| 573 | 
         
            +
                        f"Upload manually or use this CLI `python evaluation/upload_vizwiz.py --result {merged_results_path} --token <your_eval_user_token>`."
         
     | 
| 574 | 
         
            +
                    )
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                elif task_name == "mmmu" or task_name == 'mmmu_pro':
         
     | 
| 577 | 
         
            +
                    def extract_answer(text):
         
     | 
| 578 | 
         
            +
                        import re
         
     | 
| 579 | 
         
            +
                        # Regular expression to find content inside \answer{xxx}
         
     | 
| 580 | 
         
            +
                        match = re.search(r'\\answer\{(.*?)\}', text)
         
     | 
| 581 | 
         
            +
                        if match:
         
     | 
| 582 | 
         
            +
                            return match.group(1)  # Return the content inside the braces
         
     | 
| 583 | 
         
            +
                        return text  # Return the original string if no match is found
         
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                    eval_samples = []
         
     | 
| 586 | 
         
            +
                    eval_output_dict = {}
         
     | 
| 587 | 
         
            +
                    for res in merged_results:
         
     | 
| 588 | 
         
            +
                        pred_ans = res["answer"]
         
     | 
| 589 | 
         
            +
                        gt_ans = res['gt_answer']
         
     | 
| 590 | 
         
            +
                        if res['question_type'] == 'multiple-choice':
         
     | 
| 591 | 
         
            +
                            parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
         
     | 
| 592 | 
         
            +
                            eval_samples.append(
         
     | 
| 593 | 
         
            +
                                {
         
     | 
| 594 | 
         
            +
                                    'id': res['question_id'],
         
     | 
| 595 | 
         
            +
                                    'question_type': res['question_type'],
         
     | 
| 596 | 
         
            +
                                    'answer': res['gt_answer'],  # the content in option, not answer index.
         
     | 
| 597 | 
         
            +
                                    'response': pred_ans,
         
     | 
| 598 | 
         
            +
                                    'parsed_pred': parsed_pred,
         
     | 
| 599 | 
         
            +
                                    'index2ans': res['index2ans'],
         
     | 
| 600 | 
         
            +
                                }
         
     | 
| 601 | 
         
            +
                            )
         
     | 
| 602 | 
         
            +
                            eval_output_dict[res['question_id']] = parsed_pred
         
     | 
| 603 | 
         
            +
                        else:
         
     | 
| 604 | 
         
            +
                            pred_ans = extract_answer(pred_ans)  # for sft v9prov3, we observe answers are within \answer{xxx}
         
     | 
| 605 | 
         
            +
                            parsed_pred = parse_open_response(pred_ans)
         
     | 
| 606 | 
         
            +
                            eval_samples.append(
         
     | 
| 607 | 
         
            +
                                {
         
     | 
| 608 | 
         
            +
                                    'id': res['question_id'],
         
     | 
| 609 | 
         
            +
                                    'question_type': res['question_type'],
         
     | 
| 610 | 
         
            +
                                    'answer': res['gt_answer'],
         
     | 
| 611 | 
         
            +
                                    'response': pred_ans,
         
     | 
| 612 | 
         
            +
                                    'parsed_pred': parsed_pred,
         
     | 
| 613 | 
         
            +
                                }
         
     | 
| 614 | 
         
            +
                            )
         
     | 
| 615 | 
         
            +
                            eval_output_dict[res['question_id']] = pred_ans
         
     | 
| 616 | 
         
            +
                    if args.subset is not None:
         
     | 
| 617 | 
         
            +
                        eval_output_dict_path = os.path.join(results_save_dir,
         
     | 
| 618 | 
         
            +
                                                             f"eval_output_dict_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
         
     | 
| 619 | 
         
            +
                    else:
         
     | 
| 620 | 
         
            +
                        eval_output_dict_path = os.path.join(results_save_dir, f"eval_output_dict_{task_name}.json")
         
     | 
| 621 | 
         
            +
                    json.dump(eval_output_dict, open(eval_output_dict_path, "w"), indent=4, sort_keys=True)
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
                    mmmu_results = mmmu_main_eval(eval_output_dict, task_cfg)
         
     | 
| 624 | 
         
            +
                    with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
         
     | 
| 625 | 
         
            +
                        f.write(f"{task_name} {task_cfg['split']}:\n")
         
     | 
| 626 | 
         
            +
                        for cat, cat_val in mmmu_results.items():
         
     | 
| 627 | 
         
            +
                            if 'Overall' in cat:
         
     | 
| 628 | 
         
            +
                                cat = cat.replace("Overall-", "")
         
     | 
| 629 | 
         
            +
                                print(f'{cat}: {cat_val["acc"] * 100:.2f}')
         
     | 
| 630 | 
         
            +
                                f.write(f'{cat}: {cat_val["acc"] * 100:.2f}\n')
         
     | 
| 631 | 
         
            +
                    return mmmu_results['Overall']['acc']
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                else:
         
     | 
| 634 | 
         
            +
                    raise NotImplementedError(f"Task {task_name} is not supported yet.")
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
            def evaluate(args, model, tokenizer, tasks):
         
     | 
| 638 | 
         
            +
                start_time = time.time()
         
     | 
| 639 | 
         
            +
             
     | 
| 640 | 
         
            +
                results_save_dir = os.path.join(args.result_save_path)
         
     | 
| 641 | 
         
            +
                if not os.path.exists(results_save_dir):
         
     | 
| 642 | 
         
            +
                    os.makedirs(results_save_dir)
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
                task_names = list(tasks.keys())
         
     | 
| 645 | 
         
            +
             
     | 
| 646 | 
         
            +
                print(f"Evaluating tasks: {task_names}")
         
     | 
| 647 | 
         
            +
             
     | 
| 648 | 
         
            +
                with torch.no_grad():
         
     | 
| 649 | 
         
            +
                    # for task_name in task_names:
         
     | 
| 650 | 
         
            +
                    for task_name in args.zero_shot_eval_tasks:
         
     | 
| 651 | 
         
            +
                        task_cfg = tasks[task_name]
         
     | 
| 652 | 
         
            +
                        print(f"Preparing dataloader for {task_name}...")
         
     | 
| 653 | 
         
            +
             
     | 
| 654 | 
         
            +
                        dataloader = get_task_dataloader(task_name, task_cfg, args)
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                        generation_config = dict(max_new_tokens=1024, do_sample=False)
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                        print("Start generating...")
         
     | 
| 659 | 
         
            +
                        generate_task_results(task_name, task_cfg, args, model, tokenizer, generation_config,
         
     | 
| 660 | 
         
            +
                                              dataloader, results_save_dir)
         
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
                        print("Start calculating task metric...")
         
     | 
| 663 | 
         
            +
                        calc_task_metrics(args, task_name, task_cfg, results_save_dir)
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                end_time = time.time()
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
                print(f"Evaluation takes {(end_time - start_time) / 60:.1f} minutes in total!")
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 671 | 
         
            +
                path = "nvidia/NVLM-D-72B"
         
     | 
| 672 | 
         
            +
                print("Loading model... from", path)
         
     | 
| 673 | 
         
            +
                device_map = split_model()
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                start = time.time()
         
     | 
| 676 | 
         
            +
                model = AutoModel.from_pretrained(
         
     | 
| 677 | 
         
            +
                    path,
         
     | 
| 678 | 
         
            +
                    torch_dtype=torch.bfloat16,
         
     | 
| 679 | 
         
            +
                    low_cpu_mem_usage=True,
         
     | 
| 680 | 
         
            +
                    use_flash_attn=True,
         
     | 
| 681 | 
         
            +
                    device_map=device_map,
         
     | 
| 682 | 
         
            +
                    trust_remote_code=True).eval()
         
     | 
| 683 | 
         
            +
                end = time.time()
         
     | 
| 684 | 
         
            +
                print("loading model takes:", end - start)
         
     | 
| 685 | 
         
            +
             
     | 
| 686 | 
         
            +
                print(model)
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 691 | 
         
            +
                parser.add_argument('--config-path', type=str, required=True,
         
     | 
| 692 | 
         
            +
                                    help='YAML file configuring evaluation datasets')
         
     | 
| 693 | 
         
            +
                parser.add_argument('--result-save-path', type=str, default=os.path.join(path, '/eval_results'))
         
     | 
| 694 | 
         
            +
                parser.add_argument('--zero-shot-eval-tasks', nargs='+', type=str, default=['mmmu'])
         
     | 
| 695 | 
         
            +
                parser.add_argument('--start-idx', type=int, default=0)
         
     | 
| 696 | 
         
            +
                parser.add_argument('--subset', type=int, default=None)
         
     | 
| 697 | 
         
            +
             
     | 
| 698 | 
         
            +
                args = parser.parse_args()
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                tasks = yaml.safe_load(open(args.config_path))['datasets']
         
     | 
| 701 | 
         
            +
             
     | 
| 702 | 
         
            +
                evaluate(args, model, tokenizer, tasks)
         
     |