from flask import Flask, request, jsonify from flask_cors import CORS import json import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from flask_caching import Cache import hashlib import re from datetime import datetime import os from flask import Response import pytz app = Flask(__name__) CORS(app) app.config['CACHE_TYPE'] = 'SimpleCache' cache = Cache(app) model_path = "tetianamohorian/hate_speech_model" HISTORY_FILE = "history.json" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval() def generate_text_hash(text): return hashlib.md5(text.encode('utf-8')).hexdigest() def get_current_time(): tz = pytz.timezone('Europe/Bratislava') now = datetime.now(tz) return now.strftime("%d.%m.%Y %H:%M:%S") def save_to_history(text, prediction_label): entry = { "text": text, "prediction": prediction_label, "timestamp": get_current_time() } if os.path.exists(HISTORY_FILE): with open(HISTORY_FILE, "r", encoding="utf-8") as f: history = json.load(f) else: history = [] history.append(entry) with open(HISTORY_FILE, "w", encoding="utf-8") as f: json.dump(history, f, ensure_ascii=False, indent=2) @app.route("/api/predict", methods=["POST"]) def predict(): try: data = request.json text = data.get("text", "") if not text: return jsonify({"error": "Text nesmie byť prázdny."}), 400 if len(text) > 512: return jsonify({"error": "Text je príliš dlhý. Maximálne 512 znakov."}), 400 if re.search(r"[а-яА-ЯёЁ]", text): return jsonify({"error": "Text nesmie obsahovať azbuku (cyriliku)."}), 400 text_hash = generate_text_hash(text) cached_result = cache.get(text_hash) if cached_result: save_to_history(text, cached_result) return jsonify({"prediction": cached_result}), 200 inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) predictions = torch.argmax(outputs.logits, dim=1).item() prediction_label = "Pravdepodobne toxický" if predictions == 1 else "Neutrálny text" cache.set(text_hash, prediction_label) save_to_history(text, prediction_label) return jsonify({"prediction": prediction_label}), 200 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/history", methods=["GET"]) def get_history(): try: if os.path.exists(HISTORY_FILE): with open(HISTORY_FILE, "r", encoding="utf-8") as f: history = json.load(f) return Response( json.dumps(history, ensure_ascii=False, indent=2), mimetype="application/json" ) else: return jsonify([]), 200 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/history/raw", methods=["GET"]) def get_raw_history(): try: if os.path.exists(HISTORY_FILE): with open(HISTORY_FILE, "r", encoding="utf-8") as f: content = f.read() return Response(content, mimetype="application/json") else: return jsonify({"error": "history.json not found"}), 404 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/history/reset", methods=["POST"]) def reset_history(): try: with open(HISTORY_FILE, "w", encoding="utf-8") as f: json.dump([], f, ensure_ascii=False, indent=2) return jsonify({"message": "History reset successful."}), 200 except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": port = int(os.environ.get("PORT", 5000)) app.run(host="0.0.0.0", port=port)