import json from dotenv import load_dotenv from tqdm import tqdm from squad_utils import print_squad from translate_utils import translate_text def sort_qas_by_answer_index(squad): for article in squad['data']: for paragraph in article['paragraphs']: impossible_qas = list(filter(lambda qas: qas['is_impossible'] == True, paragraph['qas'])) possible_qas = list(filter(lambda qas: qas['is_impossible'] == False, paragraph['qas'])) sorted_qas = sorted(possible_qas, key=lambda qas: qas['answers'][0]['answer_start']) for qas in sorted_qas: a = qas['answers'][0] a['answer_end'] = a['answer_start'] + len(a['text']) paragraph['qas'] = sorted_qas + impossible_qas def transform_squad(squad): for article in squad['data']: for paragraph in article['paragraphs']: add_special_chars_to_paragraph(paragraph) def add_special_chars_to_paragraph(paragraph): for counter, qas in enumerate(paragraph['qas']): # Skip if impossible question if qas["is_impossible"] == True: continue if len(qas['answers']) > 1 or len(qas['answers']) == 0: continue special_char = f"[{counter}]" current = qas['answers'][0] # Get start index start = current['answer_start'] # Calculate end index end = current['answer_end'] # Add special chars to context context = paragraph['context'] paragraph['context'] = f"{context[:start]}{special_char} {context[start:end]} {special_char}{context[end:]}" # Recalculate indexes for q in paragraph['qas'][counter + 1:]: # Skip all answers before and current one if q["is_impossible"] == True: continue other = q['answers'][0] if other['answer_start'] >= current['answer_start'] and other['answer_end'] <= current["answer_end"]: # Other is being enclosed by current other['answer_start'] += len(special_char) +1 other['answer_end'] += 2*len(special_char) +2 elif other['answer_start'] < current['answer_end']: # Other is enclosing the current one other['answer_start'] += len(special_char) +1 other['answer_end'] += len(special_char) +1 else: # Other is after current other['answer_start'] += 2*len(special_char) +2 other['answer_end'] += 2*len(special_char) +2 # Fix indexes in current answer other = paragraph['qas'][counter]['answers'][0] if other == current: # Other answer is the one im working on other['answer_start'] += len(special_char) +1 other['answer_end'] += len(special_char) +1 def detransform_squad(squad): for article in squad['data']: for paragraph in article['paragraphs']: for counter, qas in enumerate(paragraph['qas']): # Skip if impossible question if qas["is_impossible"] == True: continue if len(qas) == 0: continue if len(qas['answers']) == 0: continue special_char = f"[{counter}]" len_special_char = len(special_char) current = qas['answers'][0] # Fix english indexes start = paragraph['context'].find(special_char) end = paragraph['context'].rfind(special_char) - len_special_char - 2 current['answer_start'] = start current['answer_end'] = end # Fix slovak indexes start = paragraph['translated_context'].find(special_char) end = paragraph['translated_context'].rfind(special_char) - len_special_char - 2 current['translated_answer_start'] = start current['translated_answer_end'] = end # Fix english context paragraph['context'] = paragraph['context'].replace(f"{special_char} ", "") # There are possible cases where special char is followed by ,. or is at end of paragraph paragraph['context'] = paragraph['context'].replace(f" {special_char}", "") # Fix slovak context paragraph['translated_context'] = paragraph['translated_context'].replace(f"{special_char} ", "") # There are possible cases where special char is followed by ,. or is at end of paragraph paragraph['translated_context'] = paragraph['translated_context'].replace(f" {special_char}", "") # Add translated_text to qas start = current['translated_answer_start'] end = current['translated_answer_end'] current['translated_text'] = paragraph['translated_context'][start:end] def translate_paragraphs(squad): for article in tqdm(squad["data"]): for paragraph in article["paragraphs"]: # Translate context translated = translate_text(paragraph["context"]) paragraph['translated_context'] = translated # Translate questions for qas in paragraph['qas']: translated = translate_text(qas['question']) qas['translated_question'] = translated if __name__ == "__main__": load_dotenv() with open("./data/squad-v2-dev-small.json", "r") as f: squad = json.load(f) sort_qas_by_answer_index(squad) transform_squad(squad) translate_paragraphs(squad) with open("./data/squad-v2-dev-small-transformed.json", "w") as f: json.dump(squad, f, indent=2) # with open("./data/squad-v2-dev-small-transformed.json", "r") as f: # squad = json.load(f) detransform_squad(squad) with open("./data/squad-v2-dev-small-translated.json", "w") as f: json.dump(squad, f, indent=2)