tg_ws_proxy.py
This commit is contained in:
parent
4ae7cb92f7
commit
e0f01614a3
|
|
@ -22,6 +22,9 @@ _RECV_BUF = 256 * 1024
|
||||||
_SEND_BUF = 256 * 1024
|
_SEND_BUF = 256 * 1024
|
||||||
_WS_POOL_SIZE = 4
|
_WS_POOL_SIZE = 4
|
||||||
_WS_POOL_MAX_AGE = 120.0
|
_WS_POOL_MAX_AGE = 120.0
|
||||||
|
_TCP_ONLY_PORTS = {5222}
|
||||||
|
_WS_ONLY_PORTS = {443}
|
||||||
|
_DYN_IP_CACHE_MAX = 256
|
||||||
|
|
||||||
_TG_RANGES = [
|
_TG_RANGES = [
|
||||||
# 185.76.151.0/24
|
# 185.76.151.0/24
|
||||||
|
|
@ -43,13 +46,14 @@ _IP_TO_DC: Dict[str, Tuple[int, bool]] = {
|
||||||
# DC1
|
# DC1
|
||||||
'149.154.175.50': (1, False), '149.154.175.51': (1, False),
|
'149.154.175.50': (1, False), '149.154.175.51': (1, False),
|
||||||
'149.154.175.53': (1, False), '149.154.175.54': (1, False),
|
'149.154.175.53': (1, False), '149.154.175.54': (1, False),
|
||||||
'149.154.175.52': (1, True),
|
'149.154.175.52': (1, True), '149.154.175.211': (1, False),
|
||||||
# DC2
|
# DC2
|
||||||
'149.154.167.41': (2, False), '149.154.167.50': (2, False),
|
'149.154.167.41': (2, False), '149.154.167.50': (2, False),
|
||||||
'149.154.167.51': (2, False), '149.154.167.220': (2, False),
|
'149.154.167.51': (2, False), '149.154.167.220': (2, False),
|
||||||
'95.161.76.100': (2, False),
|
'95.161.76.100': (2, False),
|
||||||
'149.154.167.151': (2, True), '149.154.167.222': (2, True),
|
'149.154.167.151': (2, True), '149.154.167.222': (2, True),
|
||||||
'149.154.167.223': (2, True), '149.154.162.123': (2, True),
|
'149.154.167.223': (2, True), '149.154.162.123': (2, True),
|
||||||
|
'149.154.167.35': (2, False), '149.154.167.255': (2, True),
|
||||||
# DC3
|
# DC3
|
||||||
'149.154.175.100': (3, False), '149.154.175.101': (3, False),
|
'149.154.175.100': (3, False), '149.154.175.101': (3, False),
|
||||||
'149.154.175.102': (3, True),
|
'149.154.175.102': (3, True),
|
||||||
|
|
@ -74,6 +78,7 @@ _DC_OVERRIDES: Dict[int, int] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
_dc_opt: Dict[int, Optional[str]] = {}
|
_dc_opt: Dict[int, Optional[str]] = {}
|
||||||
|
_prefer_tcp_for_media = False
|
||||||
|
|
||||||
# DCs where WS is known to fail (302 redirect)
|
# DCs where WS is known to fail (302 redirect)
|
||||||
# Raw TCP fallback will be used instead
|
# Raw TCP fallback will be used instead
|
||||||
|
|
@ -82,7 +87,10 @@ _ws_blacklist: Set[Tuple[int, bool]] = set()
|
||||||
|
|
||||||
# Rate-limit re-attempts per (dc, is_media)
|
# Rate-limit re-attempts per (dc, is_media)
|
||||||
_dc_fail_until: Dict[Tuple[int, bool], float] = {}
|
_dc_fail_until: Dict[Tuple[int, bool], float] = {}
|
||||||
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
|
_dc_fail_count: Dict[Tuple[int, bool], int] = {}
|
||||||
|
_domain_success: Dict[Tuple[int, bool], str] = {}
|
||||||
|
_DC_FAIL_COOLDOWN = 15.0 # base seconds to keep reduced WS timeout after failure
|
||||||
|
_DC_FAIL_COOLDOWN_MAX = 120.0
|
||||||
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -151,6 +159,16 @@ class RawWebSocket:
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
|
def is_usable(self) -> bool:
|
||||||
|
if self._closed:
|
||||||
|
return False
|
||||||
|
if self.writer.is_closing():
|
||||||
|
return False
|
||||||
|
transport = self.writer.transport
|
||||||
|
if transport is None or transport.is_closing():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def connect(ip: str, domain: str, path: str = '/apiws',
|
async def connect(ip: str, domain: str, path: str = '/apiws',
|
||||||
timeout: float = 10.0) -> 'RawWebSocket':
|
timeout: float = 10.0) -> 'RawWebSocket':
|
||||||
|
|
@ -472,8 +490,15 @@ class _MsgSplitter:
|
||||||
def _ws_domains(dc: int, is_media) -> List[str]:
|
def _ws_domains(dc: int, is_media) -> List[str]:
|
||||||
dc = _DC_OVERRIDES.get(dc, dc)
|
dc = _DC_OVERRIDES.get(dc, dc)
|
||||||
if is_media is None or is_media:
|
if is_media is None or is_media:
|
||||||
return [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org']
|
domains = [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org']
|
||||||
return [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org']
|
else:
|
||||||
|
domains = [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org']
|
||||||
|
key = (dc, bool(is_media))
|
||||||
|
preferred = _domain_success.get(key)
|
||||||
|
if preferred in domains:
|
||||||
|
domains.remove(preferred)
|
||||||
|
domains.insert(0, preferred)
|
||||||
|
return domains
|
||||||
|
|
||||||
|
|
||||||
class Stats:
|
class Stats:
|
||||||
|
|
@ -488,6 +513,13 @@ class Stats:
|
||||||
self.bytes_down = 0
|
self.bytes_down = 0
|
||||||
self.pool_hits = 0
|
self.pool_hits = 0
|
||||||
self.pool_misses = 0
|
self.pool_misses = 0
|
||||||
|
self.media_ws_success = 0
|
||||||
|
self.media_ws_fail = 0
|
||||||
|
self.media_tcp_fallback = 0
|
||||||
|
self.media_unknown_dc = 0
|
||||||
|
self.media_init_patched = 0
|
||||||
|
self.port_tcp_only_hits = 0
|
||||||
|
self.port_ws_attempts = 0
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
return (f"total={self.connections_total} ws={self.connections_ws} "
|
return (f"total={self.connections_total} ws={self.connections_ws} "
|
||||||
|
|
@ -495,6 +527,11 @@ class Stats:
|
||||||
f"http_skip={self.connections_http_rejected} "
|
f"http_skip={self.connections_http_rejected} "
|
||||||
f"pass={self.connections_passthrough} "
|
f"pass={self.connections_passthrough} "
|
||||||
f"err={self.ws_errors} "
|
f"err={self.ws_errors} "
|
||||||
|
f"media(ws={self.media_ws_success},fail={self.media_ws_fail},"
|
||||||
|
f"tcp={self.media_tcp_fallback},unk={self.media_unknown_dc},"
|
||||||
|
f"patch={self.media_init_patched}) "
|
||||||
|
f"ports(tcp_only={self.port_tcp_only_hits},"
|
||||||
|
f"ws_try={self.port_ws_attempts}) "
|
||||||
f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} "
|
f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} "
|
||||||
f"up={_human_bytes(self.bytes_up)} "
|
f"up={_human_bytes(self.bytes_up)} "
|
||||||
f"down={_human_bytes(self.bytes_down)}")
|
f"down={_human_bytes(self.bytes_down)}")
|
||||||
|
|
@ -503,6 +540,95 @@ class Stats:
|
||||||
_stats = Stats()
|
_stats = Stats()
|
||||||
|
|
||||||
|
|
||||||
|
def _remember_ip_mapping(ip: str, dc: int, is_media: bool):
|
||||||
|
if not ip or ':' in ip:
|
||||||
|
return
|
||||||
|
current = _IP_TO_DC.get(ip)
|
||||||
|
if current == (dc, is_media):
|
||||||
|
return
|
||||||
|
if current is None and len(_IP_TO_DC) >= (128 + _DYN_IP_CACHE_MAX):
|
||||||
|
return
|
||||||
|
_IP_TO_DC[ip] = (dc, is_media)
|
||||||
|
log.debug("learned IP mapping %s -> DC%d%s", ip, dc, 'm' if is_media else '')
|
||||||
|
|
||||||
|
|
||||||
|
def _target_ip_for(dc: int, dst: str) -> Optional[str]:
|
||||||
|
target = _dc_opt.get(dc)
|
||||||
|
if target:
|
||||||
|
return target
|
||||||
|
if _is_telegram_ip(dst):
|
||||||
|
return dst
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _guess_dc_candidates(dst: str) -> List[Tuple[int, bool, str]]:
|
||||||
|
candidates: List[Tuple[int, bool, str]] = []
|
||||||
|
seen: Set[Tuple[int, bool, str]] = set()
|
||||||
|
|
||||||
|
mapped = _IP_TO_DC.get(dst)
|
||||||
|
if mapped is not None:
|
||||||
|
dc, is_media = mapped
|
||||||
|
target = _target_ip_for(dc, dst)
|
||||||
|
if target:
|
||||||
|
item = (dc, bool(is_media), target)
|
||||||
|
seen.add(item)
|
||||||
|
candidates.append(item)
|
||||||
|
|
||||||
|
prefixes = []
|
||||||
|
parts = dst.split('.')
|
||||||
|
if len(parts) == 4:
|
||||||
|
prefixes = ['.'.join(parts[:3]) + '.', '.'.join(parts[:2]) + '.']
|
||||||
|
for ip, (dc, is_media) in list(_IP_TO_DC.items()):
|
||||||
|
target = _target_ip_for(dc, dst)
|
||||||
|
if not target:
|
||||||
|
continue
|
||||||
|
if any(ip.startswith(prefix) for prefix in prefixes):
|
||||||
|
item = (dc, bool(is_media), target)
|
||||||
|
if item not in seen:
|
||||||
|
seen.add(item)
|
||||||
|
candidates.append(item)
|
||||||
|
|
||||||
|
for dc, target in _dc_opt.items():
|
||||||
|
if not target:
|
||||||
|
continue
|
||||||
|
for is_media in (False, True):
|
||||||
|
item = (dc, is_media, target)
|
||||||
|
if item not in seen:
|
||||||
|
seen.add(item)
|
||||||
|
candidates.append(item)
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
def _ws_mode_for(port: int, is_media: bool) -> str:
|
||||||
|
if port in _TCP_ONLY_PORTS:
|
||||||
|
return 'tcp'
|
||||||
|
if port in _WS_ONLY_PORTS:
|
||||||
|
if is_media and _prefer_tcp_for_media:
|
||||||
|
return 'tcp'
|
||||||
|
return 'ws'
|
||||||
|
if is_media and _prefer_tcp_for_media:
|
||||||
|
return 'tcp'
|
||||||
|
return 'ws'
|
||||||
|
|
||||||
|
|
||||||
|
def _register_ws_success(dc_key: Tuple[int, bool], domain: Optional[str] = None):
|
||||||
|
_dc_fail_until.pop(dc_key, None)
|
||||||
|
_dc_fail_count.pop(dc_key, None)
|
||||||
|
if domain:
|
||||||
|
_domain_success[dc_key] = domain
|
||||||
|
|
||||||
|
|
||||||
|
def _register_ws_failure(dc_key: Tuple[int, bool], redirect_only: bool):
|
||||||
|
fails = _dc_fail_count.get(dc_key, 0) + 1
|
||||||
|
_dc_fail_count[dc_key] = fails
|
||||||
|
cooldown = min(_DC_FAIL_COOLDOWN * (2 ** (fails - 1)), _DC_FAIL_COOLDOWN_MAX)
|
||||||
|
_dc_fail_until[dc_key] = time.monotonic() + cooldown
|
||||||
|
if redirect_only:
|
||||||
|
_ws_blacklist.add(dc_key)
|
||||||
|
return cooldown
|
||||||
|
|
||||||
|
|
||||||
class _WsPool:
|
class _WsPool:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._idle: Dict[Tuple[int, bool], list] = {}
|
self._idle: Dict[Tuple[int, bool], list] = {}
|
||||||
|
|
@ -518,7 +644,7 @@ class _WsPool:
|
||||||
while bucket:
|
while bucket:
|
||||||
ws, created = bucket.pop(0)
|
ws, created = bucket.pop(0)
|
||||||
age = now - created
|
age = now - created
|
||||||
if age > _WS_POOL_MAX_AGE or ws._closed:
|
if age > _WS_POOL_MAX_AGE or not ws.is_usable():
|
||||||
asyncio.create_task(self._quiet_close(ws))
|
asyncio.create_task(self._quiet_close(ws))
|
||||||
continue
|
continue
|
||||||
_stats.pool_hits += 1
|
_stats.pool_hits += 1
|
||||||
|
|
@ -550,9 +676,11 @@ class _WsPool:
|
||||||
self._connect_one(target_ip, domains)))
|
self._connect_one(target_ip, domains)))
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
try:
|
try:
|
||||||
ws = await t
|
result = await t
|
||||||
if ws:
|
if result:
|
||||||
|
ws, domain = result
|
||||||
bucket.append((ws, time.monotonic()))
|
bucket.append((ws, time.monotonic()))
|
||||||
|
_domain_success[(dc, is_media)] = domain
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
log.debug("WS pool refilled DC%d%s: %d ready",
|
log.debug("WS pool refilled DC%d%s: %d ready",
|
||||||
|
|
@ -561,12 +689,12 @@ class _WsPool:
|
||||||
self._refilling.discard(key)
|
self._refilling.discard(key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]:
|
async def _connect_one(target_ip, domains) -> Optional[Tuple[RawWebSocket, str]]:
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
try:
|
try:
|
||||||
ws = await RawWebSocket.connect(
|
ws = await RawWebSocket.connect(
|
||||||
target_ip, domain, timeout=8)
|
target_ip, domain, timeout=8)
|
||||||
return ws
|
return ws, domain
|
||||||
except WsHandshakeError as exc:
|
except WsHandshakeError as exc:
|
||||||
if exc.is_redirect:
|
if exc.is_redirect:
|
||||||
continue
|
continue
|
||||||
|
|
@ -631,6 +759,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
except (asyncio.CancelledError, ConnectionError, OSError):
|
except (asyncio.CancelledError, ConnectionError, OSError):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
ws._closed = True
|
||||||
log.debug("[%s] tcp->ws ended: %s", label, e)
|
log.debug("[%s] tcp->ws ended: %s", label, e)
|
||||||
|
|
||||||
async def ws_to_tcp():
|
async def ws_to_tcp():
|
||||||
|
|
@ -651,6 +780,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
except (asyncio.CancelledError, ConnectionError, OSError):
|
except (asyncio.CancelledError, ConnectionError, OSError):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
ws._closed = True
|
||||||
log.debug("[%s] ws->tcp ended: %s", label, e)
|
log.debug("[%s] ws->tcp ended: %s", label, e)
|
||||||
|
|
||||||
tasks = [asyncio.create_task(tcp_to_ws()),
|
tasks = [asyncio.create_task(tcp_to_ws()),
|
||||||
|
|
@ -698,7 +828,9 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer,
|
||||||
else:
|
else:
|
||||||
_stats.bytes_down += len(data)
|
_stats.bytes_down += len(data)
|
||||||
dst_w.write(data)
|
dst_w.write(data)
|
||||||
await dst_w.drain()
|
buf = dst_w.transport.get_write_buffer_size()
|
||||||
|
if buf > _SEND_BUF:
|
||||||
|
await dst_w.drain()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -808,8 +940,13 @@ async def _handle_client(reader, writer):
|
||||||
dlen = (await reader.readexactly(1))[0]
|
dlen = (await reader.readexactly(1))[0]
|
||||||
dst = (await reader.readexactly(dlen)).decode()
|
dst = (await reader.readexactly(dlen)).decode()
|
||||||
elif atyp == 4: # IPv6
|
elif atyp == 4: # IPv6
|
||||||
raw = await reader.readexactly(16)
|
await reader.readexactly(16)
|
||||||
dst = _socket.inet_ntop(_socket.AF_INET6, raw)
|
await reader.readexactly(2)
|
||||||
|
log.debug("[%s] IPv6 SOCKS request rejected", label)
|
||||||
|
writer.write(_socks5_reply(0x08))
|
||||||
|
await writer.drain()
|
||||||
|
writer.close()
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
writer.write(_socks5_reply(0x08))
|
writer.write(_socks5_reply(0x08))
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
|
|
@ -819,12 +956,8 @@ async def _handle_client(reader, writer):
|
||||||
port = struct.unpack('!H', await reader.readexactly(2))[0]
|
port = struct.unpack('!H', await reader.readexactly(2))[0]
|
||||||
|
|
||||||
if ':' in dst:
|
if ':' in dst:
|
||||||
log.error(
|
log.debug("[%s] rejected non-IPv4 destination %s:%d", label, dst, port)
|
||||||
"[%s] IPv6 address detected: %s:%d — "
|
writer.write(_socks5_reply(0x08))
|
||||||
"IPv6 addresses are not supported; "
|
|
||||||
"disable IPv6 to continue using the proxy.",
|
|
||||||
label, dst, port)
|
|
||||||
writer.write(_socks5_reply(0x05))
|
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
writer.close()
|
writer.close()
|
||||||
return
|
return
|
||||||
|
|
@ -881,29 +1014,86 @@ async def _handle_client(reader, writer):
|
||||||
# -- Extract DC ID --
|
# -- Extract DC ID --
|
||||||
dc, is_media = _dc_from_init(init)
|
dc, is_media = _dc_from_init(init)
|
||||||
init_patched = False
|
init_patched = False
|
||||||
|
|
||||||
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
|
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
|
||||||
if dc is None and dst in _IP_TO_DC:
|
if dc is None and dst in _IP_TO_DC:
|
||||||
dc, is_media = _IP_TO_DC.get(dst)
|
dc, is_media = _IP_TO_DC.get(dst)
|
||||||
if dc in _dc_opt:
|
signed_dc = -dc if is_media else dc
|
||||||
init = _patch_init_dc(init, dc if is_media else -dc)
|
init = _patch_init_dc(init, signed_dc)
|
||||||
init_patched = True
|
init_patched = True
|
||||||
|
if is_media:
|
||||||
|
_stats.media_init_patched += 1
|
||||||
|
|
||||||
if dc is None or dc not in _dc_opt:
|
if dc is None:
|
||||||
log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough",
|
guessed = _guess_dc_candidates(dst)
|
||||||
label, dc, dst, port)
|
if guessed:
|
||||||
await _tcp_fallback(reader, writer, dst, port, init, label)
|
log.info("[%s] unknown DC for %s:%d -> trying guessed WS candidates: %s",
|
||||||
return
|
label, dst, port, ', '.join(f'DC{gdc}{'m' if gmedia else ""}@{gtarget}' for gdc, gmedia, gtarget in guessed[:6]))
|
||||||
|
last_media = False
|
||||||
|
for gdc, gis_media, gtarget in guessed:
|
||||||
|
dc = gdc
|
||||||
|
is_media = gis_media
|
||||||
|
last_media = gis_media
|
||||||
|
signed_dc = -gdc if gis_media else gdc
|
||||||
|
patched_init = _patch_init_dc(init, signed_dc)
|
||||||
|
_remember_ip_mapping(dst, gdc, gis_media)
|
||||||
|
target = _target_ip_for(gdc, dst) or gtarget
|
||||||
|
if target:
|
||||||
|
init = patched_init
|
||||||
|
init_patched = True
|
||||||
|
if gis_media:
|
||||||
|
_stats.media_init_patched += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
dc = None
|
||||||
|
is_media = last_media
|
||||||
|
if dc is None:
|
||||||
|
if is_media:
|
||||||
|
_stats.media_unknown_dc += 1
|
||||||
|
log.warning("[%s] unknown DC for %s:%d -> TCP passthrough",
|
||||||
|
label, dst, port)
|
||||||
|
await _tcp_fallback(reader, writer, dst, port, init, label)
|
||||||
|
return
|
||||||
|
|
||||||
dc_key = (dc, is_media if is_media is not None else True)
|
_remember_ip_mapping(dst, dc, bool(is_media))
|
||||||
|
dc_key = (dc, bool(is_media))
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
media_tag = (" media" if is_media
|
media_tag = (" media" if is_media
|
||||||
else (" media?" if is_media is None else ""))
|
else (" media?" if is_media is None else ""))
|
||||||
|
target = _target_ip_for(dc, dst)
|
||||||
|
mode = _ws_mode_for(port, bool(is_media))
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
if is_media:
|
||||||
|
_stats.media_unknown_dc += 1
|
||||||
|
log.warning("[%s] DC%d%s has no target IP for %s:%d -> TCP passthrough",
|
||||||
|
label, dc, media_tag, dst, port)
|
||||||
|
await _tcp_fallback(reader, writer, dst, port, init, label,
|
||||||
|
dc=dc, is_media=is_media)
|
||||||
|
return
|
||||||
|
|
||||||
|
if mode == 'tcp':
|
||||||
|
_stats.port_tcp_only_hits += 1
|
||||||
|
if is_media:
|
||||||
|
_stats.media_tcp_fallback += 1
|
||||||
|
log.info("[%s] DC%d%s port %d policy -> TCP %s:%d",
|
||||||
|
label, dc, media_tag, port, dst, port)
|
||||||
|
ok = await _tcp_fallback(reader, writer, dst, port, init,
|
||||||
|
label, dc=dc, is_media=is_media)
|
||||||
|
if ok:
|
||||||
|
log.info("[%s] DC%d%s TCP policy session closed",
|
||||||
|
label, dc, media_tag)
|
||||||
|
return
|
||||||
|
|
||||||
|
_stats.port_ws_attempts += 1
|
||||||
|
|
||||||
# -- WS blacklist check --
|
# -- WS blacklist check --
|
||||||
if dc_key in _ws_blacklist:
|
if dc_key in _ws_blacklist:
|
||||||
log.debug("[%s] DC%d%s WS blacklisted -> TCP %s:%d",
|
log.debug("[%s] DC%d%s WS blacklisted -> TCP %s:%d",
|
||||||
label, dc, media_tag, dst, port)
|
label, dc, media_tag, dst, port)
|
||||||
|
if is_media:
|
||||||
|
_stats.media_ws_fail += 1
|
||||||
|
_stats.media_tcp_fallback += 1
|
||||||
ok = await _tcp_fallback(reader, writer, dst, port, init,
|
ok = await _tcp_fallback(reader, writer, dst, port, init,
|
||||||
label, dc=dc, is_media=is_media)
|
label, dc=dc, is_media=is_media)
|
||||||
if ok:
|
if ok:
|
||||||
|
|
@ -916,15 +1106,16 @@ async def _handle_client(reader, writer):
|
||||||
ws_timeout = _WS_FAIL_TIMEOUT if now < fail_until else 10.0
|
ws_timeout = _WS_FAIL_TIMEOUT if now < fail_until else 10.0
|
||||||
|
|
||||||
domains = _ws_domains(dc, is_media)
|
domains = _ws_domains(dc, is_media)
|
||||||
target = _dc_opt[dc]
|
|
||||||
ws = None
|
ws = None
|
||||||
ws_failed_redirect = False
|
ws_failed_redirect = False
|
||||||
all_redirects = True
|
all_redirects = True
|
||||||
|
selected_domain = None
|
||||||
|
|
||||||
ws = await _ws_pool.get(dc, is_media, target, domains)
|
ws = await _ws_pool.get(dc, bool(is_media), target, domains)
|
||||||
if ws:
|
if ws:
|
||||||
log.info("[%s] DC%d%s (%s:%d) -> pool hit via %s",
|
log.info("[%s] DC%d%s (%s:%d) -> pool hit via %s",
|
||||||
label, dc, media_tag, dst, port, target)
|
label, dc, media_tag, dst, port, target)
|
||||||
|
selected_domain = _domain_success.get(dc_key)
|
||||||
else:
|
else:
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
url = f'wss://{domain}/apiws'
|
url = f'wss://{domain}/apiws'
|
||||||
|
|
@ -934,6 +1125,7 @@ async def _handle_client(reader, writer):
|
||||||
ws = await RawWebSocket.connect(target, domain,
|
ws = await RawWebSocket.connect(target, domain,
|
||||||
timeout=ws_timeout)
|
timeout=ws_timeout)
|
||||||
all_redirects = False
|
all_redirects = False
|
||||||
|
selected_domain = domain
|
||||||
break
|
break
|
||||||
except WsHandshakeError as exc:
|
except WsHandshakeError as exc:
|
||||||
_stats.ws_errors += 1
|
_stats.ws_errors += 1
|
||||||
|
|
@ -962,18 +1154,18 @@ async def _handle_client(reader, writer):
|
||||||
|
|
||||||
# -- WS failed -> fallback --
|
# -- WS failed -> fallback --
|
||||||
if ws is None:
|
if ws is None:
|
||||||
|
cooldown = _register_ws_failure(dc_key, ws_failed_redirect and all_redirects)
|
||||||
if ws_failed_redirect and all_redirects:
|
if ws_failed_redirect and all_redirects:
|
||||||
_ws_blacklist.add(dc_key)
|
|
||||||
log.warning(
|
log.warning(
|
||||||
"[%s] DC%d%s blacklisted for WS (all 302)",
|
"[%s] DC%d%s blacklisted for WS (all 302)",
|
||||||
label, dc, media_tag)
|
label, dc, media_tag)
|
||||||
elif ws_failed_redirect:
|
|
||||||
_dc_fail_until[dc_key] = now + _DC_FAIL_COOLDOWN
|
|
||||||
else:
|
else:
|
||||||
_dc_fail_until[dc_key] = now + _DC_FAIL_COOLDOWN
|
|
||||||
log.info("[%s] DC%d%s WS cooldown for %ds",
|
log.info("[%s] DC%d%s WS cooldown for %ds",
|
||||||
label, dc, media_tag, int(_DC_FAIL_COOLDOWN))
|
label, dc, media_tag, int(cooldown))
|
||||||
|
|
||||||
|
if is_media:
|
||||||
|
_stats.media_ws_fail += 1
|
||||||
|
_stats.media_tcp_fallback += 1
|
||||||
log.info("[%s] DC%d%s -> TCP fallback to %s:%d",
|
log.info("[%s] DC%d%s -> TCP fallback to %s:%d",
|
||||||
label, dc, media_tag, dst, port)
|
label, dc, media_tag, dst, port)
|
||||||
ok = await _tcp_fallback(reader, writer, dst, port, init,
|
ok = await _tcp_fallback(reader, writer, dst, port, init,
|
||||||
|
|
@ -984,8 +1176,10 @@ async def _handle_client(reader, writer):
|
||||||
return
|
return
|
||||||
|
|
||||||
# -- WS success --
|
# -- WS success --
|
||||||
_dc_fail_until.pop(dc_key, None)
|
_register_ws_success(dc_key, selected_domain)
|
||||||
_stats.connections_ws += 1
|
_stats.connections_ws += 1
|
||||||
|
if is_media:
|
||||||
|
_stats.media_ws_success += 1
|
||||||
|
|
||||||
splitter = None
|
splitter = None
|
||||||
if init_patched:
|
if init_patched:
|
||||||
|
|
@ -1023,6 +1217,28 @@ _server_instance = None
|
||||||
_server_stop_event = None
|
_server_stop_event = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _probe_startup(dc_opt: Dict[int, Optional[str]]):
|
||||||
|
for dc, target_ip in dc_opt.items():
|
||||||
|
if not target_ip:
|
||||||
|
continue
|
||||||
|
for is_media in (False, True):
|
||||||
|
domains = _ws_domains(dc, is_media)
|
||||||
|
ok = False
|
||||||
|
used = None
|
||||||
|
for domain in domains:
|
||||||
|
try:
|
||||||
|
ws = await RawWebSocket.connect(target_ip, domain, timeout=4.0)
|
||||||
|
used = domain
|
||||||
|
await ws.close()
|
||||||
|
ok = True
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
log.info("startup probe DC%d%s via %s: %s",
|
||||||
|
dc, 'm' if is_media else '',
|
||||||
|
target_ip, used if ok else 'FAIL')
|
||||||
|
|
||||||
|
|
||||||
async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
||||||
stop_event: Optional[asyncio.Event] = None,
|
stop_event: Optional[asyncio.Event] = None,
|
||||||
host: str = '127.0.0.1'):
|
host: str = '127.0.0.1'):
|
||||||
|
|
@ -1063,6 +1279,7 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
||||||
asyncio.create_task(log_stats())
|
asyncio.create_task(log_stats())
|
||||||
|
|
||||||
await _ws_pool.warmup(dc_opt)
|
await _ws_pool.warmup(dc_opt)
|
||||||
|
asyncio.create_task(_probe_startup(dc_opt))
|
||||||
|
|
||||||
if stop_event:
|
if stop_event:
|
||||||
async def wait_stop():
|
async def wait_stop():
|
||||||
|
|
@ -1109,6 +1326,19 @@ def run_proxy(port: int, dc_opt: Dict[int, str],
|
||||||
asyncio.run(_run(port, dc_opt, stop_event, host))
|
asyncio.run(_run(port, dc_opt, stop_event, host))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_port_set(value: str) -> Set[int]:
|
||||||
|
ports: Set[int] = set()
|
||||||
|
for part in value.split(','):
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
p = int(part)
|
||||||
|
if not (1 <= p <= 65535):
|
||||||
|
raise ValueError(f"Invalid port {p}")
|
||||||
|
ports.add(p)
|
||||||
|
return ports
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser(
|
ap = argparse.ArgumentParser(
|
||||||
description='Telegram Desktop WebSocket Bridge Proxy')
|
description='Telegram Desktop WebSocket Bridge Proxy')
|
||||||
|
|
@ -1120,6 +1350,12 @@ def main():
|
||||||
default=[],
|
default=[],
|
||||||
help='Target IP for a DC, e.g. --dc-ip 1:149.154.175.205'
|
help='Target IP for a DC, e.g. --dc-ip 1:149.154.175.205'
|
||||||
' --dc-ip 2:149.154.167.220')
|
' --dc-ip 2:149.154.167.220')
|
||||||
|
ap.add_argument('--tcp-only-ports', type=str, default='5222',
|
||||||
|
help='Comma-separated Telegram destination ports that should always use direct TCP (default 5222)')
|
||||||
|
ap.add_argument('--ws-ports', type=str, default='443',
|
||||||
|
help='Comma-separated Telegram destination ports that should prefer WebSocket bridge (default 443)')
|
||||||
|
ap.add_argument('--prefer-tcp-for-media', action='store_true',
|
||||||
|
help='Route media Telegram sessions over direct TCP when possible')
|
||||||
ap.add_argument('-v', '--verbose', action='store_true',
|
ap.add_argument('-v', '--verbose', action='store_true',
|
||||||
help='Debug logging')
|
help='Debug logging')
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
@ -1127,12 +1363,17 @@ def main():
|
||||||
if not args.dc_ip:
|
if not args.dc_ip:
|
||||||
args.dc_ip = ['2:149.154.167.220', '4:149.154.167.220']
|
args.dc_ip = ['2:149.154.167.220', '4:149.154.167.220']
|
||||||
|
|
||||||
|
global _prefer_tcp_for_media, _TCP_ONLY_PORTS, _WS_ONLY_PORTS
|
||||||
try:
|
try:
|
||||||
dc_opt = parse_dc_ip_list(args.dc_ip)
|
dc_opt = parse_dc_ip_list(args.dc_ip)
|
||||||
|
_TCP_ONLY_PORTS = _parse_port_set(args.tcp_only_ports)
|
||||||
|
_WS_ONLY_PORTS = _parse_port_set(args.ws_ports)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
_prefer_tcp_for_media = args.prefer_tcp_for_media
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
format='%(asctime)s %(levelname)-5s %(message)s',
|
format='%(asctime)s %(levelname)-5s %(message)s',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue