even more spring cleaning
This commit is contained in:
parent
e5c2794c06
commit
0e76056e0d
1
data/squad-v2-dev-small.json
Normal file
1
data/squad-v2-dev-small.json
Normal file
File diff suppressed because one or more lines are too long
@ -72,13 +72,13 @@ def add_special_chars_to_paragraph(paragraph):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
with open("./squad-test.json", "r") as f:
|
with open("./data/squad-v2-dev-small.json", "r") as f:
|
||||||
squad = json.load(f)
|
squad = json.load(f)
|
||||||
|
|
||||||
sort_qas_by_answer_index(squad)
|
sort_qas_by_answer_index(squad)
|
||||||
transform_squad(squad)
|
transform_squad(squad)
|
||||||
print_squad(squad)
|
print_squad(squad)
|
||||||
|
|
||||||
with open("./squad-test-translated.json", "w") as f:
|
with open("./data/squad-v2-dev-small-translated.json", "w") as f:
|
||||||
json.dump(squad, f, indent=2)
|
json.dump(squad, f, indent=2)
|
||||||
|
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
import json
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
import six
|
|
||||||
from google.cloud import translate_v2 as translate
|
|
||||||
|
|
||||||
|
|
||||||
def translate_text(text):
|
|
||||||
"""Translates text into the target language.
|
|
||||||
|
|
||||||
Target must be an ISO 639-1 language code.
|
|
||||||
See https://g.co/cloud/translate/v2/translate-reference#supported_languages
|
|
||||||
"""
|
|
||||||
|
|
||||||
translate_client = translate.Client()
|
|
||||||
|
|
||||||
if isinstance(text, six.binary_type):
|
|
||||||
text = text.decode("utf-8")
|
|
||||||
|
|
||||||
# Text can also be a sequence of strings, in which case this method
|
|
||||||
# will return a sequence of results for each text.
|
|
||||||
result = translate_client.translate(text, target_language="sk")
|
|
||||||
|
|
||||||
print(u"Text: {}".format(result["input"]))
|
|
||||||
print(u"Translation: {}".format(result["translatedText"]))
|
|
||||||
print(u"Detected source language: {}".format(result["detectedSourceLanguage"]))
|
|
||||||
|
|
||||||
|
|
||||||
def sort_qas_by_answer_index(squad):
|
|
||||||
for article in squad['data']:
|
|
||||||
for paragraph in article['paragraphs']:
|
|
||||||
impossible_qas = list(filter(lambda qas: qas['is_impossible'] == True, paragraph['qas']))
|
|
||||||
possible_qas = list(filter(lambda qas: qas['is_impossible'] == False, paragraph['qas']))
|
|
||||||
sorted_qas = sorted(possible_qas, key=lambda qas: qas['answers'][0]['answer_start'])
|
|
||||||
|
|
||||||
for qas in sorted_qas:
|
|
||||||
a = qas['answers'][0]
|
|
||||||
a['answer_end'] = a['answer_start'] + len(a['text'])
|
|
||||||
|
|
||||||
paragraph['qas'] = sorted_qas + impossible_qas
|
|
||||||
|
|
||||||
|
|
||||||
def transform_squad(squad):
|
|
||||||
for article in squad['data']:
|
|
||||||
for paragraph in article['paragraphs']:
|
|
||||||
add_special_chars_to_paragraph(paragraph)
|
|
||||||
|
|
||||||
|
|
||||||
def add_special_chars_to_paragraph(paragraph):
|
|
||||||
for counter, qas in enumerate(paragraph['qas']):
|
|
||||||
# Skip if impossible question
|
|
||||||
if qas["is_impossible"] == True: continue
|
|
||||||
|
|
||||||
special_char = f"[{counter}]"
|
|
||||||
|
|
||||||
if len(qas['answers']) > 1 or len(qas['answers']) == 0: continue
|
|
||||||
|
|
||||||
current = qas['answers'][0]
|
|
||||||
|
|
||||||
# Get start index
|
|
||||||
start = current['answer_start']
|
|
||||||
# Calculate end index
|
|
||||||
end = current['answer_end']
|
|
||||||
# Add special chars to context
|
|
||||||
context = paragraph['context']
|
|
||||||
paragraph['context'] = f"{context[:start]}{special_char} {context[start:end]} {special_char}{context[end:]}"
|
|
||||||
|
|
||||||
# Recalculate indexes
|
|
||||||
for q in paragraph['qas'][counter + 1:]: # Skip all answers before and current one
|
|
||||||
if q["is_impossible"] == True: continue
|
|
||||||
|
|
||||||
other = q['answers'][0]
|
|
||||||
|
|
||||||
if other['answer_start'] >= current['answer_start'] and other['answer_end'] <= current["answer_end"]: # Other is being enclosed by current
|
|
||||||
other['answer_start'] += len(special_char) +1
|
|
||||||
other['answer_end'] += 2*len(special_char) +2
|
|
||||||
|
|
||||||
elif other['answer_start'] < current['answer_end']: # Other is enclosing the current one
|
|
||||||
other['answer_start'] += len(special_char) +1
|
|
||||||
other['answer_end'] += len(special_char) +1
|
|
||||||
|
|
||||||
else: # Other is after current
|
|
||||||
other['answer_start'] += 2*len(special_char) +2
|
|
||||||
other['answer_end'] += 2*len(special_char) +2
|
|
||||||
|
|
||||||
# Fix indexes in current answer
|
|
||||||
other = paragraph['qas'][counter]['answers'][0]
|
|
||||||
|
|
||||||
if other == current: # Other answer is the one im working on
|
|
||||||
other['answer_start'] += len(special_char) +1
|
|
||||||
other['answer_end'] += len(special_char) +1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
with open("./squad-test.json", "r") as f:
|
|
||||||
squad = json.load(f)
|
|
||||||
|
|
||||||
sort_qas_by_answer_index(squad)
|
|
||||||
transform_squad(squad)
|
|
||||||
print_squad(squad)
|
|
||||||
|
|
||||||
with open("./squad-test-out.json", "w") as f:
|
|
||||||
json.dump(squad, f, indent=2)
|
|
||||||
|
|
@ -69,7 +69,7 @@ def print_squad(squad, article_limit=100, paragraph_limit=100, qas_limit=100):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
squad = None
|
squad = None
|
||||||
|
|
||||||
with open("squad-v2-dev.json", "r", encoding="utf-8") as f:
|
with open("./data/squad-v2-dev.json", "r", encoding="utf-8") as f:
|
||||||
squad = json.load(f)
|
squad = json.load(f)
|
||||||
|
|
||||||
calculate_chars(squad)
|
calculate_chars(squad)
|
||||||
|
27
translate_utils.py
Normal file
27
translate_utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import json
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import six
|
||||||
|
from google.cloud import translate_v2 as translate
|
||||||
|
|
||||||
|
|
||||||
|
def translate_text(text):
|
||||||
|
"""Translates text into the target language.
|
||||||
|
|
||||||
|
Target must be an ISO 639-1 language code.
|
||||||
|
See https://g.co/cloud/translate/v2/translate-reference#supported_languages
|
||||||
|
"""
|
||||||
|
|
||||||
|
translate_client = translate.Client()
|
||||||
|
|
||||||
|
if isinstance(text, six.binary_type):
|
||||||
|
text = text.decode("utf-8")
|
||||||
|
|
||||||
|
# Text can also be a sequence of strings, in which case this method
|
||||||
|
# will return a sequence of results for each text.
|
||||||
|
result = translate_client.translate(text, target_language="sk")
|
||||||
|
|
||||||
|
print(u"Text: {}".format(result["input"]))
|
||||||
|
print(u"Translation: {}".format(result["translatedText"]))
|
||||||
|
print(u"Detected source language: {}".format(result["detectedSourceLanguage"]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user