bakalarka_praca/telegram_bot/bot.py

97 lines
3.0 KiB
Python

import os
import torch
import logging
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from telegram import Update
from telegram.ext import Application, MessageHandler, filters, CommandHandler, CallbackContext
from dotenv import load_dotenv
import mysql.connector
# Load environment variables from .env file
load_dotenv()
TOKEN = os.getenv("TOKEN")
# Path to the model
MODEL_PATH = "./hate_speech_model/final_model"
# Database configuration
db_config = {
"host": "mysql",
"user": "root",
"password": "0674998280tanya",
"database": "telegram_bot"
}
def save_violator(username, message):
"""Saves the violator's message to the database"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
query = "INSERT INTO violators (username, message) VALUES (%s, %s)"
cursor.execute(query, (username, message))
conn.commit()
cursor.close()
conn.close()
except mysql.connector.Error as err:
logging.error(f"MySQL error: {err}")
# Logging setup
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO,
handlers=[
logging.FileHandler("log.txt"),
logging.StreamHandler()
]
)
# Load tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
def classify_text(text):
"""Classifies the input text"""
model.eval()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred = torch.argmax(logits, dim=-1).item()
return "🛑 Hate speech detected" if pred == 1 else "✅ OK"
async def check_message(update: Update, context: CallbackContext):
"""Checks chat messages and reacts to toxic content"""
message_text = update.message.text
result = classify_text(message_text)
if result == "🛑 Hate speech detected":
username = update.message.from_user.username or "unknown"
await update.message.reply_text("⚠️ Warning! Please maintain respectful communication.")
await update.message.delete() # Automatically deletes toxic message
# Log the toxic message
logging.warning(f"Toxic message from {username}: {message_text}")
save_violator(username, message_text)
async def start(update: Update, context: CallbackContext):
"""Sends welcome message on /start"""
await update.message.reply_text("Hi! I'm monitoring the chat for hate speech!")
def main():
"""Starts the bot"""
app = Application.builder().token(TOKEN).build()
# Register command and message handlers
app.add_handler(CommandHandler("start", start))
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, check_message))
# Start polling
app.run_polling()
if __name__ == "__main__":
main()