|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""pp ops."""
|
|
|
|
from big_vision.pp.registry import Registry
|
|
import tensorflow as tf
|
|
|
|
|
|
@Registry.register('preprocess_ops.sci_qa_choices_shuffle')
|
|
def sci_qa_choices_shuffle(
|
|
choice_str_inkey='choices',
|
|
ans_inkey='answer',
|
|
indexed_choices_outkey='indexed_choices',
|
|
indexed_answer_outkey='indexed_answer',
|
|
):
|
|
"""Random shuffle the sci_qa's choice on the fly.
|
|
|
|
Args:
|
|
choice_str_inkey: the original choice list from
|
|
sciqa,e.g['apple','banana',..]
|
|
ans_inkey: the original answer from sciqa e.g. 1
|
|
indexed_choices_outkey: shuffled choice (with index suffix concat to string)
|
|
e.g."(A) banana, (B) apple"
|
|
indexed_answer_outkey: shuffled answer with abc index, e,g
|
|
1(original)->2(shuffled)->'B' (alphabet index)
|
|
|
|
Returns:
|
|
"""
|
|
def _template(data):
|
|
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
abc_tensor = tf.constant([f'({a})' for a in alphabet])
|
|
abcans_tensor = tf.constant([f'{a}' for a in alphabet])
|
|
choices = data[choice_str_inkey]
|
|
indices = tf.range(len(choices))
|
|
|
|
shuffled_indices = tf.random.shuffle(indices)
|
|
|
|
shuffled_tensor = tf.gather(choices, shuffled_indices)
|
|
|
|
abc_tensor = tf.gather(abc_tensor, indices)
|
|
|
|
data[indexed_choices_outkey] = tf.strings.reduce_join(
|
|
tf.strings.join([abc_tensor, shuffled_tensor], separator=' '),
|
|
separator=', ',
|
|
)
|
|
|
|
answer_tensor = data[ans_inkey]
|
|
new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor))
|
|
new_ans_indice = tf.gather(abcans_tensor, new_ans_indice)
|
|
data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice)
|
|
return data
|
|
|
|
return _template
|
|
|