97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
from telegram import Update
|
|
from telegram.ext import Application, MessageHandler, filters, CommandHandler, CallbackContext
|
|
from dotenv import load_dotenv
|
|
import mysql.connector
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
TOKEN = os.getenv("TOKEN")
|
|
|
|
# Path to the model
|
|
MODEL_PATH = "./hate_speech_model/final_model"
|
|
|
|
# Database configuration
|
|
db_config = {
|
|
"host": "mysql",
|
|
"user": "root",
|
|
"password": "0674998280tanya",
|
|
"database": "telegram_bot"
|
|
}
|
|
|
|
def save_violator(username, message):
|
|
"""Saves the violator's message to the database"""
|
|
try:
|
|
conn = mysql.connector.connect(**db_config)
|
|
cursor = conn.cursor()
|
|
query = "INSERT INTO violators (username, message) VALUES (%s, %s)"
|
|
cursor.execute(query, (username, message))
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|
|
except mysql.connector.Error as err:
|
|
logging.error(f"MySQL error: {err}")
|
|
|
|
# Logging setup
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
level=logging.INFO,
|
|
handlers=[
|
|
logging.FileHandler("log.txt"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
# Load tokenizer and model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
|
|
|
|
def classify_text(text):
|
|
"""Classifies the input text"""
|
|
model.eval()
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
logits = outputs.logits
|
|
pred = torch.argmax(logits, dim=-1).item()
|
|
|
|
return "🛑 Hate speech detected" if pred == 1 else "✅ OK"
|
|
|
|
async def check_message(update: Update, context: CallbackContext):
|
|
"""Checks chat messages and reacts to toxic content"""
|
|
message_text = update.message.text
|
|
result = classify_text(message_text)
|
|
|
|
if result == "🛑 Hate speech detected":
|
|
username = update.message.from_user.username or "unknown"
|
|
await update.message.reply_text("⚠️ Warning! Please maintain respectful communication.")
|
|
await update.message.delete() # Automatically deletes toxic message
|
|
|
|
# Log the toxic message
|
|
logging.warning(f"Toxic message from {username}: {message_text}")
|
|
save_violator(username, message_text)
|
|
|
|
async def start(update: Update, context: CallbackContext):
|
|
"""Sends welcome message on /start"""
|
|
await update.message.reply_text("Hi! I'm monitoring the chat for hate speech!")
|
|
|
|
def main():
|
|
"""Starts the bot"""
|
|
app = Application.builder().token(TOKEN).build()
|
|
|
|
# Register command and message handlers
|
|
app.add_handler(CommandHandler("start", start))
|
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, check_message))
|
|
|
|
# Start polling
|
|
app.run_polling()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|