From ad05bc1fe2132d495f1a49640f893bc33a99cfc1 Mon Sep 17 00:00:00 2001 From: Tetiana Mohorian Date: Mon, 28 Apr 2025 22:59:28 +0000 Subject: [PATCH] Aktualizovat sk1/backend/app.py --- sk1/backend/app.py | 227 ++++++++++++++++++++++++++------------------- 1 file changed, 132 insertions(+), 95 deletions(-) diff --git a/sk1/backend/app.py b/sk1/backend/app.py index 4aeb5a6..7b260c5 100644 --- a/sk1/backend/app.py +++ b/sk1/backend/app.py @@ -1,95 +1,132 @@ -from flask import Flask, request, jsonify -from flask_cors import CORS -import json - -import torch -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from flask_caching import Cache -import hashlib - -from datetime import datetime -import os -from flask import Response - -app = Flask(__name__) -CORS(app) - -app.config['CACHE_TYPE'] = 'SimpleCache' -cache = Cache(app) - -model_path = "tetianamohorian/hate_speech_model" -HISTORY_FILE = "history.json" - -tokenizer = AutoTokenizer.from_pretrained(model_path) -model = AutoModelForSequenceClassification.from_pretrained(model_path) -model.eval() - - -def generate_text_hash(text): - return hashlib.md5(text.encode('utf-8')).hexdigest() - -def save_to_history(text, prediction_label): - entry = { - "text": text, - "prediction": prediction_label, - "timestamp": datetime.now().strftime("%d.%m.%Y %H:%M:%S") - } - - if os.path.exists(HISTORY_FILE): - with open(HISTORY_FILE, "r", encoding="utf-8") as f: - history = json.load(f) - else: - history = [] - - history.append(entry) - with open(HISTORY_FILE, "w", encoding="utf-8") as f: - json.dump(history, f, ensure_ascii=False, indent=2) - -@app.route("/api/predict", methods=["POST"]) -def predict(): - try: - data = request.json - text = data.get("text", "") - - text_hash = generate_text_hash(text) - cached_result = cache.get(text_hash) - if cached_result: - save_to_history(text, cached_result) - return jsonify({"prediction": cached_result}), 200 - - inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) - - with torch.no_grad(): - outputs = model(**inputs) - predictions = torch.argmax(outputs.logits, dim=1).item() - - prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text" - cache.set(text_hash, prediction_label) - - save_to_history(text, prediction_label) - - return jsonify({"prediction": prediction_label}), 200 - - except Exception as e: - return jsonify({"error": str(e)}), 500 - - - -@app.route("/api/history", methods=["GET"]) -def get_history(): - try: - if os.path.exists(HISTORY_FILE): - with open(HISTORY_FILE, "r", encoding="utf-8") as f: - history = json.load(f) - return Response( - json.dumps(history, ensure_ascii=False, indent=2), - mimetype="application/json" - ) - else: - return jsonify([]), 200 - except Exception as e: - return jsonify({"error": str(e)}), 500 - -if __name__ == "__main__": - port = int(os.environ.get("PORT", 5000)) - app.run(host="0.0.0.0", port=port) +from flask import Flask, request, jsonify +from flask_cors import CORS +import json + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from flask_caching import Cache +import hashlib +import re +from datetime import datetime +import os +from flask import Response +import pytz + +app = Flask(__name__) +CORS(app) + +app.config['CACHE_TYPE'] = 'SimpleCache' +cache = Cache(app) + +model_path = "tetianamohorian/hate_speech_model" +HISTORY_FILE = "history.json" + +tokenizer = AutoTokenizer.from_pretrained(model_path) +model = AutoModelForSequenceClassification.from_pretrained(model_path) +model.eval() + + +def generate_text_hash(text): + return hashlib.md5(text.encode('utf-8')).hexdigest() + +def get_current_time(): + tz = pytz.timezone('Europe/Bratislava') + now = datetime.now(tz) + return now.strftime("%d.%m.%Y %H:%M:%S") + +def save_to_history(text, prediction_label): + entry = { + "text": text, + "prediction": prediction_label, + "timestamp": get_current_time() + } + + if os.path.exists(HISTORY_FILE): + with open(HISTORY_FILE, "r", encoding="utf-8") as f: + history = json.load(f) + else: + history = [] + + history.append(entry) + with open(HISTORY_FILE, "w", encoding="utf-8") as f: + json.dump(history, f, ensure_ascii=False, indent=2) + +@app.route("/api/predict", methods=["POST"]) +def predict(): + try: + data = request.json + text = data.get("text", "") + + if not text: + return jsonify({"error": "Text nesmie byť prázdny."}), 400 + if len(text) > 512: + return jsonify({"error": "Text je príliš dlhý. Maximálne 512 znakov."}), 400 + if re.search(r"[а-яА-ЯёЁ]", text): + return jsonify({"error": "Text nesmie obsahovať azbuku (cyriliku)."}), 400 + + + text_hash = generate_text_hash(text) + cached_result = cache.get(text_hash) + if cached_result: + save_to_history(text, cached_result) + return jsonify({"prediction": cached_result}), 200 + + inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + + with torch.no_grad(): + outputs = model(**inputs) + predictions = torch.argmax(outputs.logits, dim=1).item() + + prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text" + cache.set(text_hash, prediction_label) + + save_to_history(text, prediction_label) + + return jsonify({"prediction": prediction_label}), 200 + + except Exception as e: + return jsonify({"error": str(e)}), 500 + + + +@app.route("/api/history", methods=["GET"]) +def get_history(): + try: + if os.path.exists(HISTORY_FILE): + with open(HISTORY_FILE, "r", encoding="utf-8") as f: + history = json.load(f) + return Response( + json.dumps(history, ensure_ascii=False, indent=2), + mimetype="application/json" + ) + else: + return jsonify([]), 200 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/api/history/raw", methods=["GET"]) +def get_raw_history(): + try: + if os.path.exists(HISTORY_FILE): + with open(HISTORY_FILE, "r", encoding="utf-8") as f: + content = f.read() + return Response(content, mimetype="application/json") + else: + return jsonify({"error": "history.json not found"}), 404 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/api/history/reset", methods=["POST"]) +def reset_history(): + try: + with open(HISTORY_FILE, "w", encoding="utf-8") as f: + json.dump([], f, ensure_ascii=False, indent=2) + return jsonify({"message": "History reset successful."}), 200 + except Exception as e: + return jsonify({"error": str(e)}), 500 + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 5000)) + app.run(host="0.0.0.0", port=port) \ No newline at end of file