48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
# import requests
|
|
#
|
|
# API_TOKEN = "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"
|
|
# API_URL = "https://api-inference.huggingface.co/models/google/mt5-base"
|
|
#
|
|
# headers = {
|
|
# "Authorization": f"Bearer {API_TOKEN}",
|
|
# "Content-Type": "application/json"
|
|
# }
|
|
#
|
|
# def query_mT5(prompt):
|
|
# payload = {
|
|
# "inputs": prompt,
|
|
# "parameters": {
|
|
# "max_length": 100,
|
|
# "do_sample": True,
|
|
# "temperature": 0.7
|
|
# }
|
|
# }
|
|
# response = requests.post(API_URL, headers=headers, json=payload)
|
|
# return response.json()
|
|
#
|
|
# # Пример использования
|
|
# result = query_mT5("Aké sú účinné lieky na horúčku?")
|
|
# print("Ответ от mT5:", result)
|
|
|
|
from transformers import AutoTokenizer, MT5ForConditionalGeneration
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
|
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
|
|
|
# training
|
|
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
|
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
|
outputs = model(input_ids=input_ids, labels=labels)
|
|
loss = outputs.loss
|
|
logits = outputs.logits
|
|
|
|
# inference
|
|
|
|
input_ids = tokenizer(
|
|
"summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
|
|
).input_ids # Batch size 1
|
|
outputs = model.generate(input_ids, max_new_tokens=50)
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
# studies have shown that owning a dog is good for you.
|