add testing models files
This commit is contained in:
parent
4e0499ff05
commit
ff01567cb4
80
Backend/index-server-es.py
Normal file
80
Backend/index-server-es.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
from elasticsearch.helpers import bulk
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Настройка подключения к Elasticsearch с аутентификацией и HTTPS
|
||||||
|
es = Elasticsearch(
|
||||||
|
[{'host': 'localhost', 'port': 9200, 'scheme': 'https'}],
|
||||||
|
http_auth=('elastic', 'S7DoO3ma=G=9USBPbqq3'), # замените на ваш пароль
|
||||||
|
verify_certs=False # Отключить проверку SSL-сертификата, если используется самоподписанный сертификат
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||||||
|
|
||||||
|
def create_index():
|
||||||
|
# Определяем маппинг для индекса
|
||||||
|
mapping = {
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "text",
|
||||||
|
"analyzer": "standard"
|
||||||
|
},
|
||||||
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 384 # Размерность векторного представления
|
||||||
|
},
|
||||||
|
"full_data": {
|
||||||
|
"type": "object",
|
||||||
|
"enabled": False # Отключаем индексацию вложенных данных
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
es.indices.create(index='drug_docs', body=mapping, ignore=400)
|
||||||
|
|
||||||
|
def load_drug_data(json_path):
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def index_documents(data):
|
||||||
|
actions = []
|
||||||
|
total_docs = len(data)
|
||||||
|
for i, item in enumerate(data, start=1):
|
||||||
|
doc_text = f"{item['link']} {item.get('pribalovy_letak', '')} {item.get('spc', '')}"
|
||||||
|
|
||||||
|
vector = embeddings.embed_query(doc_text)
|
||||||
|
|
||||||
|
action = {
|
||||||
|
"_index": "drug_docs",
|
||||||
|
"_id": i,
|
||||||
|
"_source": {
|
||||||
|
'text': doc_text,
|
||||||
|
'vector': vector,
|
||||||
|
'full_data': item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
actions.append(action)
|
||||||
|
|
||||||
|
# Отображение прогресса
|
||||||
|
print(f"Индексируется документ {i}/{total_docs}", end='\r')
|
||||||
|
|
||||||
|
# Опционально: индексируем пакетами по N документов
|
||||||
|
if i % 100 == 0 or i == total_docs:
|
||||||
|
bulk(es, actions)
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
# Если остались неиндексированные документы
|
||||||
|
if actions:
|
||||||
|
bulk(es, actions)
|
||||||
|
|
||||||
|
print("\nИндексирование завершено.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_index()
|
||||||
|
data_path = "/home/poiasnik/esDB/cleaned_general_info_additional.json"
|
||||||
|
drug_data = load_drug_data(data_path)
|
||||||
|
index_documents(drug_data)
|
||||||
|
|
77
Backend/qwen72-test.py
Normal file
77
Backend/qwen72-test.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
# Настройка логирования
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Подключение к Elasticsearch
|
||||||
|
es = Elasticsearch(
|
||||||
|
["https://localhost:9200"],
|
||||||
|
basic_auth=("elastic", "S7DoO3ma=G=9USBPbqq3"), # Ваш пароль
|
||||||
|
verify_certs=False
|
||||||
|
)
|
||||||
|
index_name = 'drug_docs'
|
||||||
|
|
||||||
|
# Загрузка токенизатора и модели
|
||||||
|
model_name = "Qwen/Qwen2.5-7B-Instruct"
|
||||||
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Проверка наличия pad_token
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
def text_search(query, k=10, max_doc_length=300, max_docs=3):
|
||||||
|
try:
|
||||||
|
es_results = es.search(
|
||||||
|
index=index_name,
|
||||||
|
body={"size": k, "query": {"match": {"text": query}}}
|
||||||
|
)
|
||||||
|
text_documents = [hit['_source'].get('text', '') for hit in es_results['hits']['hits']]
|
||||||
|
text_documents = [doc[:max_doc_length] for doc in text_documents[:max_docs]]
|
||||||
|
return text_documents
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ошибка поиска: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Пример запроса для поиска
|
||||||
|
query = "čo piť pri horúčke"
|
||||||
|
text_documents = text_search(query)
|
||||||
|
|
||||||
|
# Обрезаем текст, если он превышает предел токенов модели
|
||||||
|
max_tokens_per_input = 1024 # Установим более низкое значение для max_tokens
|
||||||
|
context_text = ' '.join(text_documents)
|
||||||
|
input_text = (
|
||||||
|
f"Informácie o liekoch: {context_text[:max_tokens_per_input]}\n"
|
||||||
|
"Uveďte tri konkrétne lieky alebo riešenia s veľmi krátkym vysvetlením pre každý z nich.\n"
|
||||||
|
"Odpoveď v slovenčine:"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Токенизация входного текста
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=max_tokens_per_input, truncation=True).to(device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generated_ids = model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
max_new_tokens=300, # Снижено значение
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=False, # Отключено семплирование для детерминированного вывода
|
||||||
|
pad_token_id=tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True, errors='ignore')
|
||||||
|
print("Сгенерированный текст:", response)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Произошла ошибка во время генерации: {e}")
|
||||||
|
|
308
Backend/qwen7b-test.py
Normal file
308
Backend/qwen7b-test.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
"""A simple command-line interactive chat demo for Qwen2.5-Instruct model with left-padding using bos_token."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
from copy import deepcopy
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
|
DEFAULT_CKPT_PATH = "Qwen/Qwen2.5-7B-Instruct"
|
||||||
|
|
||||||
|
_WELCOME_MSG = """\
|
||||||
|
Welcome to use Qwen2.5-Instruct model, type text to start chat, type :h to show command help.
|
||||||
|
"""
|
||||||
|
_HELP_MSG = """\
|
||||||
|
Commands:
|
||||||
|
:help / :h Show this help message
|
||||||
|
:exit / :quit / :q Exit the demo
|
||||||
|
:clear / :cl Clear screen
|
||||||
|
:clear-history / :clh Clear history
|
||||||
|
:history / :his Show history
|
||||||
|
:seed Show current random seed
|
||||||
|
:seed <N> Set random seed to <N>
|
||||||
|
:conf Show current generation config
|
||||||
|
:conf <key>=<value> Change generation config
|
||||||
|
:reset-conf Reset generation config
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ALL_COMMAND_NAMES = [
|
||||||
|
"help",
|
||||||
|
"h",
|
||||||
|
"exit",
|
||||||
|
"quit",
|
||||||
|
"q",
|
||||||
|
"clear",
|
||||||
|
"cl",
|
||||||
|
"clear-history",
|
||||||
|
"clh",
|
||||||
|
"history",
|
||||||
|
"his",
|
||||||
|
"seed",
|
||||||
|
"conf",
|
||||||
|
"reset-conf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_readline():
|
||||||
|
try:
|
||||||
|
import readline
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
_matches = []
|
||||||
|
|
||||||
|
def _completer(text, state):
|
||||||
|
nonlocal _matches
|
||||||
|
|
||||||
|
if state == 0:
|
||||||
|
_matches = [
|
||||||
|
cmd_name for cmd_name in _ALL_COMMAND_NAMES if cmd_name.startswith(text)
|
||||||
|
]
|
||||||
|
if 0 <= state < len(_matches):
|
||||||
|
return _matches[state]
|
||||||
|
return None
|
||||||
|
|
||||||
|
readline.set_completer(_completer)
|
||||||
|
readline.parse_and_bind("tab: complete")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model_tokenizer(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.checkpoint_path,
|
||||||
|
resume_download=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set bos_token for left-padding
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.bos_token
|
||||||
|
|
||||||
|
device_map = "cpu" if args.cpu_only else "auto"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.checkpoint_path,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device_map,
|
||||||
|
resume_download=True,
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
# Conservative generation config
|
||||||
|
model.generation_config.max_new_tokens = 256
|
||||||
|
model.generation_config.temperature = 0.7
|
||||||
|
model.generation_config.top_k = 50
|
||||||
|
model.generation_config.top_p = 0.9
|
||||||
|
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
model.generation_config.do_sample = False
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _gc():
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_screen():
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
os.system("cls")
|
||||||
|
else:
|
||||||
|
os.system("clear")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_history(history):
|
||||||
|
terminal_width = shutil.get_terminal_size()[0]
|
||||||
|
print(f"History ({len(history)})".center(terminal_width, "="))
|
||||||
|
for index, (query, response) in enumerate(history):
|
||||||
|
print(f"User[{index}]: {query}")
|
||||||
|
print(f"Qwen[{index}]: {response}")
|
||||||
|
print("=" * terminal_width)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_input() -> str:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = input("User> ").strip()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
print("[ERROR] Encoding error in input")
|
||||||
|
continue
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
exit(1)
|
||||||
|
if message:
|
||||||
|
return message
|
||||||
|
print("[ERROR] Query is empty")
|
||||||
|
|
||||||
|
|
||||||
|
def _chat_stream(model, tokenizer, query, history):
|
||||||
|
conversation = []
|
||||||
|
for query_h, response_h in history:
|
||||||
|
conversation.append({"role": "user", "content": query_h})
|
||||||
|
conversation.append({"role": "assistant", "content": response_h})
|
||||||
|
conversation.append({"role": "user", "content": query})
|
||||||
|
input_text = tokenizer.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
# Perform left-padding with bos_token
|
||||||
|
inputs = tokenizer(
|
||||||
|
[input_text],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
truncation=True,
|
||||||
|
pad_to_multiple_of=8,
|
||||||
|
max_length=1024,
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
# Update attention_mask for left-padding compatibility
|
||||||
|
inputs["attention_mask"] = inputs["attention_mask"].flip(dims=[1])
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
generation_kwargs = {
|
||||||
|
**inputs,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
yield new_text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Qwen2.5-Instruct command-line interactive chat demo."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_CKPT_PATH,
|
||||||
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
|
)
|
||||||
|
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
history, response = [], ""
|
||||||
|
|
||||||
|
model, tokenizer = _load_model_tokenizer(args)
|
||||||
|
orig_gen_config = deepcopy(model.generation_config)
|
||||||
|
|
||||||
|
_setup_readline()
|
||||||
|
|
||||||
|
_clear_screen()
|
||||||
|
print(_WELCOME_MSG)
|
||||||
|
|
||||||
|
seed = args.seed
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = _get_input()
|
||||||
|
|
||||||
|
# Process commands.
|
||||||
|
if query.startswith(":"):
|
||||||
|
command_words = query[1:].strip().split()
|
||||||
|
if not command_words:
|
||||||
|
command = ""
|
||||||
|
else:
|
||||||
|
command = command_words[0]
|
||||||
|
|
||||||
|
if command in ["exit", "quit", "q"]:
|
||||||
|
break
|
||||||
|
elif command in ["clear", "cl"]:
|
||||||
|
_clear_screen()
|
||||||
|
print(_WELCOME_MSG)
|
||||||
|
_gc()
|
||||||
|
continue
|
||||||
|
elif command in ["clear-history", "clh"]:
|
||||||
|
print(f"[INFO] All {len(history)} history cleared")
|
||||||
|
history.clear()
|
||||||
|
_gc()
|
||||||
|
continue
|
||||||
|
elif command in ["help", "h"]:
|
||||||
|
print(_HELP_MSG)
|
||||||
|
continue
|
||||||
|
elif command in ["history", "his"]:
|
||||||
|
_print_history(history)
|
||||||
|
continue
|
||||||
|
elif command in ["seed"]:
|
||||||
|
if len(command_words) == 1:
|
||||||
|
print(f"[INFO] Current random seed: {seed}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
new_seed_s = command_words[1]
|
||||||
|
try:
|
||||||
|
new_seed = int(new_seed_s)
|
||||||
|
except ValueError:
|
||||||
|
print(
|
||||||
|
f"[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"[INFO] Random seed changed to {new_seed}")
|
||||||
|
seed = new_seed
|
||||||
|
continue
|
||||||
|
elif command in ["conf"]:
|
||||||
|
if len(command_words) == 1:
|
||||||
|
print(model.generation_config)
|
||||||
|
else:
|
||||||
|
for key_value_pairs_str in command_words[1:]:
|
||||||
|
eq_idx = key_value_pairs_str.find("=")
|
||||||
|
if eq_idx == -1:
|
||||||
|
print("[WARNING] format: <key>=<value>")
|
||||||
|
continue
|
||||||
|
conf_key, conf_value_str = (
|
||||||
|
key_value_pairs_str[:eq_idx],
|
||||||
|
key_value_pairs_str[eq_idx + 1 :],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
conf_value = eval(conf_value_str)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"[INFO] Change config: model.generation_config.{conf_key} = {conf_value}"
|
||||||
|
)
|
||||||
|
setattr(model.generation_config, conf_key, conf_value)
|
||||||
|
continue
|
||||||
|
elif command in ["reset-conf"]:
|
||||||
|
print("[INFO] Reset generation config")
|
||||||
|
model.generation_config = deepcopy(orig_gen_config)
|
||||||
|
print(model.generation_config)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run chat.
|
||||||
|
set_seed(seed)
|
||||||
|
_clear_screen()
|
||||||
|
print(f"\nUser: {query}")
|
||||||
|
print(f"\nQwen: ", end="")
|
||||||
|
try:
|
||||||
|
partial_text = ""
|
||||||
|
for new_text in _chat_stream(model, tokenizer, query, history):
|
||||||
|
print(new_text, end="", flush=True)
|
||||||
|
partial_text += new_text
|
||||||
|
response = partial_text
|
||||||
|
print()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("[WARNING] Generation interrupted")
|
||||||
|
continue
|
||||||
|
|
||||||
|
history.append((query, response))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
34
Backend/test_flant5.py
Normal file
34
Backend/test_flant5.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
API_TOKEN = "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"
|
||||||
|
API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-large"
|
||||||
|
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {API_TOKEN}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def query_flan_t5(prompt):
|
||||||
|
payload = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"max_length": 250,
|
||||||
|
"do_sample": True,
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response = requests.post(API_URL, headers=headers, json=payload)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
prompt = "Ako sa máš? Daj odpoved v slovencine"
|
||||||
|
result = query_flan_t5(prompt)
|
||||||
|
|
||||||
|
if isinstance(result, list) and len(result) > 0:
|
||||||
|
print("Ответ от Flan-T5:", result[0]['generated_text'])
|
||||||
|
else:
|
||||||
|
print("Ошибка при получении ответа:", result)
|
47
Backend/test_mt5_base.py
Normal file
47
Backend/test_mt5_base.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# import requests
|
||||||
|
#
|
||||||
|
# API_TOKEN = "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"
|
||||||
|
# API_URL = "https://api-inference.huggingface.co/models/google/mt5-base"
|
||||||
|
#
|
||||||
|
# headers = {
|
||||||
|
# "Authorization": f"Bearer {API_TOKEN}",
|
||||||
|
# "Content-Type": "application/json"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# def query_mT5(prompt):
|
||||||
|
# payload = {
|
||||||
|
# "inputs": prompt,
|
||||||
|
# "parameters": {
|
||||||
|
# "max_length": 100,
|
||||||
|
# "do_sample": True,
|
||||||
|
# "temperature": 0.7
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# response = requests.post(API_URL, headers=headers, json=payload)
|
||||||
|
# return response.json()
|
||||||
|
#
|
||||||
|
# # Пример использования
|
||||||
|
# result = query_mT5("Aké sú účinné lieky na horúčku?")
|
||||||
|
# print("Ответ от mT5:", result)
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, MT5ForConditionalGeneration
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
||||||
|
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||||
|
|
||||||
|
# training
|
||||||
|
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||||
|
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
||||||
|
outputs = model(input_ids=input_ids, labels=labels)
|
||||||
|
loss = outputs.loss
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# inference
|
||||||
|
|
||||||
|
input_ids = tokenizer(
|
||||||
|
"summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
|
||||||
|
).input_ids # Batch size 1
|
||||||
|
outputs = model.generate(input_ids, max_new_tokens=50)
|
||||||
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
# studies have shown that owning a dog is good for you.
|
19
Backend/test_slovakbert-skquad.py
Normal file
19
Backend/test_slovakbert-skquad.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer, util
|
||||||
|
|
||||||
|
# Загрузка модели из Hugging Face
|
||||||
|
model = SentenceTransformer("TUKE-DeutscheTelekom/slovakbert-skquad-mnlr") # Замените на ID нужной модели
|
||||||
|
|
||||||
|
# Пример предложений на словацком языке
|
||||||
|
sentences = [
|
||||||
|
"Prvý most cez Zlatý roh nechal vybudovať cisár Justinián I. V roku 1502 vypísal sultán Bajezid II. súťaž na nový most.",
|
||||||
|
"V ktorom roku vznikol druhý drevený most cez záliv Zlatý roh?",
|
||||||
|
"Aká je priemerná dĺžka života v Eritrei?"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Получение эмбеддингов для каждого предложения
|
||||||
|
embeddings = model.encode(sentences)
|
||||||
|
print("Shape of embeddings:", embeddings.shape) # Вывод формы эмбеддингов, например (3, 768)
|
||||||
|
|
||||||
|
# Вычисление сходства между предложениями
|
||||||
|
similarities = util.cos_sim(embeddings, embeddings)
|
||||||
|
print("Similarity matrix:\n", similarities)
|
Loading…
Reference in New Issue
Block a user