90 lines
4.0 KiB
Python
90 lines
4.0 KiB
Python
from __future__ import annotations
|
|
import logging
|
|
import redis.asyncio as aioredis
|
|
from app.config import Config
|
|
from app.providers import SendResult
|
|
from app.providers.base import BaseProvider
|
|
logger = logging.getLogger(__name__)
|
|
RATE_KEY = "sms:rate:{phone}"
|
|
CODE_KEY = "sms:code:{phone}"
|
|
|
|
class RateLimitExceeded(Exception):
|
|
def __init__(self, retry_after: int) -> None:
|
|
self.retry_after = retry_after
|
|
super().__init__(f"Rate limit exceeded, retry after {retry_after}s")
|
|
|
|
class SmsService:
|
|
def __init__(self, config: Config, providers: dict[str, BaseProvider], redis: aioredis.Redis) -> None:
|
|
self.config = config
|
|
self.providers = providers
|
|
self.redis = redis
|
|
|
|
async def send_code(self, phone_number: str, code: str | None = None) -> SendResult:
|
|
normalized = phone_number if phone_number.startswith("+") else f"+{phone_number}"
|
|
await self._check_rate_limit(normalized)
|
|
primary_name, fallback_name = self.config.resolve_provider(normalized)
|
|
result = await self._try_send(primary_name, normalized, code=code)
|
|
if not result.success and fallback_name:
|
|
logger.warning(
|
|
"Провайдер %s недоступен для %s, пробуем fallback: %s",
|
|
primary_name, normalized, fallback_name,
|
|
)
|
|
result = await self._try_send(fallback_name, normalized, code=code)
|
|
if result.success and result.code:
|
|
ttl = self.config.settings.code_ttl_seconds
|
|
key = CODE_KEY.format(phone=normalized)
|
|
await self.redis.set(key, result.code, ex=ttl)
|
|
if self.config.settings.log_codes:
|
|
logger.info("Код сохранён: phone=%s code=%s provider=%s", normalized, result.code, result.provider)
|
|
return result
|
|
|
|
async def _check_rate_limit(self, phone: str) -> None:
|
|
rl = self.config.settings.rate_limit
|
|
if not rl.enabled:
|
|
return
|
|
key = RATE_KEY.format(phone=phone)
|
|
pipe = self.redis.pipeline()
|
|
pipe.incr(key)
|
|
pipe.ttl(key)
|
|
count, ttl = await pipe.execute()
|
|
if count == 1:
|
|
await self.redis.expire(key, rl.window_seconds)
|
|
ttl = rl.window_seconds
|
|
if count > rl.max_attempts:
|
|
retry_after = ttl if ttl > 0 else rl.window_seconds
|
|
logger.warning("Rate limit для %s: попытка %d/%d, retry_after=%ds", phone, count, rl.max_attempts, retry_after)
|
|
raise RateLimitExceeded(retry_after=retry_after)
|
|
|
|
async def _try_send(self, provider_name: str, phone: str, code: str | None = None) -> SendResult:
|
|
provider = self.providers.get(provider_name)
|
|
if provider is None:
|
|
logger.error("Провайдер не найден: %s", provider_name)
|
|
return SendResult(success=False, provider=provider_name, error=f"Provider '{provider_name}' not found")
|
|
return await provider.send(phone, code=code)
|
|
|
|
async def get_pending_code(self, phone_number: str) -> str | None:
|
|
normalized = phone_number if phone_number.startswith("+") else f"+{phone_number}"
|
|
key = CODE_KEY.format(phone=normalized)
|
|
return await self.redis.get(key)
|
|
|
|
async def consume_code(self, phone_number: str) -> str | None:
|
|
normalized = phone_number if phone_number.startswith("+") else f"+{phone_number}"
|
|
key = CODE_KEY.format(phone=normalized)
|
|
pipe = self.redis.pipeline()
|
|
pipe.get(key)
|
|
pipe.delete(key)
|
|
code, _ = await pipe.execute()
|
|
return code
|
|
|
|
async def list_pending_codes(self) -> list[dict]:
|
|
pattern = CODE_KEY.format(phone="*")
|
|
result = []
|
|
async for key in self.redis.scan_iter(pattern):
|
|
pipe = self.redis.pipeline()
|
|
pipe.get(key)
|
|
pipe.ttl(key)
|
|
code, ttl = await pipe.execute()
|
|
if code:
|
|
phone = key.replace("sms:code:", "")
|
|
result.append({"phone": phone, "code": code, "expires_in": max(ttl, 0)})
|
|
return result |