Upload 4 files
Browse files- evaluation/README.md +8 -0
- evaluation/arguments.py +31 -0
- evaluation/dataset.py +74 -0
- evaluation/evaluate.py +102 -0
evaluation/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
### Commands for running evaluation
|
| 3 |
+
|
| 4 |
+
```console
|
| 5 |
+
python evaluate.py --eval-dataset doc2dial
|
| 6 |
+
python evaluate.py --eval-dataset quac
|
| 7 |
+
python evaluate.py --eval-dataset qrecc
|
| 8 |
+
```
|
evaluation/arguments.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def get_args():
|
| 6 |
+
parser = argparse.ArgumentParser(description="Dragon-multiturn")
|
| 7 |
+
|
| 8 |
+
parser.add_argument('--query-encoder-path', type=str, default='nvidia/dragon-multiturn-query-encoder')
|
| 9 |
+
parser.add_argument('--context-encoder-path', type=str, default='nvidia/dragon-multiturn-context-encoder')
|
| 10 |
+
|
| 11 |
+
parser.add_argument('--data-folder', type=str, default='', help='path to the datafolder of ChatRAG Bench')
|
| 12 |
+
parser.add_argument('--eval-dataset', type=str, default='', help='evaluation dataset (e.g., doc2dial)')
|
| 13 |
+
|
| 14 |
+
parser.add_argument('--doc2dial-datapath', type=str, default='doc2dial/test.json')
|
| 15 |
+
parser.add_argument('--doc2dial-docpath', type=str, default='doc2dial/documents.json')
|
| 16 |
+
|
| 17 |
+
parser.add_argument('--quac-datapath', type=str, default='quac/test.json')
|
| 18 |
+
parser.add_argument('--quac-docpath', type=str, default='quac/documents.json')
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--qrecc-datapath', type=str, default='qrecc/test.json')
|
| 21 |
+
parser.add_argument('--qrecc-docpath', type=str, default='qrecc/documents.json')
|
| 22 |
+
|
| 23 |
+
parser.add_argument('--topiocqa-datapath', type=str, default='topiocqa/dev.json')
|
| 24 |
+
parser.add_argument('--topiocqa-docpath', type=str, default='')
|
| 25 |
+
|
| 26 |
+
parser.add_argument('--inscit-datapath', type=str, default='inscit/dev.json')
|
| 27 |
+
parser.add_argument('--inscit-docpath', type=str, default='')
|
| 28 |
+
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
return args
|
evaluation/dataset.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_query(messages, num_turns=5):
|
| 6 |
+
## convert query into a format as follows:
|
| 7 |
+
## user: {user}\nagent: {agent}\nuser: {user}
|
| 8 |
+
query = ""
|
| 9 |
+
for item in messages[-num_turns:]:
|
| 10 |
+
item['role'] = item['role'].replace("assistant", "agent")
|
| 11 |
+
query += "{}: {}\n".format(item['role'], item['content'])
|
| 12 |
+
query = query.strip()
|
| 13 |
+
|
| 14 |
+
return query
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_query_with_topic(messages, topic, num_turns=3):
|
| 18 |
+
## convert query into a format as follows:
|
| 19 |
+
## user: this is a question about {topic}. {user}\nagent: {agent}\nuser: this is a question about {topic}. {user}
|
| 20 |
+
query = ""
|
| 21 |
+
for item in messages[-num_turns:]:
|
| 22 |
+
item['role'] = item['role'].replace("assistant", "agent")
|
| 23 |
+
if item['role'] == 'user':
|
| 24 |
+
query += "{}: this is a question about {}. {}\n".format(item['role'], topic, item['content'])
|
| 25 |
+
else:
|
| 26 |
+
query += "{}: {}\n".format(item['role'], item['content'])
|
| 27 |
+
query = query.strip()
|
| 28 |
+
|
| 29 |
+
return query
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_data_for_evaluation(input_datapath, document_datapath, dataset_name):
|
| 33 |
+
|
| 34 |
+
print('reading evaluation data from %s' % input_datapath)
|
| 35 |
+
with open(input_datapath, "r") as f:
|
| 36 |
+
input_list = json.load(f)
|
| 37 |
+
|
| 38 |
+
print('reading documents from %s' % document_datapath)
|
| 39 |
+
with open(document_datapath, "r") as f:
|
| 40 |
+
documents = json.load(f)
|
| 41 |
+
|
| 42 |
+
eval_data = {}
|
| 43 |
+
for item in input_list:
|
| 44 |
+
"""
|
| 45 |
+
We incorporate topic information for topiocqa and inscit datasets:
|
| 46 |
+
query = get_query_with_topic(item['messages'], item['topic'])
|
| 47 |
+
"""
|
| 48 |
+
query = get_query(item['messages'])
|
| 49 |
+
|
| 50 |
+
doc_id = item['document']
|
| 51 |
+
gold_idx = item['ground_truth_ctx']['index']
|
| 52 |
+
|
| 53 |
+
if dataset_name == 'qrecc':
|
| 54 |
+
"""
|
| 55 |
+
The 'gold context' for the qrecc dataset is obtained based on the word
|
| 56 |
+
overlaps between gold answer and each context in the document, which might
|
| 57 |
+
not be the real gold context.
|
| 58 |
+
|
| 59 |
+
To improve the evaluation quality of this dataset,
|
| 60 |
+
we further add the answer of the query into the 'gold context'
|
| 61 |
+
to ensure the 'gold context' is the most relevant chunk to the query.
|
| 62 |
+
|
| 63 |
+
Note that this is just for the retrieval evaluation purpose, we do not
|
| 64 |
+
add answer to the context for the ChatRAG evaluation.
|
| 65 |
+
"""
|
| 66 |
+
answer = item['answers'][0]
|
| 67 |
+
documents[doc_id][gold_idx] += " || " + answer
|
| 68 |
+
|
| 69 |
+
if doc_id not in eval_data:
|
| 70 |
+
eval_data[doc_id] = [{"query": query, "gold_idx": gold_idx}]
|
| 71 |
+
else:
|
| 72 |
+
eval_data[doc_id].append({"query": query, "gold_idx": gold_idx})
|
| 73 |
+
|
| 74 |
+
return eval_data, documents
|
evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import AutoModel, AutoTokenizer
|
| 3 |
+
from dataset import get_data_for_evaluation
|
| 4 |
+
from arguments import get_args
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512):
|
| 11 |
+
|
| 12 |
+
ranked_indices_list = []
|
| 13 |
+
gold_index_list = []
|
| 14 |
+
for doc_id in tqdm(eval_data):
|
| 15 |
+
context_list = documents[doc_id]
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
# get chunk embeddings
|
| 19 |
+
context_embs = []
|
| 20 |
+
for chunk in context_list:
|
| 21 |
+
chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
|
| 22 |
+
|
| 23 |
+
c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask)
|
| 24 |
+
c_emb = c_emb.last_hidden_state[:, 0, :]
|
| 25 |
+
context_embs.append(c_emb)
|
| 26 |
+
context_embs = torch.cat(context_embs, dim=0) # (num_chunk, hidden_dim)
|
| 27 |
+
|
| 28 |
+
sample_list = eval_data[doc_id]
|
| 29 |
+
query_embs = []
|
| 30 |
+
for item in sample_list:
|
| 31 |
+
gold_idx = item['gold_idx']
|
| 32 |
+
gold_index_list.append(gold_idx)
|
| 33 |
+
|
| 34 |
+
query = item['query']
|
| 35 |
+
query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
|
| 36 |
+
q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask)
|
| 37 |
+
q_emb = q_emb.last_hidden_state[:, 0, :]
|
| 38 |
+
query_embs.append(q_emb)
|
| 39 |
+
query_embs = torch.cat(query_embs, dim=0) # (num_query, hidden_dim)
|
| 40 |
+
|
| 41 |
+
similarities = query_embs.matmul(context_embs.transpose(0,1)) # (num_query, num_chunk)
|
| 42 |
+
ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (num_query, num_chunk)
|
| 43 |
+
ranked_indices_list.extend(ranked_results.tolist())
|
| 44 |
+
|
| 45 |
+
return ranked_indices_list, gold_index_list
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def calculate_recall(ranked_indices_list, gold_index_list, topk):
|
| 49 |
+
hit = 0
|
| 50 |
+
for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list):
|
| 51 |
+
for idx in ranked_indices[:topk]:
|
| 52 |
+
if idx == gold_index:
|
| 53 |
+
hit += 1
|
| 54 |
+
break
|
| 55 |
+
recall = hit / len(ranked_indices_list)
|
| 56 |
+
|
| 57 |
+
print("top-%d recall score: %.4f" % (topk, recall))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
args = get_args()
|
| 62 |
+
|
| 63 |
+
## get tokenizer
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path)
|
| 65 |
+
|
| 66 |
+
## get retriever model
|
| 67 |
+
query_encoder = AutoModel.from_pretrained(args.query_encoder_path)
|
| 68 |
+
context_encoder = AutoModel.from_pretrained(args.context_encoder_path)
|
| 69 |
+
query_encoder.to("cuda"), query_encoder.eval()
|
| 70 |
+
context_encoder.to("cuda"), context_encoder.eval()
|
| 71 |
+
|
| 72 |
+
## get evaluation data
|
| 73 |
+
if args.eval_dataset == "doc2dial":
|
| 74 |
+
input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath)
|
| 75 |
+
input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath)
|
| 76 |
+
elif args.eval_dataset == "quac":
|
| 77 |
+
input_datapath = os.path.join(args.data_folder, args.quac_datapath)
|
| 78 |
+
input_docpath = os.path.join(args.data_folder, args.quac_docpath)
|
| 79 |
+
elif args.eval_dataset == "qrecc":
|
| 80 |
+
input_datapath = os.path.join(args.data_folder, args.qrecc_datapath)
|
| 81 |
+
input_docpath = os.path.join(args.data_folder, args.qrecc_docpath)
|
| 82 |
+
elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit":
|
| 83 |
+
raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded")
|
| 84 |
+
else:
|
| 85 |
+
raise Exception("Please input a correct eval_dataset name!")
|
| 86 |
+
|
| 87 |
+
eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset)
|
| 88 |
+
|
| 89 |
+
## run retrieval
|
| 90 |
+
ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer)
|
| 91 |
+
print("number of the total test samples: %d" % len(ranked_indices_list))
|
| 92 |
+
|
| 93 |
+
## calculate recall scores
|
| 94 |
+
print("evaluating on %s" % args.eval_dataset)
|
| 95 |
+
topk_list = [1, 5, 20]
|
| 96 |
+
for topk in topk_list:
|
| 97 |
+
calculate_recall(ranked_indices_list, gold_index_list, topk=topk)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
main()
|
| 102 |
+
|