zkt25/sk1/backend/app.py

132 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify
from flask_cors import CORS
import json
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from flask_caching import Cache
import hashlib
import re
from datetime import datetime
import os
from flask import Response
import pytz
app = Flask(__name__)
CORS(app)
app.config['CACHE_TYPE'] = 'SimpleCache'
cache = Cache(app)
model_path = "tetianamohorian/hate_speech_model"
HISTORY_FILE = "history.json"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
def generate_text_hash(text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
def get_current_time():
tz = pytz.timezone('Europe/Bratislava')
now = datetime.now(tz)
return now.strftime("%d.%m.%Y %H:%M:%S")
def save_to_history(text, prediction_label):
entry = {
"text": text,
"prediction": prediction_label,
"timestamp": get_current_time()
}
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
history = json.load(f)
else:
history = []
history.append(entry)
with open(HISTORY_FILE, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
@app.route("/api/predict", methods=["POST"])
def predict():
try:
data = request.json
text = data.get("text", "")
if not text:
return jsonify({"error": "Text nesmie byť prázdny."}), 400
if len(text) > 512:
return jsonify({"error": "Text je príliš dlhý. Maximálne 512 znakov."}), 400
if re.search(r"[а-яА-ЯёЁ]", text):
return jsonify({"error": "Text nesmie obsahovať azbuku (cyriliku)."}), 400
text_hash = generate_text_hash(text)
cached_result = cache.get(text_hash)
if cached_result:
save_to_history(text, cached_result)
return jsonify({"prediction": cached_result}), 200
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1).item()
prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text"
cache.set(text_hash, prediction_label)
save_to_history(text, prediction_label)
return jsonify({"prediction": prediction_label}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history", methods=["GET"])
def get_history():
try:
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
history = json.load(f)
return Response(
json.dumps(history, ensure_ascii=False, indent=2),
mimetype="application/json"
)
else:
return jsonify([]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history/raw", methods=["GET"])
def get_raw_history():
try:
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
content = f.read()
return Response(content, mimetype="application/json")
else:
return jsonify({"error": "history.json not found"}), 404
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/history/reset", methods=["POST"])
def reset_history():
try:
with open(HISTORY_FILE, "w", encoding="utf-8") as f:
json.dump([], f, ensure_ascii=False, indent=2)
return jsonify({"message": "History reset successful."}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 5000))
app.run(host="0.0.0.0", port=port)