Add commands192024.py

This commit is contained in:
Andrii Pervashov 2024-09-01 16:18:06 +00:00
parent 5eac948712
commit 7b333c201c

49
commands192024.py Normal file
View File

@ -0,0 +1,49 @@
pip install happytransformer
pip install datasets
import csv
from happytransformer import HappyTextToText
from happytransformer import TTSettings
from datasets import load_dataset
happy_tt = HappyTextToText("T5", "t5-base")
train_dataset = load_dataset("jfleg", split='validation[:]')
eval_dataset = load_dataset("jfleg", split='test[:]')
def generate_csv(csv_path, dataset):
with open(csv_path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["input", "target"])
for case in dataset:
# Adding the task's prefix to input
input_text = "grammar: " + case["sentence"]
for correction in case["corrections"]:
# a few of the cases contain blank strings.
if input_text and correction:
writter.writerow([input_text, correction])
generate_csv("train.csv", train_dataset)
generate_csv("eval.csv", eval_dataset)
evalsettings = TTTrainArgs(max_input_length=4096, max_output_length=4096)
before_result = happy_tt.eval("eval.csv", args=evalsettings)
print("Before loss:", before_result.loss)
from happytransformer import TTTrainArgs
args = TTTrainArgs(batch_size=8, max_input_length = 10000, max_output_length = 10000)
happy_tt.train("train.csv", args=args)
after_loss = happy_tt.eval("eval.csv", args=evalsettings)
print("After loss: ", after_loss.loss)
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=20)
example_1 = "grammar: This sentences, has bads grammar and spelling!"
result_1 = happy_tt.generate_text(example_1, args=beam_settings)
print(result_1.text)
example_2 = "grammar: I am enjoys, writtings articles ons AI."
result_2 = happy_tt.generate_text(example_2, args=beam_settings)
print(result_2.text)