Add commands192024.py
This commit is contained in:
parent
5eac948712
commit
7b333c201c
49
commands192024.py
Normal file
49
commands192024.py
Normal file
@ -0,0 +1,49 @@
|
||||
pip install happytransformer
|
||||
pip install datasets
|
||||
|
||||
|
||||
import csv
|
||||
from happytransformer import HappyTextToText
|
||||
from happytransformer import TTSettings
|
||||
from datasets import load_dataset
|
||||
happy_tt = HappyTextToText("T5", "t5-base")
|
||||
|
||||
train_dataset = load_dataset("jfleg", split='validation[:]')
|
||||
eval_dataset = load_dataset("jfleg", split='test[:]')
|
||||
|
||||
|
||||
def generate_csv(csv_path, dataset):
|
||||
with open(csv_path, 'w', newline='') as csvfile:
|
||||
writter = csv.writer(csvfile)
|
||||
writter.writerow(["input", "target"])
|
||||
for case in dataset:
|
||||
# Adding the task's prefix to input
|
||||
input_text = "grammar: " + case["sentence"]
|
||||
for correction in case["corrections"]:
|
||||
# a few of the cases contain blank strings.
|
||||
if input_text and correction:
|
||||
writter.writerow([input_text, correction])
|
||||
|
||||
|
||||
|
||||
generate_csv("train.csv", train_dataset)
|
||||
generate_csv("eval.csv", eval_dataset)
|
||||
|
||||
evalsettings = TTTrainArgs(max_input_length=4096, max_output_length=4096)
|
||||
before_result = happy_tt.eval("eval.csv", args=evalsettings)
|
||||
print("Before loss:", before_result.loss)
|
||||
|
||||
from happytransformer import TTTrainArgs
|
||||
args = TTTrainArgs(batch_size=8, max_input_length = 10000, max_output_length = 10000)
|
||||
happy_tt.train("train.csv", args=args)
|
||||
|
||||
after_loss = happy_tt.eval("eval.csv", args=evalsettings)
|
||||
print("After loss: ", after_loss.loss)
|
||||
|
||||
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=20)
|
||||
example_1 = "grammar: This sentences, has bads grammar and spelling!"
|
||||
result_1 = happy_tt.generate_text(example_1, args=beam_settings)
|
||||
print(result_1.text)
|
||||
example_2 = "grammar: I am enjoys, writtings articles ons AI."
|
||||
result_2 = happy_tt.generate_text(example_2, args=beam_settings)
|
||||
print(result_2.text)
|
Loading…
Reference in New Issue
Block a user