Update scripts/eval_mteb.py
Browse files- scripts/eval_mteb.py +21 -7
scripts/eval_mteb.py
CHANGED
|
@@ -119,7 +119,6 @@ CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'Onl
|
|
| 119 |
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
| 120 |
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
|
| 121 |
|
| 122 |
-
|
| 123 |
MTEB_PL = [
|
| 124 |
"CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
|
| 125 |
"SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
|
|
@@ -406,9 +405,9 @@ class Wrapper:
|
|
| 406 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 407 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
| 408 |
self.instruction = instruction
|
| 409 |
-
self.default_query = default_query
|
|
|
|
| 410 |
self.force_default = force_default
|
| 411 |
-
|
| 412 |
if self.tokenizer.padding_side != 'right':
|
| 413 |
logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
|
| 414 |
self.tokenizer.padding_side = 'right'
|
|
@@ -675,13 +674,15 @@ class Wrapper:
|
|
| 675 |
def main(args):
|
| 676 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 677 |
encoder = Encoder(args.model, args.pooling)
|
|
|
|
| 678 |
model = Wrapper(
|
| 679 |
tokenizer, encoder,
|
| 680 |
batch_size=args.batch_size,
|
| 681 |
max_seq_len=args.max_seq_len,
|
| 682 |
-
normalize_embeddings=args.norm
|
|
|
|
| 683 |
)
|
| 684 |
-
|
| 685 |
if args.task == 'mteb':
|
| 686 |
task_names = MTEB_TASK_LIST
|
| 687 |
lang = ['en']
|
|
@@ -709,8 +710,21 @@ def main(args):
|
|
| 709 |
eval_splits = task_cls.description['eval_splits']
|
| 710 |
else:
|
| 711 |
eval_splits = ["test"]
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
print('\n')
|
| 715 |
|
| 716 |
|
|
@@ -729,4 +743,4 @@ if __name__ == "__main__":
|
|
| 729 |
)
|
| 730 |
_PARSER.add_argument("--norm", action="store_true")
|
| 731 |
_ARGS = _PARSER.parse_args()
|
| 732 |
-
main(_ARGS)
|
|
|
|
| 119 |
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
| 120 |
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
|
| 121 |
|
|
|
|
| 122 |
MTEB_PL = [
|
| 123 |
"CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
|
| 124 |
"SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
|
|
|
|
| 405 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 406 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
| 407 |
self.instruction = instruction
|
| 408 |
+
self.default_query = default_query
|
| 409 |
+
self.sep = sep
|
| 410 |
self.force_default = force_default
|
|
|
|
| 411 |
if self.tokenizer.padding_side != 'right':
|
| 412 |
logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
|
| 413 |
self.tokenizer.padding_side = 'right'
|
|
|
|
| 674 |
def main(args):
|
| 675 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 676 |
encoder = Encoder(args.model, args.pooling)
|
| 677 |
+
default_query = args.default_type == 'query'
|
| 678 |
model = Wrapper(
|
| 679 |
tokenizer, encoder,
|
| 680 |
batch_size=args.batch_size,
|
| 681 |
max_seq_len=args.max_seq_len,
|
| 682 |
+
normalize_embeddings=args.norm,
|
| 683 |
+
default_query=default_query
|
| 684 |
)
|
| 685 |
+
sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
|
| 686 |
if args.task == 'mteb':
|
| 687 |
task_names = MTEB_TASK_LIST
|
| 688 |
lang = ['en']
|
|
|
|
| 710 |
eval_splits = task_cls.description['eval_splits']
|
| 711 |
else:
|
| 712 |
eval_splits = ["test"]
|
| 713 |
+
sym = False
|
| 714 |
+
for name in sym_retrievals:
|
| 715 |
+
if task.startswith(name):
|
| 716 |
+
sym = True
|
| 717 |
+
break
|
| 718 |
+
else:
|
| 719 |
+
sym = False
|
| 720 |
+
if sym:
|
| 721 |
+
logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
|
| 722 |
+
model.force_default = True
|
| 723 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
| 724 |
+
|
| 725 |
+
if sym:
|
| 726 |
+
logger.info(f"Switch back.")
|
| 727 |
+
model.force_default = force_default_ori
|
| 728 |
print('\n')
|
| 729 |
|
| 730 |
|
|
|
|
| 743 |
)
|
| 744 |
_PARSER.add_argument("--norm", action="store_true")
|
| 745 |
_ARGS = _PARSER.parse_args()
|
| 746 |
+
main(_ARGS)
|