Aktualizovat sk1/backend/app.py

This commit is contained in:
Tetiana Mohorian 2025-04-28 22:59:28 +00:00
parent 0ec9a0efb6
commit ad05bc1fe2

View File

@ -1,95 +1,132 @@
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from flask_cors import CORS from flask_cors import CORS
import json import json
import torch import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from flask_caching import Cache from flask_caching import Cache
import hashlib import hashlib
import re
from datetime import datetime from datetime import datetime
import os import os
from flask import Response from flask import Response
import pytz
app = Flask(__name__)
CORS(app) app = Flask(__name__)
CORS(app)
app.config['CACHE_TYPE'] = 'SimpleCache'
cache = Cache(app) app.config['CACHE_TYPE'] = 'SimpleCache'
cache = Cache(app)
model_path = "tetianamohorian/hate_speech_model"
HISTORY_FILE = "history.json" model_path = "tetianamohorian/hate_speech_model"
HISTORY_FILE = "history.json"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval() model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
def generate_text_hash(text):
return hashlib.md5(text.encode('utf-8')).hexdigest() def generate_text_hash(text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
def save_to_history(text, prediction_label):
entry = { def get_current_time():
"text": text, tz = pytz.timezone('Europe/Bratislava')
"prediction": prediction_label, now = datetime.now(tz)
"timestamp": datetime.now().strftime("%d.%m.%Y %H:%M:%S") return now.strftime("%d.%m.%Y %H:%M:%S")
}
def save_to_history(text, prediction_label):
if os.path.exists(HISTORY_FILE): entry = {
with open(HISTORY_FILE, "r", encoding="utf-8") as f: "text": text,
history = json.load(f) "prediction": prediction_label,
else: "timestamp": get_current_time()
history = [] }
history.append(entry) if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "w", encoding="utf-8") as f: with open(HISTORY_FILE, "r", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2) history = json.load(f)
else:
@app.route("/api/predict", methods=["POST"]) history = []
def predict():
try: history.append(entry)
data = request.json with open(HISTORY_FILE, "w", encoding="utf-8") as f:
text = data.get("text", "") json.dump(history, f, ensure_ascii=False, indent=2)
text_hash = generate_text_hash(text) @app.route("/api/predict", methods=["POST"])
cached_result = cache.get(text_hash) def predict():
if cached_result: try:
save_to_history(text, cached_result) data = request.json
return jsonify({"prediction": cached_result}), 200 text = data.get("text", "")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) if not text:
return jsonify({"error": "Text nesmie byť prázdny."}), 400
with torch.no_grad(): if len(text) > 512:
outputs = model(**inputs) return jsonify({"error": "Text je príliš dlhý. Maximálne 512 znakov."}), 400
predictions = torch.argmax(outputs.logits, dim=1).item() if re.search(r"[а-яА-ЯёЁ]", text):
return jsonify({"error": "Text nesmie obsahovať azbuku (cyriliku)."}), 400
prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text"
cache.set(text_hash, prediction_label)
text_hash = generate_text_hash(text)
save_to_history(text, prediction_label) cached_result = cache.get(text_hash)
if cached_result:
return jsonify({"prediction": prediction_label}), 200 save_to_history(text, cached_result)
return jsonify({"prediction": cached_result}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500 inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
@app.route("/api/history", methods=["GET"]) predictions = torch.argmax(outputs.logits, dim=1).item()
def get_history():
try: prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text"
if os.path.exists(HISTORY_FILE): cache.set(text_hash, prediction_label)
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
history = json.load(f) save_to_history(text, prediction_label)
return Response(
json.dumps(history, ensure_ascii=False, indent=2), return jsonify({"prediction": prediction_label}), 200
mimetype="application/json"
) except Exception as e:
else: return jsonify({"error": str(e)}), 500
return jsonify([]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history", methods=["GET"])
if __name__ == "__main__": def get_history():
port = int(os.environ.get("PORT", 5000)) try:
app.run(host="0.0.0.0", port=port) if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
history = json.load(f)
return Response(
json.dumps(history, ensure_ascii=False, indent=2),
mimetype="application/json"
)
else:
return jsonify([]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history/raw", methods=["GET"])
def get_raw_history():
try:
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
content = f.read()
return Response(content, mimetype="application/json")
else:
return jsonify({"error": "history.json not found"}), 404
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history/reset", methods=["POST"])
def reset_history():
try:
with open(HISTORY_FILE, "w", encoding="utf-8") as f:
json.dump([], f, ensure_ascii=False, indent=2)
return jsonify({"message": "History reset successful."}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 5000))
app.run(host="0.0.0.0", port=port)