dp2022/squad_transform.py

153 lines
5.8 KiB
Python
Raw Permalink Normal View History

import json
from dotenv import load_dotenv
2022-02-20 22:04:51 +00:00
from tqdm import tqdm
2022-02-20 20:34:12 +00:00
from squad_utils import print_squad
2022-02-20 22:04:51 +00:00
from translate_utils import translate_text
def sort_qas_by_answer_index(squad):
for article in squad['data']:
for paragraph in article['paragraphs']:
impossible_qas = list(filter(lambda qas: qas['is_impossible'] == True, paragraph['qas']))
possible_qas = list(filter(lambda qas: qas['is_impossible'] == False, paragraph['qas']))
sorted_qas = sorted(possible_qas, key=lambda qas: qas['answers'][0]['answer_start'])
for qas in sorted_qas:
a = qas['answers'][0]
a['answer_end'] = a['answer_start'] + len(a['text'])
paragraph['qas'] = sorted_qas + impossible_qas
def transform_squad(squad):
for article in squad['data']:
for paragraph in article['paragraphs']:
add_special_chars_to_paragraph(paragraph)
def add_special_chars_to_paragraph(paragraph):
for counter, qas in enumerate(paragraph['qas']):
# Skip if impossible question
if qas["is_impossible"] == True: continue
if len(qas['answers']) > 1 or len(qas['answers']) == 0: continue
special_char = f"[{counter}]"
current = qas['answers'][0]
# Get start index
start = current['answer_start']
# Calculate end index
end = current['answer_end']
# Add special chars to context
context = paragraph['context']
paragraph['context'] = f"{context[:start]}{special_char} {context[start:end]} {special_char}{context[end:]}"
# Recalculate indexes
for q in paragraph['qas'][counter + 1:]: # Skip all answers before and current one
if q["is_impossible"] == True: continue
other = q['answers'][0]
if other['answer_start'] >= current['answer_start'] and other['answer_end'] <= current["answer_end"]: # Other is being enclosed by current
other['answer_start'] += len(special_char) +1
other['answer_end'] += 2*len(special_char) +2
elif other['answer_start'] < current['answer_end']: # Other is enclosing the current one
other['answer_start'] += len(special_char) +1
other['answer_end'] += len(special_char) +1
else: # Other is after current
other['answer_start'] += 2*len(special_char) +2
other['answer_end'] += 2*len(special_char) +2
# Fix indexes in current answer
other = paragraph['qas'][counter]['answers'][0]
if other == current: # Other answer is the one im working on
other['answer_start'] += len(special_char) +1
other['answer_end'] += len(special_char) +1
def detransform_squad(squad):
for article in squad['data']:
for paragraph in article['paragraphs']:
for counter, qas in enumerate(paragraph['qas']):
# Skip if impossible question
if qas["is_impossible"] == True: continue
if len(qas) == 0: continue
if len(qas['answers']) == 0: continue
special_char = f"[{counter}]"
len_special_char = len(special_char)
current = qas['answers'][0]
# Fix english indexes
start = paragraph['context'].find(special_char)
end = paragraph['context'].rfind(special_char) - len_special_char - 2
current['answer_start'] = start
current['answer_end'] = end
# Fix slovak indexes
start = paragraph['translated_context'].find(special_char)
end = paragraph['translated_context'].rfind(special_char) - len_special_char - 2
current['translated_answer_start'] = start
current['translated_answer_end'] = end
# Fix english context
paragraph['context'] = paragraph['context'].replace(f"{special_char} ", "")
# There are possible cases where special char is followed by ,. or is at end of paragraph
paragraph['context'] = paragraph['context'].replace(f" {special_char}", "")
# Fix slovak context
paragraph['translated_context'] = paragraph['translated_context'].replace(f"{special_char} ", "")
# There are possible cases where special char is followed by ,. or is at end of paragraph
paragraph['translated_context'] = paragraph['translated_context'].replace(f" {special_char}", "")
# Add translated_text to qas
start = current['translated_answer_start']
end = current['translated_answer_end']
current['translated_text'] = paragraph['translated_context'][start:end]
2022-02-20 22:04:51 +00:00
def translate_paragraphs(squad):
for article in tqdm(squad["data"]):
for paragraph in article["paragraphs"]:
# Translate context
2022-02-20 22:04:51 +00:00
translated = translate_text(paragraph["context"])
paragraph['translated_context'] = translated
# Translate questions
for qas in paragraph['qas']:
translated = translate_text(qas['question'])
qas['translated_question'] = translated
2022-02-20 22:04:51 +00:00
if __name__ == "__main__":
load_dotenv()
2022-02-20 21:02:03 +00:00
with open("./data/squad-v2-dev-small.json", "r") as f:
2022-02-20 20:34:12 +00:00
squad = json.load(f)
2022-02-20 20:32:55 +00:00
sort_qas_by_answer_index(squad)
transform_squad(squad)
2022-02-20 22:04:51 +00:00
translate_paragraphs(squad)
with open("./data/squad-v2-dev-small-transformed.json", "w") as f:
json.dump(squad, f, indent=2)
# with open("./data/squad-v2-dev-small-transformed.json", "r") as f:
# squad = json.load(f)
detransform_squad(squad)
2022-02-20 21:02:03 +00:00
with open("./data/squad-v2-dev-small-translated.json", "w") as f:
2022-02-20 20:34:12 +00:00
json.dump(squad, f, indent=2)