доработка
This commit is contained in:
parent
56489de4f2
commit
af78fc633f
@ -3,6 +3,8 @@ pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
aiogram>=3.10.0
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
httpx>=0.25.0
|
||||
yookassa>=2.4.0
|
||||
fastapi>=0.104.0
|
||||
uvicorn>=0.24.0
|
||||
|
||||
@ -1,29 +1,67 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserService:
|
||||
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[UserModel]:
|
||||
result = await self.session.execute(
|
||||
select(UserModel).filter_by(telegram_id=str(telegram_id))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create_user(
|
||||
self,
|
||||
telegram_id: int,
|
||||
username: str = "",
|
||||
first_name: str = "",
|
||||
last_name: str = ""
|
||||
) -> UserModel:
|
||||
user = await self.get_user_by_telegram_id(telegram_id)
|
||||
if not user:
|
||||
user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
else:
|
||||
user.username = username
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
await self.session.commit()
|
||||
return user
|
||||
|
||||
async def update_user_questions(self, telegram_id: int) -> bool:
|
||||
user = await self.get_user_by_telegram_id(telegram_id)
|
||||
if user:
|
||||
user.questions_used += 1
|
||||
await self.session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def activate_premium(self, telegram_id: int) -> bool:
|
||||
try:
|
||||
user = self.session.query(UserModel) \
|
||||
.filter(UserModel.telegram_id == str(telegram_id)) \
|
||||
.first()
|
||||
user = await self.get_user_by_telegram_id(telegram_id)
|
||||
if user:
|
||||
user.is_premium = True
|
||||
if user.premium_until and user.premium_until > datetime.now():
|
||||
user.premium_until = user.premium_until + timedelta(days=30)
|
||||
else:
|
||||
user.premium_until = datetime.now() + timedelta(days=30)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error activating premium: {e}")
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
return False
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from tg_bot.config.settings import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
database_url = settings.DATABASE_URL
|
||||
if database_url.startswith("sqlite:///"):
|
||||
database_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
|
||||
engine = create_async_engine(
|
||||
database_url,
|
||||
echo=settings.DEBUG
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(bind=engine)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
def create_tables():
|
||||
async def create_tables():
|
||||
from .models import Base
|
||||
Base.metadata.create_all(bind=engine)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
print(f"Таблицы созданы: {settings.DATABASE_URL}")
|
||||
@ -32,6 +32,7 @@ async def create_bot() -> tuple[Bot, Dispatcher]:
|
||||
|
||||
|
||||
async def start_bot():
|
||||
bot = None
|
||||
try:
|
||||
bot, dp = await create_bot()
|
||||
|
||||
@ -54,4 +55,5 @@ async def start_bot():
|
||||
logger.error(f"Ошибка запуска: {e}")
|
||||
raise
|
||||
finally:
|
||||
if bot:
|
||||
await bot.session.close()
|
||||
@ -4,8 +4,10 @@ from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
|
||||
from decimal import Decimal
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.payment.yookassa.client import yookassa_client
|
||||
from tg_bot.infrastructure.database.database import SessionLocal
|
||||
from tg_bot.infrastructure.database.models import PaymentModel, UserModel
|
||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
||||
from tg_bot.infrastructure.database.models import PaymentModel
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
from sqlalchemy import select
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -17,11 +19,10 @@ async def cmd_buy(message: Message):
|
||||
user_id = message.from_user.id
|
||||
username = message.from_user.username or f"user_{user_id}"
|
||||
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_telegram_id(user_id)
|
||||
|
||||
if user and user.is_premium and user.premium_until and user.premium_until > datetime.now():
|
||||
days_left = (user.premium_until - datetime.now()).days
|
||||
@ -33,8 +34,8 @@ async def cmd_buy(message: Message):
|
||||
f"Новая подписка будет добавлена к текущей.",
|
||||
parse_mode="HTML"
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await message.answer(
|
||||
"*Создаю ссылку для оплаты...*\n\n"
|
||||
@ -49,7 +50,7 @@ async def cmd_buy(message: Message):
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
payment = PaymentModel(
|
||||
payment_id=str(uuid.uuid4()),
|
||||
@ -61,13 +62,11 @@ async def cmd_buy(message: Message):
|
||||
description="Оплата подписки VibeLawyerBot"
|
||||
)
|
||||
session.add(payment)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
print(f"Платёж сохранён в БД: {payment.payment_id}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка сохранения платежа в БД: {e}")
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
await session.rollback()
|
||||
|
||||
keyboard = InlineKeyboardMarkup(
|
||||
inline_keyboard=[
|
||||
@ -140,29 +139,22 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
||||
payment = YooPayment.find_one(yookassa_id)
|
||||
|
||||
if payment.status == "succeeded":
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
db_payment = session.query(PaymentModel).filter_by(
|
||||
yookassa_payment_id=yookassa_id
|
||||
).first()
|
||||
result = await session.execute(
|
||||
select(PaymentModel).filter_by(yookassa_payment_id=yookassa_id)
|
||||
)
|
||||
db_payment = result.scalar_one_or_none()
|
||||
|
||||
if db_payment:
|
||||
db_payment.status = "succeeded"
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
|
||||
if user:
|
||||
user.is_premium = True
|
||||
if user.premium_until and user.premium_until > datetime.now():
|
||||
user.premium_until = user.premium_until + timedelta(days=30)
|
||||
else:
|
||||
user.premium_until = datetime.now() + timedelta(days=30)
|
||||
|
||||
session.commit()
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
user_service = UserService(session)
|
||||
success = await user_service.activate_premium(user_id)
|
||||
if success:
|
||||
user = await user_service.get_user_by_telegram_id(user_id)
|
||||
await session.commit()
|
||||
if not user:
|
||||
user = await user_service.get_user_by_telegram_id(user_id)
|
||||
|
||||
await callback_query.message.answer(
|
||||
"<b>Оплата подтверждена!</b>\n\n"
|
||||
@ -181,8 +173,8 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
||||
"Пожалуйста, обратитесь к администратору.",
|
||||
parse_mode="HTML"
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print(f"Ошибка обработки платежа: {e}")
|
||||
|
||||
elif payment.status == "pending":
|
||||
await callback_query.message.answer(
|
||||
@ -212,16 +204,16 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
||||
parse_mode="HTML"
|
||||
)
|
||||
|
||||
|
||||
@router.message(Command("mypayments"))
|
||||
async def cmd_my_payments(message: Message):
|
||||
user_id = message.from_user.id
|
||||
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
payments = session.query(PaymentModel).filter_by(
|
||||
user_id=user_id
|
||||
).order_by(PaymentModel.created_at.desc()).limit(10).all()
|
||||
result = await session.execute(
|
||||
select(PaymentModel).filter_by(user_id=user_id).order_by(PaymentModel.created_at.desc()).limit(10)
|
||||
)
|
||||
payments = result.scalars().all()
|
||||
|
||||
if not payments:
|
||||
await message.answer(
|
||||
@ -248,9 +240,8 @@ async def cmd_my_payments(message: Message):
|
||||
"\n".join(response),
|
||||
parse_mode="HTML"
|
||||
)
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print(f"Ошибка получения платежей: {e}")
|
||||
|
||||
|
||||
@router.message(Command("testcards"))
|
||||
|
||||
@ -3,8 +3,9 @@ from aiogram.types import Message
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.database.database import SessionLocal
|
||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
from tg_bot.application.services.rag_service import RAGService
|
||||
|
||||
router = Router()
|
||||
@ -18,29 +19,25 @@ async def handle_question(message: Message):
|
||||
if question_text.startswith('/'):
|
||||
return
|
||||
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_telegram_id(user_id)
|
||||
|
||||
if not user:
|
||||
user = UserModel(
|
||||
telegram_id=str(user_id),
|
||||
username=message.from_user.username or "",
|
||||
first_name=message.from_user.first_name or "",
|
||||
last_name=message.from_user.last_name or ""
|
||||
user = await user_service.get_or_create_user(
|
||||
user_id,
|
||||
message.from_user.username or "",
|
||||
message.from_user.first_name or "",
|
||||
message.from_user.last_name or ""
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
await ensure_user_in_backend(str(user_id), message.from_user)
|
||||
|
||||
if user.is_premium:
|
||||
await process_premium_question(message, user, question_text, session)
|
||||
await process_premium_question(message, user, question_text, user_service)
|
||||
|
||||
elif user.questions_used < settings.FREE_QUESTIONS_LIMIT:
|
||||
await process_free_question(message, user, question_text, session)
|
||||
await process_free_question(message, user, question_text, user_service)
|
||||
|
||||
else:
|
||||
await handle_limit_exceeded(message, user)
|
||||
@ -51,8 +48,6 @@ async def handle_question(message: Message):
|
||||
"Произошла ошибка. Попробуйте позже.",
|
||||
parse_mode="HTML"
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
async def ensure_user_in_backend(telegram_id: str, telegram_user):
|
||||
@ -74,9 +69,8 @@ async def ensure_user_in_backend(telegram_id: str, telegram_user):
|
||||
print(f"Error creating user in backend: {e}")
|
||||
|
||||
|
||||
async def process_premium_question(message: Message, user: UserModel, question_text: str, session):
|
||||
user.questions_used += 1
|
||||
session.commit()
|
||||
async def process_premium_question(message: Message, user: UserModel, question_text: str, user_service: UserService):
|
||||
await user_service.update_user_questions(user.telegram_id)
|
||||
|
||||
await message.bot.send_chat_action(message.chat.id, "typing")
|
||||
|
||||
@ -135,10 +129,10 @@ async def process_premium_question(message: Message, user: UserModel, question_t
|
||||
await message.answer(response, parse_mode="HTML")
|
||||
|
||||
|
||||
async def process_free_question(message: Message, user: UserModel, question_text: str, session):
|
||||
user.questions_used += 1
|
||||
async def process_free_question(message: Message, user: UserModel, question_text: str, user_service: UserService):
|
||||
await user_service.update_user_questions(user.telegram_id)
|
||||
user = await user_service.get_user_by_telegram_id(user.telegram_id)
|
||||
remaining = settings.FREE_QUESTIONS_LIMIT - user.questions_used
|
||||
session.commit()
|
||||
|
||||
await message.bot.send_chat_action(message.chat.id, "typing")
|
||||
|
||||
|
||||
@ -4,8 +4,8 @@ from aiogram.types import Message
|
||||
from datetime import datetime
|
||||
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.database.database import SessionLocal
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
|
||||
router = Router()
|
||||
|
||||
@ -16,33 +16,22 @@ async def cmd_start(message: Message):
|
||||
username = message.from_user.username or ""
|
||||
first_name = message.from_user.first_name or ""
|
||||
last_name = message.from_user.last_name or ""
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
user = UserModel(
|
||||
telegram_id=str(user_id),
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
user_service = UserService(session)
|
||||
existing_user = await user_service.get_user_by_telegram_id(user_id)
|
||||
user = await user_service.get_or_create_user(
|
||||
user_id,
|
||||
username,
|
||||
first_name,
|
||||
last_name
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
if not existing_user:
|
||||
print(f"Новый пользователь: {user_id}")
|
||||
else:
|
||||
user.username = username
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ошибка сохранения пользователя: {e}")
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
await session.rollback()
|
||||
welcome_text = (
|
||||
f"<b>Привет, {first_name}!</b>\n\n"
|
||||
f"Я <b>VibeLawyerBot</b> - ваш помощник в юридических вопросах.\n\n"
|
||||
|
||||
@ -4,8 +4,8 @@ from aiogram.filters import Command
|
||||
from aiogram.types import Message
|
||||
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.database.database import SessionLocal
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
|
||||
router = Router()
|
||||
|
||||
@ -14,11 +14,10 @@ router = Router()
|
||||
async def cmd_stats(message: Message):
|
||||
user_id = message.from_user.id
|
||||
|
||||
session = SessionLocal()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_telegram_id(user_id)
|
||||
|
||||
if user:
|
||||
stats_text = (
|
||||
@ -42,7 +41,6 @@ async def cmd_stats(message: Message):
|
||||
f"• Осталось вопросов: {remaining}\n"
|
||||
f"• Для безлимита: /buy\n\n"
|
||||
)
|
||||
|
||||
else:
|
||||
stats_text = (
|
||||
f"<b>Добро пожаловать!</b>\n\n"
|
||||
@ -61,5 +59,3 @@ async def cmd_stats(message: Message):
|
||||
f"Попробуйте позже.",
|
||||
parse_mode="HTML"
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@ -19,24 +19,26 @@ async def handle_yookassa_webhook(request: Request):
|
||||
try:
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
from tg_bot.infrastructure.database.database import SessionLocal
|
||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
from sqlalchemy import select
|
||||
from aiogram import Bot
|
||||
|
||||
session = SessionLocal()
|
||||
if event_type == "payment.succeeded":
|
||||
payment = data.get("object", {})
|
||||
user_id = payment.get("metadata", {}).get("user_id")
|
||||
|
||||
if user_id:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_service = UserService(session)
|
||||
success = await user_service.activate_premium(int(user_id))
|
||||
if success:
|
||||
print(f"Premium activated for user {user_id}")
|
||||
|
||||
user = session.query(UserModel).filter_by(
|
||||
telegram_id=str(user_id)
|
||||
).first()
|
||||
result = await session.execute(
|
||||
select(UserModel).filter_by(telegram_id=str(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user and settings.TELEGRAM_BOT_TOKEN:
|
||||
try:
|
||||
@ -59,7 +61,6 @@ async def handle_yookassa_webhook(request: Request):
|
||||
print(f"Error sending notification: {e}")
|
||||
else:
|
||||
print(f"User {user_id} not found")
|
||||
session.close()
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Import error: {e}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user