diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 7945e70..71f9f4e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,9 +2,9 @@ name: Rust on: push: - branches: [ main ] + branches: [ "*" ] pull_request: - branches: [ main ] + branches: [ "*" ] env: CARGO_TERM_COLOR: always diff --git a/.kilocode/rules-architect/AGENTS.md b/.kilocode/rules-architect/AGENTS.md new file mode 100644 index 0000000..84e8808 --- /dev/null +++ b/.kilocode/rules-architect/AGENTS.md @@ -0,0 +1,58 @@ +# Architect Mode Rules for Telemt + +## Architecture Overview + +```mermaid +graph TB + subgraph Entry + Client[Clients] --> Listener[TCP/Unix Listener] + end + + subgraph Proxy Layer + Listener --> ClientHandler[ClientHandler] + ClientHandler --> Handshake[Handshake Validator] + Handshake --> |Valid| Relay[Relay Layer] + Handshake --> |Invalid| Masking[Masking/TLS Fronting] + end + + subgraph Transport + Relay --> MiddleProxy[Middle-End Proxy Pool] + Relay --> DirectRelay[Direct DC Relay] + MiddleProxy --> TelegramDC[Telegram DCs] + DirectRelay --> TelegramDC + end +``` + +## Module Dependencies +- [`src/main.rs`](src/main.rs) - Entry point, spawns all async tasks +- [`src/config/`](src/config/) - Configuration loading with auto-migration +- [`src/error.rs`](src/error.rs) - Error types, must be used by all modules +- [`src/crypto/`](src/crypto/) - AES, SHA, random number generation +- [`src/protocol/`](src/protocol/) - MTProto constants, frame encoding, obfuscation +- [`src/stream/`](src/stream/) - Stream wrappers, buffer pool, frame codecs +- [`src/proxy/`](src/proxy/) - Client handling, handshake, relay logic +- [`src/transport/`](src/transport/) - Upstream management, middle-proxy, SOCKS support +- [`src/stats/`](src/stats/) - Statistics and replay protection +- [`src/ip_tracker.rs`](src/ip_tracker.rs) - Per-user IP tracking + +## Key Architectural Constraints + +### Middle-End Proxy Mode +- Requires public IP on interface OR 1:1 NAT with STUN probing +- Uses separate `proxy-secret` from Telegram (NOT user secrets) +- Falls back to direct mode automatically on STUN mismatch + +### TLS Fronting +- Invalid handshakes are transparently proxied to `mask_host` +- This is critical for DPI evasion - do not change this behavior +- `mask_unix_sock` and `mask_host` are mutually exclusive + +### Stream Architecture +- Buffer pool is shared globally via Arc - prevents allocation storms +- Frame codecs implement tokio-util Encoder/Decoder traits +- State machine in [`src/stream/state.rs`](src/stream/state.rs) manages stream transitions + +### Configuration Migration +- [`ProxyConfig::load()`](src/config/mod.rs:641) mutates config in-place +- New fields must have sensible defaults +- DC203 override is auto-injected for CDN/media support diff --git a/.kilocode/rules-code/AGENTS.md b/.kilocode/rules-code/AGENTS.md new file mode 100644 index 0000000..df9f664 --- /dev/null +++ b/.kilocode/rules-code/AGENTS.md @@ -0,0 +1,23 @@ +# Code Mode Rules for Telemt + +## Error Handling +- Always use [`ProxyError`](src/error.rs:168) from [`src/error.rs`](src/error.rs) for proxy operations +- [`HandshakeResult`](src/error.rs:292) returns streams on bad client - these MUST be returned for masking, never dropped +- Use [`Recoverable`](src/error.rs:110) trait to check if errors are retryable + +## Configuration Changes +- [`ProxyConfig::load()`](src/config/mod.rs:641) auto-mutates config - new fields should have defaults +- DC203 override is auto-injected if missing - do not remove this behavior +- When adding config fields, add migration logic in [`ProxyConfig::load()`](src/config/mod.rs:641) + +## Crypto Code +- [`SecureRandom`](src/crypto/random.rs) from [`src/crypto/random.rs`](src/crypto/random.rs) must be used for all crypto operations +- Never use `rand::thread_rng()` directly - use the shared `Arc` + +## Stream Handling +- Buffer pool [`BufferPool`](src/stream/buffer_pool.rs) is shared via Arc - always use it instead of allocating +- Frame codecs in [`src/stream/frame_codec.rs`](src/stream/frame_codec.rs) implement tokio-util's Encoder/Decoder traits + +## Testing +- Tests are inline in modules using `#[cfg(test)]` +- Use `cargo test --lib ` to run tests for specific modules diff --git a/.kilocode/rules-debug/AGENTS.md b/.kilocode/rules-debug/AGENTS.md new file mode 100644 index 0000000..9d390b1 --- /dev/null +++ b/.kilocode/rules-debug/AGENTS.md @@ -0,0 +1,27 @@ +# Debug Mode Rules for Telemt + +## Logging +- `RUST_LOG` environment variable takes absolute priority over all config log levels +- Log levels: `trace`, `debug`, `info`, `warn`, `error` +- Use `RUST_LOG=debug cargo run` for detailed operational logs +- Use `RUST_LOG=trace cargo run` for full protocol-level debugging + +## Middle-End Proxy Debugging +- Set `ME_DIAG=1` environment variable for high-precision cryptography diagnostics +- STUN probe results are logged at startup - check for mismatch between local and reflected IP +- If Middle-End fails, check `proxy_secret_path` points to valid file from https://core.telegram.org/getProxySecret + +## Connection Issues +- DC connectivity is logged at startup with RTT measurements +- If DC ping fails, check `dc_overrides` for custom addresses +- Use `prefer_ipv6=false` in config if IPv6 is unreliable + +## TLS Fronting Issues +- Invalid handshakes are proxied to `mask_host` - check this host is reachable +- `mask_unix_sock` and `mask_host` are mutually exclusive - only one can be set +- If `mask_unix_sock` is set, socket must exist before connections arrive + +## Common Errors +- `ReplayAttack` - client replayed a handshake nonce, potential attack +- `TimeSkew` - client clock is off, can disable with `ignore_time_skew=true` +- `TgHandshakeTimeout` - upstream DC connection failed, check network diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..dc582ae --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,40 @@ +# AGENTS.md + +** Use general system promt from AGENTS_SYSTEM_PROMT.md ** +** Additional techiques and architectury details are here ** + +This file provides guidance to agents when working with code in this repository. + +## Build & Test Commands +```bash +cargo build --release # Production build +cargo test # Run all tests +cargo test --lib error # Run tests for specific module (error module) +cargo bench --bench crypto_bench # Run crypto benchmarks +cargo clippy -- -D warnings # Lint with clippy +``` + +## Project-Specific Conventions + +### Rust Edition +- Uses **Rust edition 2024** (not 2021) - specified in Cargo.toml + +### Error Handling Pattern +- Custom [`Recoverable`](src/error.rs:110) trait distinguishes recoverable vs fatal errors +- [`HandshakeResult`](src/error.rs:292) returns streams on bad client for masking - do not drop them +- Always use [`ProxyError`](src/error.rs:168) from [`src/error.rs`](src/error.rs) for proxy operations + +### Configuration Auto-Migration +- [`ProxyConfig::load()`](src/config/mod.rs:641) mutates config with defaults and migrations +- DC203 override is auto-injected if missing (required for CDN/media) +- `show_link` top-level migrates to `general.links.show` + +### Middle-End Proxy Requirements +- Requires public IP on interface OR 1:1 NAT with STUN probing +- Falls back to direct mode on STUN/interface mismatch unless `stun_iface_mismatch_ignore=true` +- Proxy-secret from Telegram is separate from user secrets + +### TLS Fronting Behavior +- Invalid handshakes are transparently proxied to `mask_host` for DPI evasion +- `fake_cert_len` is randomized at startup (1024-4096 bytes) +- `mask_unix_sock` and `mask_host` are mutually exclusive diff --git a/AGENTS_SYSTEM_PROMT.md b/AGENTS_SYSTEM_PROMT.md new file mode 100644 index 0000000..fd2815c --- /dev/null +++ b/AGENTS_SYSTEM_PROMT.md @@ -0,0 +1,112 @@ +## System Prompt - Modification and Architecture Guidelines + +You are working on a production-grade Rust codebase: follow these rules strictly! + +### 1. Comments and Documentation + +- All comments MUST be written in English. +- Comments MUST be concise, precise, and technical. +- Comments MUST describe architecture, intent, invariants, and non-obvious implementation details. +- DO NOT add decorative, conversational, or redundant comments. +- DO NOT add trailing comments at the end of code lines. +- Place comments on separate lines above the relevant code. + +Correct example: + +```rust +// Handles MTProto client authentication and establishes encrypted session state. +fn handle_authenticated_client(...) { ... } +``` + +Incorrect example: + +```rust +let x = 5; // set x to 5 lol +``` + +--- + +### 2. File Size and Module Structure + +- DO NOT create files larger than 350–550 lines. +- If a file exceeds this limit, split it into submodules. +- Organize submodules logically by responsibility (e.g., protocol, transport, state, handlers). +- Parent modules MUST declare and describe submodules. +- Use local git for versioning and diffs, write CORRECT and FULL comments to commits with descriptions + +Correct example: + +```rust +// Client connection handling logic. +// Submodules: +// - handshake: MTProto handshake implementation +// - relay: traffic forwarding logic +// - state: client session state machine + +pub mod handshake; +pub mod relay; +pub mod state; +``` + +* Maintain clear architectural boundaries between modules. + +--- + +### 3. Formatting + +- DO NOT run `cargo fmt`. +- DO NOT reformat existing code unless explicitly instructed. +- Preserve the existing formatting style of the project. + +--- + +### 4. Change Safety and Validation + +- DO NOT guess intent, behavior, or missing requirements. +- If anything is unclear, STOP and ask questions. +- Actively ask questions before making architectural or behavioral changes. +- Prefer clarification over assumptions. + +--- + +### 5. Warnings and Unused Code + +- DO NOT fix warnings unless explicitly instructed. +- DO NOT remove: + + - unused variables + - unused functions + - unused imports + - dead code + +These may be intentional. + +--- + +### 6. Architectural Integrity + +- Preserve existing architecture unless explicitly instructed to refactor. +- DO NOT introduce hidden behavioral changes. +- DO NOT introduce implicit refactors. +- Keep changes minimal, isolated, and intentional. + +--- + +### 7. When Modifying Code + +You MUST: + +- Maintain architectural consistency. +- Document non-obvious logic. +- Avoid unrelated changes. +- Avoid speculative improvements. + +You MUST NOT: + +- Refactor unrelated code. +- Rename symbols without explicit reason. +- Change formatting globally. + +--- + +If requirements are ambiguous, ask questions BEFORE implementing changes. diff --git a/Cargo.lock b/Cargo.lock index 6054188..38047bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1066,6 +1066,25 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1723,7 +1742,7 @@ dependencies = [ [[package]] name = "telemt" -version = "1.2.0" +version = "3.0.0" dependencies = [ "aes", "base64", @@ -1741,6 +1760,8 @@ dependencies = [ "libc", "lru", "md-5", + "num-bigint", + "num-traits", "parking_lot", "proptest", "rand", diff --git a/Cargo.toml b/Cargo.toml index 4e5a82d..4bb5172 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ libc = "0.2" # Async runtime tokio = { version = "1.42", features = ["full", "tracing"] } -tokio-util = { version = "0.7", features = ["codec"] } +tokio-util = { version = "0.7", features = ["full"] } # Crypto aes = "0.8" @@ -50,6 +50,10 @@ num-traits = "0.2" # HTTP reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } +hyper = { version = "1", features = ["server", "http1"] } +hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } +http-body-util = "0.1" +httpdate = "1.0" [dev-dependencies] tokio-test = "0.4" diff --git a/Dockerfile b/Dockerfile index 9e600c2..662ec22 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,6 +37,7 @@ RUN chown -R telemt:telemt /app USER telemt EXPOSE 443 +EXPOSE 9090 ENTRYPOINT ["/app/telemt"] CMD ["config.toml"] \ No newline at end of file diff --git a/README.md b/README.md index 37e81d2..89040c4 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,28 @@ - с новым подходом к безопасности и асинхронности - с высокоточной диагностикой криптографии через `ME_DIAG` -Для использования нужно указать: +Для использования нужно: + +1. Версия `telemt` ≥3.0.0 +2. Выполнение любого из наборов условий: + - публичный IP для исходящих соединений установлен на интерфейса инстанса с `telemt` + - ЛИБО + - вы используете NAT 1:1 + включили STUN-пробинг +3. В конфиге, в секции `[general]` указать: ```toml use_middle_proxy = true ``` -в версии `telemt` 3.0.0 и последующих. +Если условия из пункта 1 не выполняются: +1. Выключите ME-режим: + - установите `use_middle_proxy = false` + - ЛИБО + - Middle-End Proxy будет выключен автоматически по таймауту, но это займёт больше времени при запуске +2. В конфиге, добавьте в конец: +```toml +[dc_overrides] +"203" = "91.105.192.100:443" +``` Если у вас есть компетенции в асинхронных сетевых приложениях, анализе трафика, реверс-инжиниринге или сетевых расследованиях — мы открыты к идеям и pull requests. @@ -38,11 +54,28 @@ On February 15, we released `telemt 3` with support for Middle-End Proxy, which - new approach to security and asynchronicity - high-precision cryptography diagnostics via `ME_DIAG` -To use it, set: +To use this feature, the following requirements must be met: +1. `telemt` version ≥ 3.0.0 +2. One of the following conditions satisfied: + - the instance running `telemt` has a public IP address assigned to its network interface for outbound connections + - OR + - you are using 1:1 NAT and have STUN probing enabled +3. In the config file, under the `[general]` section, specify: ```toml use_middle_proxy = true +```` + +If the conditions from step 1 are not satisfied: +1. Disable Middle-End mode: + - set `use_middle_proxy = false` + - OR + - Middle-End Proxy will be disabled automatically after a timeout, but this will increase startup time + +2. In the config file, add the following at the end: +```toml +[dc_overrides] +"203" = "91.105.192.100:443" ``` -in version `telemt` 3.0.0 or later. If you have expertise in asynchronous network applications, traffic analysis, reverse engineering, or network forensics — we welcome ideas, suggestions, and pull requests. @@ -175,10 +208,6 @@ then Ctrl+X -> Y -> Enter to save ## Configuration ### Minimal Configuration for First Start ```toml -# === UI === -# Users to show in the startup log (tg:// links) -show_link = ["hello"] - # === General Settings === [general] prefer_ipv6 = false @@ -202,11 +231,19 @@ listen_addr_ipv6 = "::" # Listen on multiple interfaces/IPs (overrides listen_addr_*) [[server.listeners]] ip = "0.0.0.0" -# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links +# announce = "my.hostname.tld" # Optional: hostname for tg:// links +# OR +# announce = "1.2.3.4" # Optional: Public IP for tg:// links [[server.listeners]] ip = "::" +# Users to show in the startup log (tg:// links) +[general.links] +show = ["hello"] # Users to show in the startup log (tg:// links) +# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links +# public_port = 443 # Port for tg:// links (default: server.port) + # === Timeouts (in seconds) === [timeouts] client_handshake = 15 @@ -254,6 +291,10 @@ weight = 10 # address = "127.0.0.1:9050" # enabled = false # weight = 1 + +# === DC Address Overrides === +# [dc_overrides] +# "203" = "91.105.192.100:443" ``` ### Advanced #### Adtag diff --git a/config.toml b/config.toml index 0f5d438..0cddaac 100644 --- a/config.toml +++ b/config.toml @@ -1,7 +1,3 @@ -# === UI === -# Users to show in the startup log (tg:// links) -show_link = ["hello"] - # === General Settings === [general] prefer_ipv6 = true @@ -24,6 +20,8 @@ tls = true port = 443 listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv6 = "::" +# listen_unix_sock = "/var/run/telemt.sock" # Unix socket +# listen_unix_sock_perm = "0666" # Socket file permissions # metrics_port = 9090 # metrics_whitelist = ["127.0.0.1", "::1"] @@ -35,6 +33,12 @@ ip = "0.0.0.0" [[server.listeners]] ip = "::" +# Users to show in the startup log (tg:// links) +[general.links] +show = ["hello"] # Users to show in the startup log (tg:// links) +# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links +# public_port = 443 # Port for tg:// links (default: server.port) + # === Timeouts (in seconds) === [timeouts] client_handshake = 15 @@ -65,7 +69,7 @@ hello = "00000000000000000000000000000000" # hello = 50 # [access.user_max_unique_ips] -# hello = 5 +# hello = 5 # [access.user_data_quota] # hello = 1073741824 # 1 GB @@ -80,4 +84,8 @@ weight = 10 # type = "socks5" # address = "127.0.0.1:1080" # enabled = false -# weight = 1 \ No newline at end of file +# weight = 1 + +# === DC Address Overrides === +# [dc_overrides] +# "203" = "91.105.192.100:443" diff --git a/docker-compose.yml b/docker-compose.yml index 5a23f14..8386caf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,8 +5,13 @@ services: restart: unless-stopped ports: - "443:443" + - "9090:9090" + # Allow caching 'proxy-secret' in read-only container + working_dir: /run/telemt volumes: - - ./config.toml:/app/config.toml:ro + - ./config.toml:/run/telemt/config.toml:ro + tmpfs: + - /run/telemt:rw,mode=1777,size=1m environment: - RUST_LOG=info # Uncomment this line if you want to use host network for IPv6, but bridge is default and usually better diff --git a/src/config/mod.rs b/src/config/mod.rs index a3dee7a..f9d2131 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,6 +3,7 @@ use crate::error::{ProxyError, Result}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use serde::de::Deserializer; use std::collections::HashMap; use std::net::IpAddr; use std::path::Path; @@ -53,6 +54,40 @@ fn default_metrics_whitelist() -> Vec { vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()] } +fn default_unknown_dc_log_path() -> Option { + Some("unknown-dc.txt".to_string()) +} + +// ============= Custom Deserializers ============= + +#[derive(Deserialize)] +#[serde(untagged)] +enum OneOrMany { + One(String), + Many(Vec), +} + +fn deserialize_dc_overrides<'de, D>( + deserializer: D, +) -> std::result::Result>, D::Error> +where + D: Deserializer<'de>, +{ + let raw: HashMap = HashMap::deserialize(deserializer)?; + let mut out = HashMap::new(); + for (dc, val) in raw { + let mut addrs = match val { + OneOrMany::One(s) => vec![s], + OneOrMany::Many(v) => v, + }; + addrs.retain(|s| !s.trim().is_empty()); + if !addrs.is_empty() { + out.insert(dc, addrs); + } + } + Ok(out) +} + // ============= Log Level ============= /// Logging verbosity level @@ -95,6 +130,50 @@ impl LogLevel { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dc_overrides_allow_string_and_array() { + let toml = r#" + [dc_overrides] + "201" = "149.154.175.50:443" + "202" = ["149.154.167.51:443", "149.154.175.100:443"] + "#; + let cfg: ProxyConfig = toml::from_str(toml).unwrap(); + assert_eq!(cfg.dc_overrides["201"], vec!["149.154.175.50:443"]); + assert_eq!( + cfg.dc_overrides["202"], + vec!["149.154.167.51:443", "149.154.175.100:443"] + ); + } + + #[test] + fn dc_overrides_inject_dc203_default() { + let toml = r#" + [general] + use_middle_proxy = false + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_dc_override_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert!(cfg + .dc_overrides + .get("203") + .map(|v| v.contains(&"91.105.192.100:443".to_string())) + .unwrap_or(false)); + let _ = std::fs::remove_file(path); + } +} + impl std::fmt::Display for LogLevel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -163,12 +242,41 @@ pub struct GeneralConfig { #[serde(default)] pub middle_proxy_nat_stun: Option, + /// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected). + #[serde(default)] + pub stun_iface_mismatch_ignore: bool, + + /// Log unknown (non-standard) DC requests to a file (default: unknown-dc.txt). Set to null to disable. + #[serde(default = "default_unknown_dc_log_path")] + pub unknown_dc_log_path: Option, + #[serde(default)] pub log_level: LogLevel, /// Disable colored output in logs (useful for files/systemd) #[serde(default)] pub disable_colors: bool, + + /// [general.links] — proxy link generation overrides + #[serde(default)] + pub links: LinksConfig, +} + +/// `[general.links]` — proxy link generation settings. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LinksConfig { + /// List of usernames whose tg:// links to display at startup. + /// `"*"` = all users, `["alice", "bob"]` = specific users. + #[serde(default)] + pub show: ShowLink, + + /// Public hostname/IP for tg:// link generation (overrides detected IP). + #[serde(default)] + pub public_host: Option, + + /// Public port for tg:// link generation (overrides server.port). + #[serde(default)] + pub public_port: Option, } impl Default for GeneralConfig { @@ -183,8 +291,11 @@ impl Default for GeneralConfig { middle_proxy_nat_ip: None, middle_proxy_nat_probe: false, middle_proxy_nat_stun: None, + stun_iface_mismatch_ignore: false, + unknown_dc_log_path: default_unknown_dc_log_path(), log_level: LogLevel::Normal, disable_colors: false, + links: LinksConfig::default(), } } } @@ -194,8 +305,8 @@ pub struct ServerConfig { #[serde(default = "default_port")] pub port: u16, - #[serde(default = "default_listen_addr")] - pub listen_addr_ipv4: String, + #[serde(default)] + pub listen_addr_ipv4: Option, #[serde(default)] pub listen_addr_ipv6: Option, @@ -203,6 +314,16 @@ pub struct ServerConfig { #[serde(default)] pub listen_unix_sock: Option, + /// Unix socket file permissions (octal, e.g. "0666" or "0777"). + /// Applied via chmod after bind. Default: no change (inherits umask). + #[serde(default)] + pub listen_unix_sock_perm: Option, + + /// Enable TCP listening. Default: true when no unix socket, false when + /// listen_unix_sock is set. Set explicitly to override auto-detection. + #[serde(default)] + pub listen_tcp: Option, + #[serde(default)] pub metrics_port: Option, @@ -217,9 +338,11 @@ impl Default for ServerConfig { fn default() -> Self { Self { port: default_port(), - listen_addr_ipv4: default_listen_addr(), + listen_addr_ipv4: Some(default_listen_addr()), listen_addr_ipv6: Some("::".to_string()), listen_unix_sock: None, + listen_unix_sock_perm: None, + listen_tcp: None, metrics_port: None, metrics_whitelist: default_metrics_whitelist(), listeners: Vec::new(), @@ -374,6 +497,12 @@ pub struct UpstreamConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ListenerConfig { pub ip: IpAddr, + /// IP address or hostname to announce in proxy links. + /// Takes precedence over `announce_ip` if both are set. + #[serde(default)] + pub announce: Option, + /// Deprecated: Use `announce` instead. IP address to announce in proxy links. + /// Migrated to `announce` automatically if `announce` is not set. #[serde(default)] pub announce_ip: Option, } @@ -499,13 +628,13 @@ pub struct ProxyConfig { pub show_link: ShowLink, /// DC address overrides for non-standard DCs (CDN, media, test, etc.) - /// Keys are DC indices as strings, values are "ip:port" addresses. + /// Keys are DC indices as strings, values are one or more \"ip:port\" addresses. /// Matches the C implementation's `proxy_for :` config directive. /// Example in config.toml: /// [dc_overrides] - /// "203" = "149.154.175.100:443" - #[serde(default)] - pub dc_overrides: HashMap, + /// \"203\" = [\"149.154.175.100:443\", \"91.105.192.100:443\"] + #[serde(default, deserialize_with = "deserialize_dc_overrides")] + pub dc_overrides: HashMap>, /// Default DC index (1-5) for unmapped non-standard DCs. /// Matches the C implementation's `default ` config directive. @@ -572,11 +701,28 @@ impl ProxyConfig { use rand::Rng; config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); - // Migration: Populate listeners if empty - if config.server.listeners.is_empty() { - if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::() { + // Resolve listen_tcp: explicit value wins, otherwise auto-detect. + // If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided. + // If no unix socket → TCP always (backward compat). + let listen_tcp = config.server.listen_tcp.unwrap_or_else(|| { + if config.server.listen_unix_sock.is_some() { + // Unix socket present: TCP only if user explicitly set addresses or listeners + config.server.listen_addr_ipv4.is_some() + || !config.server.listeners.is_empty() + } else { + true + } + }); + + // Migration: Populate listeners if empty (skip when listen_tcp = false) + if config.server.listeners.is_empty() && listen_tcp { + let ipv4_str = config.server.listen_addr_ipv4 + .as_deref() + .unwrap_or("0.0.0.0"); + if let Ok(ipv4) = ipv4_str.parse::() { config.server.listeners.push(ListenerConfig { ip: ipv4, + announce: None, announce_ip: None, }); } @@ -584,12 +730,25 @@ impl ProxyConfig { if let Ok(ipv6) = ipv6_str.parse::() { config.server.listeners.push(ListenerConfig { ip: ipv6, + announce: None, announce_ip: None, }); } } } + // Migration: announce_ip → announce for each listener + for listener in &mut config.server.listeners { + if listener.announce.is_none() && listener.announce_ip.is_some() { + listener.announce = Some(listener.announce_ip.unwrap().to_string()); + } + } + + // Migration: show_link (top-level) → general.links.show + if !config.show_link.is_empty() && config.general.links.show.is_empty() { + config.general.links.show = config.show_link.clone(); + } + // Migration: Populate upstreams if empty (Default Direct) if config.upstreams.is_empty() { config.upstreams.push(UpstreamConfig { @@ -599,6 +758,12 @@ impl ProxyConfig { }); } + // Ensure default DC203 override is present. + config + .dc_overrides + .entry("203".to_string()) + .or_insert_with(|| vec!["91.105.192.100:443".to_string()]); + Ok(config) } diff --git a/src/main.rs b/src/main.rs index 8a974be..2dd9a56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,12 +8,15 @@ use tokio::signal; use tokio::sync::Semaphore; use tracing::{debug, error, info, warn}; use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; +#[cfg(unix)] +use tokio::net::UnixListener; mod cli; mod config; mod crypto; mod error; mod ip_tracker; +mod metrics; mod protocol; mod proxy; mod stats; @@ -27,7 +30,10 @@ use crate::ip_tracker::UserIpTracker; use crate::proxy::ClientHandler; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; -use crate::transport::middle_proxy::{MePool, fetch_proxy_config}; +use crate::transport::middle_proxy::{ + MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, + stun_probe, +}; use crate::transport::{ListenOptions, UpstreamManager, create_listener}; use crate::util::ip::detect_ip; use crate::protocol::constants::{TG_MIDDLE_PROXIES_V4, TG_MIDDLE_PROXIES_V6}; @@ -100,6 +106,37 @@ fn parse_cli() -> (String, bool, Option) { (config_path, silent, log_level) } +fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { + info!("--- Proxy Links ({}) ---", host); + for user_name in config.general.links.show.resolve_users(&config.access.users) { + if let Some(secret) = config.access.users.get(user_name) { + info!("User: {}", user_name); + if config.general.modes.classic { + info!( + " Classic: tg://proxy?server={}&port={}&secret={}", + host, port, secret + ); + } + if config.general.modes.secure { + info!( + " DD: tg://proxy?server={}&port={}&secret=dd{}", + host, port, secret + ); + } + if config.general.modes.tls { + let domain_hex = hex::encode(&config.censorship.tls_domain); + info!( + " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + host, port, secret, domain_hex + ); + } + } else { + warn!("User '{}' in show_link not found", user_name); + } + } + info!("------------------------"); +} + #[tokio::main] async fn main() -> std::result::Result<(), Box> { let (config_path, cli_silent, cli_log_level) = parse_cli(); @@ -183,7 +220,7 @@ async fn main() -> std::result::Result<(), Box> { } let prefer_ipv6 = config.general.prefer_ipv6; - let use_middle_proxy = config.general.use_middle_proxy; + let mut use_middle_proxy = config.general.use_middle_proxy; let config = Arc::new(config); let stats = Arc::new(Stats::new()); let rng = Arc::new(SecureRandom::new()); @@ -207,6 +244,41 @@ async fn main() -> std::result::Result<(), Box> { // Connection concurrency limit let _max_connections = Arc::new(Semaphore::new(10_000)); + // STUN check before choosing transport + if use_middle_proxy { + match stun_probe(config.general.middle_proxy_nat_stun.clone()).await { + Ok(Some(probe)) => { + info!( + local_ip = %probe.local_addr.ip(), + reflected_ip = %probe.reflected_addr.ip(), + "STUN Autodetect:" + ); + if probe.local_addr.ip() != probe.reflected_addr.ip() + && !config.general.stun_iface_mismatch_ignore + { + match crate::transport::middle_proxy::detect_public_ip().await { + Some(ip) => { + info!( + local_ip = %probe.local_addr.ip(), + reflected_ip = %probe.reflected_addr.ip(), + public_ip = %ip, + "STUN mismatch but public IP auto-detected, continuing with middle proxy" + ); + } + None => { + warn!( + "STUN/IP-on-Interface mismatch and public IP auto-detect failed -> fallback to direct-DC" + ); + use_middle_proxy = false; + } + } + } + } + Ok(None) => warn!("STUN probe returned no address; continuing"), + Err(e) => warn!(error = %e, "STUN probe failed; continuing"), + } + } + // ===================================================================== // Middle Proxy initialization (if enabled) // ===================================================================== @@ -231,25 +303,25 @@ async fn main() -> std::result::Result<(), Box> { // proxy-secret is from: https://core.telegram.org/getProxySecret // ============================================================= let proxy_secret_path = config.general.proxy_secret_path.as_deref(); - match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { - Ok(proxy_secret) => { - info!( - secret_len = proxy_secret.len(), - key_sig = format_args!( - "0x{:08x}", - if proxy_secret.len() >= 4 { - u32::from_le_bytes([ - proxy_secret[0], - proxy_secret[1], - proxy_secret[2], - proxy_secret[3], - ]) - } else { - 0 - } - ), - "Proxy-secret loaded" - ); +match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { + Ok(proxy_secret) => { + info!( + secret_len = proxy_secret.len() as usize, // ← ЯВНЫЙ ТИП usize + key_sig = format_args!( + "0x{:08x}", + if proxy_secret.len() >= 4 { + u32::from_le_bytes([ + proxy_secret[0], + proxy_secret[1], + proxy_secret[2], + proxy_secret[3], + ]) + } else { + 0 + } + ), + "Proxy-secret loaded" + ); // Load ME config (v4/v6) + default DC let mut cfg_v4 = fetch_proxy_config( @@ -295,6 +367,18 @@ async fn main() -> std::result::Result<(), Box> { .await; }); + // Periodic ME connection rotation + let pool_clone_rot = pool.clone(); + let rng_clone_rot = rng.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_rotation_task( + pool_clone_rot, + rng_clone_rot, + std::time::Duration::from_secs(1800), + ) + .await; + }); + // Periodic updater: getProxyConfig + proxy-secret let pool_clone2 = pool.clone(); let rng_clone2 = rng.clone(); @@ -325,90 +409,154 @@ async fn main() -> std::result::Result<(), Box> { }; if me_pool.is_some() { - info!("Transport: Middle Proxy (supports all DCs including CDN)"); + info!("Transport: Middle-End Proxy - all DC-over-RPC"); } else { - info!("Transport: Direct TCP (standard DCs only)"); + info!("Transport: Direct DC - TCP - standard DC-over-TCP"); } - // Startup DC ping (only meaningful in direct mode) - if me_pool.is_none() { - info!("================= Telegram DC Connectivity ================="); + // Middle-End ping before DC connectivity + if let Some(ref pool) = me_pool { + let me_results = run_me_ping(pool, &rng).await; - let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; + let v4_ok = me_results.iter().any(|r| { + matches!(r.family, MePingFamily::V4) + && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + }); + let v6_ok = me_results.iter().any(|r| { + matches!(r.family, MePingFamily::V6) + && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + }); - for upstream_result in &ping_results { - // Show which IP version is in use and which is fallback - if upstream_result.both_available { - if prefer_ipv6 { - info!(" IPv6 in use and IPv4 is fallback"); - } else { - info!(" IPv4 in use and IPv6 is fallback"); - } - } else { - let v6_works = upstream_result - .v6_results - .iter() - .any(|r| r.rtt_ms.is_some()); - let v4_works = upstream_result - .v4_results - .iter() - .any(|r| r.rtt_ms.is_some()); - if v6_works && !v4_works { - info!(" IPv6 only (IPv4 unavailable)"); - } else if v4_works && !v6_works { - info!(" IPv4 only (IPv6 unavailable)"); - } else if !v6_works && !v4_works { - info!(" No connectivity!"); - } - } - - info!(" via {}", upstream_result.upstream_name); - info!("============================================================"); - - // Print IPv6 results first - for dc in &upstream_result.v6_results { - let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); - match &dc.rtt_ms { - Some(rtt) => { - // Align: IPv6 addresses are longer, use fewer tabs - // [2001:b28:f23d:f001::a]:443 = ~28 chars - info!(" DC{} [IPv6] {}:\t\t{:.0} ms", dc.dc_idx, addr_str, rtt); - } - None => { - let err = dc.error.as_deref().unwrap_or("fail"); - info!(" DC{} [IPv6] {}:\t\tFAIL ({})", dc.dc_idx, addr_str, err); - } - } - } - - info!("============================================================"); - - // Print IPv4 results - for dc in &upstream_result.v4_results { - let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); - match &dc.rtt_ms { - Some(rtt) => { - // Align: IPv4 addresses are shorter, use more tabs - // 149.154.175.50:443 = ~18 chars - info!( - " DC{} [IPv4] {}:\t\t\t\t{:.0} ms", - dc.dc_idx, addr_str, rtt - ); - } - None => { - let err = dc.error.as_deref().unwrap_or("fail"); - info!( - " DC{} [IPv4] {}:\t\t\t\tFAIL ({})", - dc.dc_idx, addr_str, err - ); - } - } - } - - info!("============================================================"); + info!("================= Telegram ME Connectivity ================="); + if v4_ok && v6_ok { + info!(" IPv4 and IPv6 available"); + } else if v4_ok { + info!(" IPv4 only / IPv6 unavailable"); + } else if v6_ok { + info!(" IPv6 only / IPv4 unavailable"); + } else { + info!(" No ME connectivity"); } + info!(" via direct"); + info!("============================================================"); + + use std::collections::BTreeMap; + let mut grouped: BTreeMap> = BTreeMap::new(); + for report in me_results { + for s in report.samples { + let key = s.dc.abs(); + grouped.entry(key).or_default().push(s); + } + } + + let family_order = if prefer_ipv6 { + vec![(MePingFamily::V6, true), (MePingFamily::V6, false), (MePingFamily::V4, true), (MePingFamily::V4, false)] + } else { + vec![(MePingFamily::V4, true), (MePingFamily::V4, false), (MePingFamily::V6, true), (MePingFamily::V6, false)] + }; + + for (dc_abs, samples) in grouped { + for (family, is_pos) in &family_order { + let fam_samples: Vec<&MePingSample> = samples + .iter() + .filter(|s| matches!(s.family, f if &f == family) && (s.dc >= 0) == *is_pos) + .collect(); + if fam_samples.is_empty() { + continue; + } + + let fam_label = match family { + MePingFamily::V4 => "IPv4", + MePingFamily::V6 => "IPv6", + }; + info!(" DC{} [{}]", dc_abs, fam_label); + for sample in fam_samples { + let line = format_sample_line(sample); + info!("{}", line); + } + } + } + info!("============================================================"); } + info!("================= Telegram DC Connectivity ================="); + + let ping_results = upstream_manager + .ping_all_dcs(prefer_ipv6, &config.dc_overrides) + .await; + + for upstream_result in &ping_results { + let v6_works = upstream_result + .v6_results + .iter() + .any(|r| r.rtt_ms.is_some()); + let v4_works = upstream_result + .v4_results + .iter() + .any(|r| r.rtt_ms.is_some()); + + if upstream_result.both_available { + if prefer_ipv6 { + info!(" IPv6 in use / IPv4 is fallback"); + } else { + info!(" IPv4 in use / IPv6 is fallback"); + } + } else { + if v6_works && !v4_works { + info!(" IPv6 only / IPv4 unavailable)"); + } else if v4_works && !v6_works { + info!(" IPv4 only / IPv6 unavailable)"); + } else if !v6_works && !v4_works { + info!(" No DC connectivity"); + } + } + + info!(" via {}", upstream_result.upstream_name); + info!("============================================================"); + + // Print IPv6 results first (only if IPv6 is available) + if v6_works { + for dc in &upstream_result.v6_results { + let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); + match &dc.rtt_ms { + Some(rtt) => { + info!(" DC{} [IPv6] {} - {:.0} ms", dc.dc_idx, addr_str, rtt); + } + None => { + let err = dc.error.as_deref().unwrap_or("fail"); + info!(" DC{} [IPv6] {} - FAIL ({})", dc.dc_idx, addr_str, err); + } + } + } + + info!("============================================================"); + } + + // Print IPv4 results (only if IPv4 is available) + if v4_works { + for dc in &upstream_result.v4_results { + let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); + match &dc.rtt_ms { + Some(rtt) => { + info!( + " DC{} [IPv4] {}\t\t\t\t{:.0} ms", + dc.dc_idx, addr_str, rtt + ); + } + None => { + let err = dc.error.as_deref().unwrap_or("fail"); + info!( + " DC{} [IPv4] {}:\t\t\t\tFAIL ({})", + dc.dc_idx, addr_str, err + ); + } + } + } + + info!("============================================================"); + } + } + // Background tasks let um_clone = upstream_manager.clone(); tokio::spawn(async move { @@ -440,47 +588,28 @@ async fn main() -> std::result::Result<(), Box> { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - let public_ip = if let Some(ip) = listener_conf.announce_ip { - ip + // Resolve the public host for link generation + let public_host = if let Some(ref announce) = listener_conf.announce { + announce.clone() // Use announce (IP or hostname) if explicitly set } else if listener_conf.ip.is_unspecified() { + // Auto-detect for unspecified addresses if listener_conf.ip.is_ipv4() { - detected_ip.ipv4.unwrap_or(listener_conf.ip) + detected_ip.ipv4 + .map(|ip| ip.to_string()) + .unwrap_or_else(|| listener_conf.ip.to_string()) } else { - detected_ip.ipv6.unwrap_or(listener_conf.ip) + detected_ip.ipv6 + .map(|ip| ip.to_string()) + .unwrap_or_else(|| listener_conf.ip.to_string()) } } else { - listener_conf.ip + listener_conf.ip.to_string() }; - if !config.show_link.is_empty() { - info!("--- Proxy Links ({}) ---", public_ip); - for user_name in config.show_link.resolve_users(&config.access.users) { - if let Some(secret) = config.access.users.get(user_name) { - info!("User: {}", user_name); - if config.general.modes.classic { - info!( - " Classic: tg://proxy?server={}&port={}&secret={}", - public_ip, config.server.port, secret - ); - } - if config.general.modes.secure { - info!( - " DD: tg://proxy?server={}&port={}&secret=dd{}", - public_ip, config.server.port, secret - ); - } - if config.general.modes.tls { - let domain_hex = hex::encode(&config.censorship.tls_domain); - info!( - " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - public_ip, config.server.port, secret, domain_hex - ); - } - } else { - warn!("User '{}' in show_link not found", user_name); - } - } - info!("------------------------"); + // Show per-listener proxy links only when public_host is not set + if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() { + let link_port = config.general.links.public_port.unwrap_or(config.server.port); + print_proxy_links(&public_host, link_port, &config); } listeners.push(listener); @@ -491,7 +620,104 @@ async fn main() -> std::result::Result<(), Box> { } } - if listeners.is_empty() { + // Show proxy links once when public_host is set, OR when there are no TCP listeners + // (unix-only mode) — use detected IP as fallback + if !config.general.links.show.is_empty() && (config.general.links.public_host.is_some() || listeners.is_empty()) { + let (host, port) = if let Some(ref h) = config.general.links.public_host { + (h.clone(), config.general.links.public_port.unwrap_or(config.server.port)) + } else { + let ip = detected_ip + .ipv4 + .or(detected_ip.ipv6) + .map(|ip| ip.to_string()); + if ip.is_none() { + warn!("show_link is configured but public IP could not be detected. Set public_host in config."); + } + (ip.unwrap_or_else(|| "UNKNOWN".to_string()), config.general.links.public_port.unwrap_or(config.server.port)) + }; + + print_proxy_links(&host, port, &config); + } + + // Unix socket setup (before listeners check so unix-only config works) + let mut has_unix_listener = false; + #[cfg(unix)] + if let Some(ref unix_path) = config.server.listen_unix_sock { + // Remove stale socket file if present (standard practice) + let _ = tokio::fs::remove_file(unix_path).await; + + let unix_listener = UnixListener::bind(unix_path)?; + + // Apply socket permissions if configured + if let Some(ref perm_str) = config.server.listen_unix_sock_perm { + match u32::from_str_radix(perm_str.trim_start_matches('0'), 8) { + Ok(mode) => { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(mode); + if let Err(e) = std::fs::set_permissions(unix_path, perms) { + error!("Failed to set unix socket permissions to {}: {}", perm_str, e); + } else { + info!("Listening on unix:{} (mode {})", unix_path, perm_str); + } + } + Err(e) => { + warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e); + info!("Listening on unix:{}", unix_path); + } + } + } else { + info!("Listening on unix:{}", unix_path); + } + + has_unix_listener = true; + + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let me_pool = me_pool.clone(); + let ip_tracker = ip_tracker.clone(); + + tokio::spawn(async move { + let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); + + loop { + match unix_listener.accept().await { + Ok((stream, _)) => { + let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16)); + + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let me_pool = me_pool.clone(); + let ip_tracker = ip_tracker.clone(); + + tokio::spawn(async move { + if let Err(e) = crate::proxy::client::handle_client_stream( + stream, fake_peer, config, stats, + upstream_manager, replay_checker, buffer_pool, rng, + me_pool, ip_tracker, + ).await { + debug!(error = %e, "Unix socket connection error"); + } + }); + } + Err(e) => { + error!("Unix socket accept error: {}", e); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }); + } + + if listeners.is_empty() && !has_unix_listener { error!("No listeners. Exiting."); std::process::exit(1); } @@ -506,6 +732,14 @@ async fn main() -> std::result::Result<(), Box> { .reload(runtime_filter) .expect("Failed to switch log filter"); + if let Some(port) = config.server.metrics_port { + let stats = stats.clone(); + let whitelist = config.server.metrics_whitelist.clone(); + tokio::spawn(async move { + metrics::serve(port, stats, whitelist).await; + }); + } + for listener in listeners { let config = config.clone(); let stats = stats.clone(); diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..24acf30 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,197 @@ +use std::convert::Infallible; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use http_body_util::Full; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Request, Response, StatusCode}; +use tokio::net::TcpListener; +use tracing::{info, warn, debug}; + +use crate::stats::Stats; + +pub async fn serve(port: u16, stats: Arc, whitelist: Vec) { + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + let listener = match TcpListener::bind(addr).await { + Ok(l) => l, + Err(e) => { + warn!(error = %e, "Failed to bind metrics on {}", addr); + return; + } + }; + info!("Metrics endpoint: http://{}/metrics", addr); + + loop { + let (stream, peer) = match listener.accept().await { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "Metrics accept error"); + continue; + } + }; + + if !whitelist.is_empty() && !whitelist.contains(&peer.ip()) { + debug!(peer = %peer, "Metrics request denied by whitelist"); + continue; + } + + let stats = stats.clone(); + tokio::spawn(async move { + let svc = service_fn(move |req| { + let stats = stats.clone(); + async move { handle(req, &stats) } + }); + if let Err(e) = http1::Builder::new() + .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) + .await + { + debug!(error = %e, "Metrics connection error"); + } + }); + } +} + +fn handle(req: Request, stats: &Stats) -> Result>, Infallible> { + if req.uri().path() != "/metrics" { + let resp = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::from("Not Found\n"))) + .unwrap(); + return Ok(resp); + } + + let body = render_metrics(stats); + let resp = Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/plain; version=0.0.4; charset=utf-8") + .body(Full::new(Bytes::from(body))) + .unwrap(); + Ok(resp) +} + +fn render_metrics(stats: &Stats) -> String { + use std::fmt::Write; + let mut out = String::with_capacity(4096); + + let _ = writeln!(out, "# HELP telemt_uptime_seconds Proxy uptime"); + let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); + let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); + + let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections"); + let _ = writeln!(out, "# TYPE telemt_connections_total counter"); + let _ = writeln!(out, "telemt_connections_total {}", stats.get_connects_all()); + + let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections"); + let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter"); + let _ = writeln!(out, "telemt_connections_bad_total {}", stats.get_connects_bad()); + + let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts"); + let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); + let _ = writeln!(out, "telemt_handshake_timeouts_total {}", stats.get_handshake_timeouts()); + + let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); + let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); + let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); + let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge"); + let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received"); + let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter"); + let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent"); + let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter"); + let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received"); + let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter"); + let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent"); + let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter"); + + for entry in stats.iter_user_stats() { + let user = entry.key(); + let s = entry.value(); + let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)); + } + + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_render_metrics_format() { + let stats = Arc::new(Stats::new()); + stats.increment_connects_all(); + stats.increment_connects_all(); + stats.increment_connects_bad(); + stats.increment_handshake_timeouts(); + stats.increment_user_connects("alice"); + stats.increment_user_curr_connects("alice"); + stats.add_user_octets_from("alice", 1024); + stats.add_user_octets_to("alice", 2048); + stats.increment_user_msgs_from("alice"); + stats.increment_user_msgs_to("alice"); + stats.increment_user_msgs_to("alice"); + + let output = render_metrics(&stats); + + assert!(output.contains("telemt_connections_total 2")); + assert!(output.contains("telemt_connections_bad_total 1")); + assert!(output.contains("telemt_handshake_timeouts_total 1")); + assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1")); + assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1")); + assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024")); + assert!(output.contains("telemt_user_octets_to_client{user=\"alice\"} 2048")); + assert!(output.contains("telemt_user_msgs_from_client{user=\"alice\"} 1")); + assert!(output.contains("telemt_user_msgs_to_client{user=\"alice\"} 2")); + } + + #[test] + fn test_render_empty_stats() { + let stats = Stats::new(); + let output = render_metrics(&stats); + assert!(output.contains("telemt_connections_total 0")); + assert!(output.contains("telemt_connections_bad_total 0")); + assert!(output.contains("telemt_handshake_timeouts_total 0")); + assert!(!output.contains("user=")); + } + + #[test] + fn test_render_has_type_annotations() { + let stats = Stats::new(); + let output = render_metrics(&stats); + assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); + assert!(output.contains("# TYPE telemt_connections_total counter")); + assert!(output.contains("# TYPE telemt_connections_bad_total counter")); + assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); + } + + #[tokio::test] + async fn test_endpoint_integration() { + let stats = Arc::new(Stats::new()); + stats.increment_connects_all(); + stats.increment_connects_all(); + stats.increment_connects_all(); + + let port = 19091u16; + let s = stats.clone(); + tokio::spawn(async move { + serve(port, s, vec![]).await; + }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let resp = reqwest::get(format!("http://127.0.0.1:{}/metrics", port)) + .await.unwrap(); + assert_eq!(resp.status(), 200); + let body = resp.text().await.unwrap(); + assert!(body.contains("telemt_connections_total 3")); + + let resp404 = reqwest::get(format!("http://127.0.0.1:{}/other", port)) + .await.unwrap(); + assert_eq!(resp404.status(), 404); + } +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 041e7cb..87d6b52 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,6 +1,8 @@ //! Client Handler +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; @@ -8,6 +10,17 @@ use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, warn}; +/// Post-handshake future (relay phase, runs outside handshake timeout) +type PostHandshakeFuture = Pin> + Send>>; + +/// Result of the handshake phase +enum HandshakeOutcome { + /// Handshake succeeded, relay work to do (outside timeout) + NeedsRelay(PostHandshakeFuture), + /// Already fully handled (bad client masking, etc.) + Handled, +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result}; @@ -24,6 +37,160 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; +pub async fn handle_client_stream( + mut stream: S, + peer: SocketAddr, + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, + ip_tracker: Arc, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + stats.increment_connects_all(); + debug!(peer = %peer, "New connection (generic stream)"); + + let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let stats_for_timeout = stats.clone(); + + // For non-TCP streams, use a synthetic local address + let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) + .parse() + .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, async { + let mut first_bytes = [0u8; 5]; + stream.read_exact(&mut first_bytes).await?; + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + if is_tls { + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + if tls_len < 512 { + debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(HandshakeOutcome::Handled); + } + + let mut handshake = vec![0u8; 5 + tls_len]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, &rng, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + debug!(peer = %peer, "Reading MTProto handshake through TLS"); + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() + .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &mtproto_handshake, tls_reader, tls_writer, peer, + &config, &replay_checker, true, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader: _, writer: _ } => { + stats.increment_connects_bad(); + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, peer, ip_tracker.clone(), + ), + ))) + } else { + if !config.general.modes.classic && !config.general.modes.secure { + debug!(peer = %peer, "Non-TLS modes disabled"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(HandshakeOutcome::Handled); + } + + let mut handshake = [0u8; HANDSHAKE_LEN]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, false, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer, + ip_tracker.clone(), + ) + ))) + } + }).await { + Ok(Ok(outcome)) => outcome, + Ok(Err(e)) => { + debug!(peer = %peer, error = %e, "Handshake failed"); + return Err(e); + } + Err(_) => { + stats_for_timeout.increment_handshake_timeouts(); + debug!(peer = %peer, "Handshake timeout"); + return Err(ProxyError::TgHandshakeTimeout); + } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), + } +} + pub struct ClientHandler; pub struct RunningClientHandler { @@ -72,6 +239,7 @@ impl RunningClientHandler { self.stats.increment_connects_all(); let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -85,31 +253,34 @@ impl RunningClientHandler { let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let stats = self.stats.clone(); - let result = timeout(handshake_timeout, self.do_handshake()).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, self.do_handshake()).await { + Ok(Ok(outcome)) => outcome, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) + return Err(e); } Err(_) => { stats.increment_handshake_timeouts(); debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) + return Err(ProxyError::TgHandshakeTimeout); } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), } } - async fn do_handshake(mut self) -> Result<()> { + async fn do_handshake(mut self) -> Result { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -120,8 +291,9 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -132,7 +304,7 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = vec![0u8; 5 + tls_len]; @@ -162,7 +334,7 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; @@ -191,37 +363,39 @@ impl RunningClientHandler { } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -251,26 +425,27 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } /// Main dispatch after successful handshake. diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 3cce39e..ff50bca 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,3 +1,5 @@ +use std::fs::OpenOptions; +use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; @@ -87,17 +89,25 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let num_dcs = datacenters.len(); let dc_key = dc_idx.to_string(); - if let Some(addr_str) = config.dc_overrides.get(&dc_key) { - match addr_str.parse::() { - Ok(addr) => { - debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); - return Ok(addr); - } - Err(_) => { - warn!(dc_idx = dc_idx, addr_str = %addr_str, - "Invalid DC override address in config, ignoring"); + if let Some(addrs) = config.dc_overrides.get(&dc_key) { + let prefer_v6 = config.general.prefer_ipv6; + let mut parsed = Vec::new(); + for addr_str in addrs { + match addr_str.parse::() { + Ok(addr) => parsed.push(addr), + Err(_) => warn!(dc_idx = dc_idx, addr_str = %addr_str, "Invalid DC override address in config, ignoring"), } } + + if let Some(addr) = parsed + .iter() + .find(|a| a.is_ipv6() == prefer_v6) + .or_else(|| parsed.first()) + .copied() + { + debug!(dc_idx = dc_idx, addr = %addr, count = parsed.len(), "Using DC override from config"); + return Ok(addr); + } } let abs_dc = dc_idx.unsigned_abs() as usize; @@ -105,6 +115,16 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); } + // Unknown DC requested by client without override: log and fall back. + if !config.dc_overrides.contains_key(&dc_key) { + warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); + if let Some(path) = &config.general.unknown_dc_log_path { + if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { + let _ = writeln!(file, "dc_idx={dc_idx}"); + } + } + } + let default_dc = config.default_dc.unwrap_or(2) as usize; let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 01ccc20..09dd532 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -263,7 +263,14 @@ where } } - client_writer.flush().await.map_err(ProxyError::Io) + // Avoid unconditional per-frame flush (throughput killer on large downloads). + // Flush only when low-latency ack semantics are requested or when + // CryptoWriter has buffered pending ciphertext that must be drained. + if quickack || client_writer.has_pending() { + client_writer.flush().await.map_err(ProxyError::Io)?; + } + + Ok(()) } async fn write_client_ack( @@ -283,5 +290,6 @@ where .write_all(&bytes) .await .map_err(ProxyError::Io)?; + // ACK should remain low-latency. client_writer.flush().await.map_err(ProxyError::Io) } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 5c3a084..2cdcdf9 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -97,6 +97,12 @@ impl Stats { .unwrap_or(0) } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } + + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + self.user_stats.iter() + } + pub fn uptime_secs(&self) -> f64 { self.start_time.read() .map(|t| t.elapsed().as_secs_f64()) diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 51daee9..326bf90 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -174,6 +174,7 @@ impl RpcWriter { if buf.len() >= 16 { self.iv.copy_from_slice(&buf[buf.len() - 16..]); } - self.writer.write_all(&buf).await.map_err(ProxyError::Io) + self.writer.write_all(&buf).await.map_err(ProxyError::Io)?; + self.writer.flush().await.map_err(ProxyError::Io) } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index aed5a54..8ac6986 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::Duration; use regex::Regex; +use httpdate; use tracing::{debug, info, warn}; use crate::error::Result; @@ -11,6 +12,7 @@ use crate::error::Result; use super::MePool; use super::secret::download_proxy_secret; use crate::crypto::SecureRandom; +use std::time::SystemTime; #[derive(Debug, Clone, Default)] pub struct ProxyConfigData { @@ -19,9 +21,29 @@ pub struct ProxyConfigData { } pub async fn fetch_proxy_config(url: &str) -> Result { - let text = reqwest::get(url) + let resp = reqwest::get(url) .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))? + ; + + if let Some(date) = resp.headers().get(reqwest::header::DATE) { + if let Ok(date_str) = date.to_str() { + if let Ok(server_time) = httpdate::parse_http_date(date_str) { + if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| { + server_time.duration_since(SystemTime::now()).map_err(|_| e) + }) { + let skew_secs = skew.as_secs(); + if skew_secs > 60 { + warn!(skew_secs, "Time skew >60s detected from fetch_proxy_config Date header"); + } else if skew_secs > 30 { + warn!(skew_secs, "Time skew >30s detected from fetch_proxy_config Date header"); + } + } + } + } + } + + let text = resp .text() .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?; diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs new file mode 100644 index 0000000..0a410c8 --- /dev/null +++ b/src/transport/middle_proxy/handshake.rs @@ -0,0 +1,412 @@ +use std::net::{IpAddr, SocketAddr}; +use std::time::{Duration, Instant}; +use socket2::{SockRef, TcpKeepalive}; +#[cfg(target_os = "linux")] +use libc; +#[cfg(target_os = "linux")] +use std::os::fd::{AsRawFd, RawFd}; +#[cfg(target_os = "linux")] +use std::os::raw::c_int; + +use bytes::BytesMut; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tracing::{debug, info, warn}; + +use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::{ + ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, + RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, +}; + +use super::codec::{ + build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, + cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, +}; +use super::wire::{extract_ip_material, IpMaterial}; +use super::MePool; + +/// Result of a successful ME handshake with timings. +pub(crate) struct HandshakeOutput { + pub rd: ReadHalf, + pub wr: WriteHalf, + pub read_key: [u8; 32], + pub read_iv: [u8; 16], + pub write_key: [u8; 32], + pub write_iv: [u8; 16], + pub handshake_ms: f64, +} + +impl MePool { + /// TCP connect with timeout + return RTT in milliseconds. + pub(crate) async fn connect_tcp(&self, addr: SocketAddr) -> Result<(TcpStream, f64)> { + let start = Instant::now(); + let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??; + let connect_ms = start.elapsed().as_secs_f64() * 1000.0; + stream.set_nodelay(true).ok(); + if let Err(e) = Self::configure_keepalive(&stream) { + warn!(error = %e, "ME keepalive setup failed"); + } + #[cfg(target_os = "linux")] + if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) { + warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed"); + } + Ok((stream, connect_ms)) + } + + fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> { + let sock = SockRef::from(stream); + let ka = TcpKeepalive::new() + .with_time(Duration::from_secs(30)) + .with_interval(Duration::from_secs(10)) + .with_retries(3); + sock.set_tcp_keepalive(&ka)?; + sock.set_keepalive(true)?; + Ok(()) + } + + #[cfg(target_os = "linux")] + fn configure_user_timeout(fd: RawFd) -> std::io::Result<()> { + let timeout_ms: c_int = 30_000; + let rc = unsafe { + libc::setsockopt( + fd, + libc::IPPROTO_TCP, + libc::TCP_USER_TIMEOUT, + &timeout_ms as *const _ as *const libc::c_void, + std::mem::size_of_val(&timeout_ms) as libc::socklen_t, + ) + }; + if rc != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) + } + + /// Perform full ME RPC handshake on an established TCP stream. + /// Returns cipher keys/ivs and split halves; does not register writer. + pub(crate) async fn handshake_only( + &self, + stream: TcpStream, + addr: SocketAddr, + rng: &SecureRandom, + ) -> Result { + let hs_start = Instant::now(); + + let local_addr = stream.local_addr().map_err(ProxyError::Io)?; + let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; + + let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; + let reflected = if self.nat_probe { + self.maybe_reflect_public_addr().await + } else { + None + }; + + let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); + let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let (mut rd, mut wr) = tokio::io::split(stream); + + let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); + let crypto_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + + let ks = self.key_selector().await; + let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); + let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); + info!( + key_selector = format_args!("0x{ks:08x}"), + crypto_ts, + frame_len = nonce_frame.len(), + nonce_frame_hex = %dump, + "Sending ME nonce frame" + ); + wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let (srv_seq, srv_nonce_payload) = timeout( + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), + read_rpc_frame_plaintext(&mut rd), + ) + .await + .map_err(|_| ProxyError::TgHandshakeTimeout)??; + + if srv_seq != -2 { + return Err(ProxyError::InvalidHandshake(format!("Expected seq=-2, got {srv_seq}"))); + } + + let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; + if schema != RPC_CRYPTO_AES_U32 { + warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); + return Err(ProxyError::InvalidHandshake(format!( + "Unsupported crypto schema: 0x{schema:x}" + ))); + } + + if srv_key_select != ks { + return Err(ProxyError::InvalidHandshake(format!( + "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" + ))); + } + + let skew = crypto_ts.abs_diff(srv_ts); + if skew > 30 { + return Err(ProxyError::InvalidHandshake(format!( + "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" + ))); + } + + info!( + %local_addr, + %local_addr_nat, + reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), + %peer_addr, + %peer_addr_nat, + key_selector = format_args!("0x{ks:08x}"), + crypto_schema = format_args!("0x{schema:08x}"), + skew_secs = skew, + "ME key derivation parameters" + ); + + let ts_bytes = crypto_ts.to_le_bytes(); + let server_port_bytes = peer_addr_nat.port().to_le_bytes(); + let client_port_bytes = local_addr_nat.port().to_le_bytes(); + + let server_ip = extract_ip_material(peer_addr_nat); + let client_ip = extract_ip_material(local_addr_nat); + + let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { + (IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => { + srv.reverse(); + clt.reverse(); + (Some(srv), Some(clt), None, None, clt, srv) + } + (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { + let zero = [0u8; 4]; + (None, None, Some(clt), Some(srv), zero, zero) + } + _ => { + return Err(ProxyError::InvalidHandshake( + "mixed IPv4/IPv6 endpoints are not supported for ME key derivation".to_string(), + )); + } + }; + + let diag_level: u8 = std::env::var("ME_DIAG").ok().and_then(|v| v.parse().ok()).unwrap_or(0); + + let secret: Vec = self.proxy_secret.read().await.clone(); + + let prekey_client = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let prekey_server = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let (wk, wi) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let (rk, ri) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); + let hs_frame = build_rpc_frame(-1, &hs_payload); + if diag_level >= 1 { + info!( + write_key = %hex_dump(&wk), + write_iv = %hex_dump(&wi), + read_key = %hex_dump(&rk), + read_iv = %hex_dump(&ri), + srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + srv_port = %hex_dump(&server_port_bytes), + clt_port = %hex_dump(&client_port_bytes), + crypto_ts = %hex_dump(&ts_bytes), + nonce_srv = %hex_dump(&srv_nonce), + nonce_clt = %hex_dump(&my_nonce), + prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), + prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), + hs_plain = %hex_dump(&hs_frame), + proxy_secret_sha256 = %hex_dump(&sha256(&secret)), + "ME diag: derived keys and handshake plaintext" + ); + } + if diag_level >= 2 { + info!( + prekey_client = %hex_dump(&prekey_client), + prekey_server = %hex_dump(&prekey_server), + "ME diag: full prekey buffers" + ); + } + + let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + if diag_level >= 1 { + info!( + hs_cipher = %hex_dump(&encrypted_hs), + "ME diag: handshake ciphertext" + ); + } + wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); + let mut enc_buf = BytesMut::with_capacity(256); + let mut dec_buf = BytesMut::with_capacity(256); + let mut read_iv = ri; + let mut handshake_ok = false; + + while Instant::now() < deadline && !handshake_ok { + let remaining = deadline - Instant::now(); + let mut tmp = [0u8; 256]; + let n = match timeout(remaining, rd.read(&mut tmp)).await { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ME closed during handshake", + ))); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => return Err(ProxyError::TgHandshakeTimeout), + }; + + enc_buf.extend_from_slice(&tmp[..n]); + + let blocks = enc_buf.len() / 16 * 16; + if blocks > 0 { + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&enc_buf[..blocks]); + read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; + dec_buf.extend_from_slice(&chunk); + let _ = enc_buf.split_to(blocks); + } + + while dec_buf.len() >= 4 { + let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; + + if fl == 4 { + let _ = dec_buf.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad HS response frame len: {fl}" + ))); + } + if dec_buf.len() < fl { + break; + } + + let frame = dec_buf.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + let ac = crate::crypto::crc32(&frame[..pe]); + if ec != ac { + return Err(ProxyError::InvalidHandshake(format!( + "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" + ))); + } + + let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + if hs_type == RPC_HANDSHAKE_ERROR_U32 { + let err_code = if frame.len() >= 16 { + i32::from_le_bytes(frame[12..16].try_into().unwrap()) + } else { + -1 + }; + return Err(ProxyError::InvalidHandshake(format!( + "ME rejected handshake (error={err_code})" + ))); + } + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + + handshake_ok = true; + break; + } + } + + if !handshake_ok { + return Err(ProxyError::TgHandshakeTimeout); + } + + let handshake_ms = hs_start.elapsed().as_secs_f64() * 1000.0; + info!(%addr, "RPC handshake OK"); + + Ok(HandshakeOutput { + rd, + wr, + read_key: rk, + read_iv, + write_key: wk, + write_iv, + handshake_ms, + }) + } +} + +fn hex_dump(data: &[u8]) -> String { + const MAX: usize = 64; + let mut out = String::with_capacity(data.len() * 2 + 3); + for (i, b) in data.iter().take(MAX).enumerate() { + if i > 0 { + out.push(' '); + } + out.push_str(&format!("{b:02x}")); + } + if data.len() > MAX { + out.push_str(" …"); + } + out +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 8f0f5a6..d2bb51a 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tracing::{debug, info, warn}; use rand::seq::SliceRandom; @@ -10,6 +11,8 @@ use crate::crypto::SecureRandom; use super::MePool; pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_connections: usize) { + let mut backoff: HashMap = HashMap::new(); + let mut last_attempt: HashMap = HashMap::new(); loop { tokio::time::sleep(Duration::from_secs(30)).await; // Per-DC coverage check @@ -19,7 +22,7 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c .read() .await .iter() - .map(|(a, _)| *a) + .map(|w| w.addr) .collect(); for (dc, addrs) in map.iter() { @@ -29,18 +32,81 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c .collect(); let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); if !has_coverage { - warn!(dc = %dc, "DC has no ME coverage, reconnecting..."); + let delay = *backoff.get(dc).unwrap_or(&30); + let now = Instant::now(); + if let Some(last) = last_attempt.get(dc) { + if now.duration_since(*last).as_secs() < delay { + continue; + } + } + warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting..."); let mut shuffled = dc_addrs.clone(); shuffled.shuffle(&mut rand::rng()); + let mut reconnected = false; for addr in shuffled { match pool.connect_one(addr, &rng).await { Ok(()) => { info!(%addr, dc = %dc, "ME reconnected for DC coverage"); + backoff.insert(*dc, 30); + last_attempt.insert(*dc, now); + reconnected = true; break; } Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"), } } + if !reconnected { + let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); + backoff.insert(*dc, next); + last_attempt.insert(*dc, now); + } + } + } + + // IPv6 coverage check (if available) + let map_v6 = pool.proxy_map_v6.read().await.clone(); + let writer_addrs_v6: std::collections::HashSet = pool + .writers + .read() + .await + .iter() + .map(|w| w.addr) + .collect(); + for (dc, addrs) in map_v6.iter() { + let dc_addrs: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a)); + if !has_coverage { + let delay = *backoff.get(dc).unwrap_or(&30); + let now = Instant::now(); + if let Some(last) = last_attempt.get(dc) { + if now.duration_since(*last).as_secs() < delay { + continue; + } + } + warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting..."); + let mut shuffled = dc_addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + let mut reconnected = false; + for addr in shuffled { + match pool.connect_one(addr, &rng).await { + Ok(()) => { + info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage"); + backoff.insert(*dc, 30); + last_attempt.insert(*dc, now); + reconnected = true; + break; + } + Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"), + } + } + if !reconnected { + let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); + backoff.insert(*dc, next); + last_attempt.insert(*dc, now); + } } } } diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index e617158..26d07dd 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -1,23 +1,29 @@ //! Middle Proxy RPC transport. mod codec; +mod handshake; mod health; mod pool; mod pool_nat; +mod ping; mod reader; mod registry; mod send; mod secret; +mod rotation; mod config_updater; mod wire; use bytes::Bytes; pub use health::me_health_monitor; +pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use pool::MePool; +pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult}; pub use registry::ConnRegistry; pub use secret::fetch_proxy_secret; pub use config_updater::{fetch_proxy_config, me_config_updater}; +pub use rotation::me_rotation_task; pub use wire::proto_flags_for_tag; #[derive(Debug)] diff --git a/src/transport/middle_proxy/ping.rs b/src/transport/middle_proxy/ping.rs new file mode 100644 index 0000000..22b1f6d --- /dev/null +++ b/src/transport/middle_proxy/ping.rs @@ -0,0 +1,164 @@ +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use crate::crypto::SecureRandom; +use crate::error::ProxyError; + +use super::MePool; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MePingFamily { + V4, + V6, +} + +#[derive(Debug, Clone)] +pub struct MePingSample { + pub dc: i32, + pub addr: SocketAddr, + pub connect_ms: Option, + pub handshake_ms: Option, + pub error: Option, + pub family: MePingFamily, +} + +#[derive(Debug, Clone)] +pub struct MePingReport { + pub dc: i32, + pub family: MePingFamily, + pub samples: Vec, +} + +pub fn format_sample_line(sample: &MePingSample) -> String { + let sign = if sample.dc >= 0 { "+" } else { "-" }; + let addr = format!("{}:{}", sample.addr.ip(), sample.addr.port()); + + match (sample.connect_ms, sample.handshake_ms.as_ref(), sample.error.as_ref()) { + (Some(conn), Some(hs), None) => format!( + " {sign} {addr}\tPing: {:.0} ms / RPC: {:.0} ms / OK", + conn, hs + ), + (Some(conn), None, Some(err)) => format!( + " {sign} {addr}\tPing: {:.0} ms / RPC: FAIL ({err})", + conn + ), + (None, _, Some(err)) => format!(" {sign} {addr}\tPing: FAIL ({err})"), + (Some(conn), None, None) => format!(" {sign} {addr}\tPing: {:.0} ms / RPC: FAIL", conn), + _ => format!(" {sign} {addr}\tPing: FAIL"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + fn sample(base: MePingSample) -> MePingSample { + base + } + + #[test] + fn ok_line_contains_both_timings() { + let s = sample(MePingSample { + dc: 4, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 8888), + connect_ms: Some(12.3), + handshake_ms: Some(34.7), + error: None, + family: MePingFamily::V4, + }); + let line = format_sample_line(&s); + assert!(line.contains("Ping: 12 ms")); + assert!(line.contains("RPC: 35 ms")); + assert!(line.contains("OK")); + } + + #[test] + fn error_line_mentions_reason() { + let s = sample(MePingSample { + dc: -5, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)), 80), + connect_ms: Some(10.0), + handshake_ms: None, + error: Some("handshake timeout".to_string()), + family: MePingFamily::V4, + }); + let line = format_sample_line(&s); + assert!(line.contains("- 5.6.7.8:80")); + assert!(line.contains("handshake timeout")); + } +} + +pub async fn run_me_ping(pool: &Arc, rng: &SecureRandom) -> Vec { + let mut reports = Vec::new(); + + let v4_map = pool.proxy_map_v4.read().await.clone(); + let v6_map = pool.proxy_map_v6.read().await.clone(); + + let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new(); + for (dc, addrs) in v4_map { + grouped.push((MePingFamily::V4, dc, addrs)); + } + for (dc, addrs) in v6_map { + grouped.push((MePingFamily::V6, dc, addrs)); + } + + for (family, dc, addrs) in grouped { + let mut samples = Vec::new(); + for (ip, port) in addrs { + let addr = SocketAddr::new(ip, port); + let mut connect_ms = None; + let mut handshake_ms = None; + let mut error = None; + + match pool.connect_tcp(addr).await { + Ok((stream, conn_rtt)) => { + connect_ms = Some(conn_rtt); + match pool.handshake_only(stream, addr, rng).await { + Ok(hs) => { + handshake_ms = Some(hs.handshake_ms); + // drop halves to close + drop(hs.rd); + drop(hs.wr); + } + Err(e) => { + error = Some(short_err(&e)); + } + } + } + Err(e) => { + error = Some(short_err(&e)); + } + } + + samples.push(MePingSample { + dc, + addr, + connect_ms, + handshake_ms, + error, + family, + }); + } + + reports.push(MePingReport { + dc, + family, + samples, + }); + } + + reports +} + +fn short_err(err: &ProxyError) -> String { + match err { + ProxyError::ConnectionTimeout { .. } => "connect timeout".to_string(), + ProxyError::TgHandshakeTimeout => "handshake timeout".to_string(), + ProxyError::InvalidHandshake(e) => format!("bad handshake: {e}"), + ProxyError::Crypto(e) => format!("crypto: {e}"), + ProxyError::Proxy(e) => format!("proxy: {e}"), + ProxyError::Io(e) => format!("io: {e}"), + _ => format!("{err}"), + } +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index f38a81c..7305f5e 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,36 +1,40 @@ use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; -use std::sync::atomic::{AtomicI32, AtomicU64}; -use std::time::Duration; - +use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering}; use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; use tokio::sync::{Mutex, RwLock}; -use tokio::time::{Instant, timeout}; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; +use std::time::Duration; -use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use super::ConnRegistry; -use super::codec::{ - RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, - cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, -}; +use super::registry::{BoundConn, ConnMeta}; +use super::codec::RpcWriter; use super::reader::reader_loop; -use super::wire::{IpMaterial, extract_ip_material}; +use super::MeResponse; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; +#[derive(Clone)] +pub struct MeWriter { + pub id: u64, + pub addr: SocketAddr, + pub writer: Arc>, + pub cancel: CancellationToken, + pub degraded: Arc, +} + pub struct MePool { pub(super) registry: Arc, - pub(super) writers: Arc>)>>> , + pub(super) writers: Arc>>, pub(super) rr: AtomicU64, pub(super) proxy_tag: Option>, pub(super) proxy_secret: Arc>>, @@ -41,6 +45,10 @@ pub struct MePool { pub(super) proxy_map_v4: Arc>>>, pub(super) proxy_map_v6: Arc>>>, pub(super) default_dc: AtomicI32, + pub(super) next_writer_id: AtomicU64, + pub(super) ping_tracker: Arc>>, + pub(super) rtt_stats: Arc>>, + pub(super) nat_reflection_cache: Arc>>, pool_size: usize, } @@ -69,6 +77,10 @@ impl MePool { proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), default_dc: AtomicI32::new(default_dc.unwrap_or(0)), + next_writer_id: AtomicU64::new(1), + ping_tracker: Arc::new(Mutex::new(HashMap::new())), + rtt_stats: Arc::new(Mutex::new(HashMap::new())), + nat_reflection_cache: Arc::new(Mutex::new(None)), }) } @@ -85,16 +97,19 @@ impl MePool { &self.registry } - fn writers_arc(&self) -> Arc>)>>> - { + fn writers_arc(&self) -> Arc>> { self.writers.clone() } - pub async fn reconcile_connections(&self, rng: &SecureRandom) { + pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { use std::collections::HashSet; let map = self.proxy_map_v4.read().await.clone(); + let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map + .iter() + .map(|(dc, addrs)| (*dc, addrs.clone())) + .collect(); let writers = self.writers.read().await; - let current: HashSet = writers.iter().map(|(a, _)| *a).collect(); + let current: HashSet = writers.iter().map(|w| w.addr).collect(); drop(writers); for (_dc, addrs) in map.iter() { @@ -157,7 +172,7 @@ impl MePool { // No-op here to avoid total outage. } - async fn key_selector(&self) -> u32 { + pub(super) async fn key_selector(&self) -> u32 { let secret = self.proxy_secret.read().await; if secret.len() >= 4 { u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]]) @@ -166,8 +181,12 @@ impl MePool { } } - pub async fn init(self: &Arc, pool_size: usize, rng: &SecureRandom) -> Result<()> { - let map = self.proxy_map_v4.read().await; + pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { + let map = self.proxy_map_v4.read().await.clone(); + let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map + .iter() + .map(|(dc, addrs)| (*dc, addrs.clone())) + .collect(); let ks = self.key_selector().await; info!( me_servers = map.len(), @@ -177,38 +196,28 @@ impl MePool { "Initializing ME pool" ); - // Ensure at least one connection per DC with failover over all addresses - for (dc, addrs) in map.iter() { + // Ensure at least one connection per DC; run DCs in parallel. + let mut join = tokio::task::JoinSet::new(); + for (dc, addrs) in dc_addrs.iter().cloned() { if addrs.is_empty() { continue; } - let mut connected = false; - let mut shuffled = addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - for (ip, port) in shuffled { - let addr = SocketAddr::new(ip, port); - match self.connect_one(addr, rng).await { - Ok(()) => { - info!(%addr, dc = %dc, "ME connected"); - connected = true; - break; - } - Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), - } - } - if !connected { - warn!(dc = %dc, "All ME servers for DC failed at init"); - } + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + join.spawn(async move { + pool.connect_primary_for_dc(dc, addrs, rng_clone).await; + }); } + while let Some(_res) = join.join_next().await {} // Additional connections up to pool_size total (round-robin across DCs) - for (dc, addrs) in map.iter() { + for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { break; } let addr = SocketAddr::new(*ip, *port); - if let Err(e) = self.connect_one(addr, rng).await { + if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); } } @@ -223,360 +232,97 @@ impl MePool { Ok(()) } - pub(crate) async fn connect_one( - &self, - addr: SocketAddr, - rng: &SecureRandom, - ) -> Result<()> { - let secret_guard = self.proxy_secret.read().await; - let secret: Vec = secret_guard.clone(); - if secret.len() < 32 { - return Err(ProxyError::Proxy( - "proxy-secret too short for ME auth".into(), - )); + pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { + let secret_len = self.proxy_secret.read().await.len(); + if secret_len < 32 { + return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } - let stream = timeout( - Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), - TcpStream::connect(addr), - ) - .await - .map_err(|_| ProxyError::ConnectionTimeout { - addr: addr.to_string(), - })? - .map_err(ProxyError::Io)?; - stream.set_nodelay(true).ok(); - - let local_addr = stream.local_addr().map_err(ProxyError::Io)?; - let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; - let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; - let reflected = if self.nat_probe { - self.maybe_reflect_public_addr().await - } else { - None - }; - let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); - let peer_addr_nat = - SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); - let (mut rd, mut wr) = tokio::io::split(stream); - - let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); - let crypto_ts = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as u32; - - let ks = self.key_selector().await; - let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); - let nonce_frame = build_rpc_frame(-2, &nonce_payload); - let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); - info!( - key_selector = format_args!("0x{ks:08x}"), - crypto_ts, - frame_len = nonce_frame.len(), - nonce_frame_hex = %dump, - "Sending ME nonce frame" - ); - wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - let (srv_seq, srv_nonce_payload) = timeout( - Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), - read_rpc_frame_plaintext(&mut rd), - ) - .await - .map_err(|_| ProxyError::TgHandshakeTimeout)??; - - if srv_seq != -2 { - return Err(ProxyError::InvalidHandshake(format!( - "Expected seq=-2, got {srv_seq}" - ))); - } - - let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; - if schema != RPC_CRYPTO_AES_U32 { - warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); - return Err(ProxyError::InvalidHandshake(format!( - "Unsupported crypto schema: 0x{schema:x}" - ))); - } - - if srv_key_select != ks { - return Err(ProxyError::InvalidHandshake(format!( - "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" - ))); - } - - let skew = crypto_ts.abs_diff(srv_ts); - if skew > 30 { - return Err(ProxyError::InvalidHandshake(format!( - "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" - ))); - } - - info!( - %local_addr, - %local_addr_nat, - reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), - %peer_addr, - %peer_addr_nat, - key_selector = format_args!("0x{ks:08x}"), - crypto_schema = format_args!("0x{schema:08x}"), - skew_secs = skew, - "ME key derivation parameters" - ); - - let ts_bytes = crypto_ts.to_le_bytes(); - let server_port_bytes = peer_addr_nat.port().to_le_bytes(); - let client_port_bytes = local_addr_nat.port().to_le_bytes(); - - let server_ip = extract_ip_material(peer_addr_nat); - let client_ip = extract_ip_material(local_addr_nat); - - let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = - match (server_ip, client_ip) { - // IPv4: reverse byte order for KDF (Python/C reference behavior) - (IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => { - srv.reverse(); - clt.reverse(); - (Some(srv), Some(clt), None, None, clt, srv) - } - (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { - let zero = [0u8; 4]; - (None, None, Some(clt), Some(srv), zero, zero) - } - _ => { - return Err(ProxyError::InvalidHandshake( - "mixed IPv4/IPv6 endpoints are not supported for ME key derivation" - .to_string(), - )); - } - }; - - let diag_level: u8 = std::env::var("ME_DIAG") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(0); - - let prekey_client = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let prekey_server = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - - let (wk, wi) = derive_middleproxy_keys( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let (rk, ri) = derive_middleproxy_keys( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - - let hs_payload = - build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); - let hs_frame = build_rpc_frame(-1, &hs_payload); - if diag_level >= 1 { - info!( - write_key = %hex_dump(&wk), - write_iv = %hex_dump(&wi), - read_key = %hex_dump(&rk), - read_iv = %hex_dump(&ri), - srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - srv_port = %hex_dump(&server_port_bytes), - clt_port = %hex_dump(&client_port_bytes), - crypto_ts = %hex_dump(&ts_bytes), - nonce_srv = %hex_dump(&srv_nonce), - nonce_clt = %hex_dump(&my_nonce), - prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), - prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), - hs_plain = %hex_dump(&hs_frame), - proxy_secret_sha256 = %hex_dump(&sha256(&secret)), - "ME diag: derived keys and handshake plaintext" - ); - } - if diag_level >= 2 { - info!( - prekey_client = %hex_dump(&prekey_client), - prekey_server = %hex_dump(&prekey_server), - "ME diag: full prekey buffers" - ); - } - - let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; - if diag_level >= 1 { - info!( - hs_cipher = %hex_dump(&encrypted_hs), - "ME diag: handshake ciphertext" - ); - } - wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); - let mut enc_buf = BytesMut::with_capacity(256); - let mut dec_buf = BytesMut::with_capacity(256); - let mut read_iv = ri; - let mut handshake_ok = false; - - while Instant::now() < deadline && !handshake_ok { - let remaining = deadline - Instant::now(); - let mut tmp = [0u8; 256]; - let n = match timeout(remaining, rd.read(&mut tmp)).await { - Ok(Ok(0)) => { - return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "ME closed during handshake", - ))); - } - Ok(Ok(n)) => n, - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) => return Err(ProxyError::TgHandshakeTimeout), - }; - - enc_buf.extend_from_slice(&tmp[..n]); - - let blocks = enc_buf.len() / 16 * 16; - if blocks > 0 { - let mut chunk = vec![0u8; blocks]; - chunk.copy_from_slice(&enc_buf[..blocks]); - read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; - dec_buf.extend_from_slice(&chunk); - let _ = enc_buf.split_to(blocks); - } - - while dec_buf.len() >= 4 { - let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; - - if fl == 4 { - let _ = dec_buf.split_to(4); - continue; - } - if !(12..=(1 << 24)).contains(&fl) { - return Err(ProxyError::InvalidHandshake(format!( - "Bad HS response frame len: {fl}" - ))); - } - if dec_buf.len() < fl { - break; - } - - let frame = dec_buf.split_to(fl); - let pe = fl - 4; - let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - let ac = crate::crypto::crc32(&frame[..pe]); - if ec != ac { - return Err(ProxyError::InvalidHandshake(format!( - "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" - ))); - } - - let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); - if hs_type == RPC_HANDSHAKE_ERROR_U32 { - let err_code = if frame.len() >= 16 { - i32::from_le_bytes(frame[12..16].try_into().unwrap()) - } else { - -1 - }; - return Err(ProxyError::InvalidHandshake(format!( - "ME rejected handshake (error={err_code})" - ))); - } - if hs_type != RPC_HANDSHAKE_U32 { - return Err(ProxyError::InvalidHandshake(format!( - "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" - ))); - } - - handshake_ok = true; - break; - } - } - - if !handshake_ok { - return Err(ProxyError::TgHandshakeTimeout); - } - - info!(%addr, "RPC handshake OK"); + let (stream, _connect_ms) = self.connect_tcp(addr).await?; + let hs = self.handshake_only(stream, addr, rng).await?; + let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); + let cancel = CancellationToken::new(); + let degraded = Arc::new(AtomicBool::new(false)); let rpc_w = Arc::new(Mutex::new(RpcWriter { - writer: wr, - key: wk, - iv: write_iv, + writer: hs.wr, + key: hs.write_key, + iv: hs.write_iv, seq_no: 0, })); - self.writers.write().await.push((addr, rpc_w.clone())); + let writer = MeWriter { + id: writer_id, + addr, + writer: rpc_w.clone(), + cancel: cancel.clone(), + degraded: degraded.clone(), + }; + self.writers.write().await.push(writer.clone()); let reg = self.registry.clone(); - let w_pong = rpc_w.clone(); - let w_pool = self.writers_arc(); - let w_ping = rpc_w.clone(); - let w_pool_ping = self.writers_arc(); + let writers_arc = self.writers_arc(); + let ping_tracker = self.ping_tracker.clone(); + let rtt_stats = self.rtt_stats.clone(); + let pool = Arc::downgrade(self); + let cancel_ping = cancel.clone(); + let rpc_w_ping = rpc_w.clone(); + let ping_tracker_ping = ping_tracker.clone(); + tokio::spawn(async move { - if let Err(e) = - reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await - { + let cancel_reader = cancel.clone(); + let res = reader_loop( + hs.rd, + hs.read_key, + hs.read_iv, + reg.clone(), + BytesMut::new(), + BytesMut::new(), + rpc_w.clone(), + ping_tracker.clone(), + rtt_stats.clone(), + writer_id, + degraded.clone(), + cancel_reader.clone(), + ) + .await; + if let Some(pool) = pool.upgrade() { + pool.remove_writer_and_reroute(writer_id).await; + } + if let Err(e) = res { warn!(error = %e, "ME reader ended"); } - let mut ws = w_pool.write().await; - ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong)); + let mut ws = writers_arc.write().await; + ws.retain(|w| w.id != writer_id); info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); + + let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); loop { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - tokio::time::sleep(Duration::from_secs(wait)).await; + tokio::select! { + _ = cancel_ping.cancelled() => { + break; + } + _ = tokio::time::sleep(Duration::from_secs(wait)) => {} + } let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&ping_id.to_le_bytes()); ping_id = ping_id.wrapping_add(1); - if let Err(e) = w_ping.lock().await.send(&p).await { + { + let mut tracker = ping_tracker_ping.lock().await; + tracker.insert(ping_id, (std::time::Instant::now(), writer_id)); + } + if let Err(e) = rpc_w_ping.lock().await.send(&p).await { debug!(error = %e, "Active ME ping failed, removing dead writer"); - let mut ws = w_pool_ping.write().await; - ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping)); + cancel_ping.cancel(); + if let Some(pool) = pool_ping.upgrade() { + pool.remove_writer_and_reroute(writer_id).await; + } break; } } @@ -585,6 +331,124 @@ impl MePool { Ok(()) } + async fn connect_primary_for_dc( + self: Arc, + dc: i32, + mut addrs: Vec<(IpAddr, u16)>, + rng: Arc, + ) { + if addrs.is_empty() { + return; + } + addrs.shuffle(&mut rand::rng()); + for (ip, port) in addrs { + let addr = SocketAddr::new(ip, port); + match self.connect_one(addr, rng.as_ref()).await { + Ok(()) => { + info!(%addr, dc = %dc, "ME connected"); + return; + } + Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), + } + } + warn!(dc = %dc, "All ME servers for DC failed at init"); + } + + pub(crate) async fn remove_writer_and_reroute(&self, writer_id: u64) { + let mut queue = self.remove_writer_only(writer_id).await; + while let Some(bound) = queue.pop() { + if !self.reroute_conn(&bound, &mut queue).await { + let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; + } + } + } + + async fn remove_writer_only(&self, writer_id: u64) -> Vec { + { + let mut ws = self.writers.write().await; + if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { + let w = ws.remove(pos); + w.cancel.cancel(); + } + } + self.registry.writer_lost(writer_id).await + } + + async fn reroute_conn(&self, bound: &BoundConn, backlog: &mut Vec) -> bool { + let payload = super::wire::build_proxy_req_payload( + bound.conn_id, + bound.meta.client_addr, + bound.meta.our_addr, + &[], + self.proxy_tag.as_deref(), + bound.meta.proto_flags, + ); + + let mut attempts = 0; + loop { + let writers_snapshot = { + let ws = self.writers.read().await; + if ws.is_empty() { + return false; + } + ws.clone() + }; + let mut candidates = self.candidate_indices_for_dc(&writers_snapshot, bound.meta.target_dc).await; + if candidates.is_empty() { + return false; + } + candidates.sort_by_key(|idx| { + writers_snapshot[*idx] + .degraded + .load(Ordering::Relaxed) + .then_some(1usize) + .unwrap_or(0) + }); + let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidates.len(); + + for offset in 0..candidates.len() { + let idx = candidates[(start + offset) % candidates.len()]; + let w = &writers_snapshot[idx]; + if let Ok(mut guard) = w.writer.try_lock() { + let send_res = guard.send(&payload).await; + drop(guard); + match send_res { + Ok(()) => { + self.registry + .bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone()) + .await; + return true; + } + Err(e) => { + warn!(error = %e, writer_id = w.id, "ME reroute send failed"); + backlog.extend(self.remove_writer_only(w.id).await); + } + } + continue; + } + } + + let w = writers_snapshot[candidates[start]].clone(); + match w.writer.lock().await.send(&payload).await { + Ok(()) => { + self.registry + .bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone()) + .await; + return true; + } + Err(e) => { + warn!(error = %e, writer_id = w.id, "ME reroute send failed (blocking)"); + backlog.extend(self.remove_writer_only(w.id).await); + } + } + + attempts += 1; + if attempts > 3 { + return false; + } + } + } + } fn hex_dump(data: &[u8]) -> String { diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 633d0af..3a69118 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,10 +1,27 @@ use std::net::{IpAddr, Ipv4Addr}; +use std::time::Duration; use tracing::{info, warn}; use crate::error::{ProxyError, Result}; use super::MePool; +use std::time::Instant; + +#[derive(Debug, Clone, Copy)] +pub struct StunProbeResult { + pub local_addr: std::net::SocketAddr, + pub reflected_addr: std::net::SocketAddr, +} + +pub async fn stun_probe(stun_addr: Option) -> Result> { + let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); + fetch_stun_binding(&stun_addr).await +} + +pub async fn detect_public_ip() -> Option { + fetch_public_ipv4_with_retry().await.ok().flatten().map(IpAddr::V4) +} impl MePool { pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { @@ -82,16 +99,30 @@ impl MePool { } pub(super) async fn maybe_reflect_public_addr(&self) -> Option { + const STUN_CACHE_TTL: Duration = Duration::from_secs(600); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + if let Some((ts, addr)) = *cache { + if ts.elapsed() < STUN_CACHE_TTL { + return Some(addr); + } + } + } + let stun_addr = self .nat_stun .clone() .unwrap_or_else(|| "stun.l.google.com:19302".to_string()); match fetch_stun_binding(&stun_addr).await { Ok(sa) => { - if let Some(sa) = sa { - info!(%sa, "NAT probe: reflected address"); + if let Some(result) = sa { + info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address"); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + *cache = Some((Instant::now(), result.reflected_addr)); + } + Some(result.reflected_addr) + } else { + None } - sa } Err(e) => { warn!(error = %e, "NAT probe failed"); @@ -128,7 +159,7 @@ async fn fetch_public_ipv4_once(url: &str) -> Result> { Ok(ip) } -async fn fetch_stun_binding(stun_addr: &str) -> Result> { +async fn fetch_stun_binding(stun_addr: &str) -> Result> { use rand::RngCore; use tokio::net::UdpSocket; @@ -196,10 +227,17 @@ async fn fetch_stun_binding(stun_addr: &str) -> Result {} } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 83df742..3ca02d5 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -1,9 +1,13 @@ +use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Instant; use bytes::{Bytes, BytesMut}; use tokio::io::AsyncReadExt; use tokio::net::TcpStream; use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; use crate::crypto::{AesCbc, crc32}; @@ -21,12 +25,21 @@ pub(crate) async fn reader_loop( enc_leftover: BytesMut, mut dec: BytesMut, writer: Arc>, + ping_tracker: Arc>>, + rtt_stats: Arc>>, + _writer_id: u64, + degraded: Arc, + cancel: CancellationToken, ) -> Result<()> { let mut raw = enc_leftover; + let mut expected_seq: i32 = 0; loop { let mut tmp = [0u8; 16_384]; - let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?; + let n = tokio::select! { + res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?, + _ = cancel.cancelled() => return Ok(()), + }; if n == 0 { return Ok(()); } @@ -70,6 +83,14 @@ pub(crate) async fn reader_loop( continue; } + let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); + if seq_no != expected_seq { + warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); + expected_seq = seq_no.wrapping_add(1); + } else { + expected_seq = expected_seq.wrapping_add(1); + } + let payload = &frame[8..pe]; if payload.len() < 4 { continue; @@ -119,6 +140,23 @@ pub(crate) async fn reader_loop( warn!(error = %e, "PONG send failed"); break; } + } else if pt == RPC_PONG_U32 && body.len() >= 8 { + let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); + if let Some((sent, wid)) = { + let mut guard = ping_tracker.lock().await; + guard.remove(&ping_id) + } { + let rtt = sent.elapsed().as_secs_f64() * 1000.0; + let mut stats = rtt_stats.lock().await; + let entry = stats.entry(wid).or_insert((rtt, rtt)); + entry.1 = entry.1 * 0.8 + rtt * 0.2; + if rtt < entry.0 { + entry.0 = rtt; + } + let degraded_now = entry.1 > entry.0 * 2.0; + degraded.store(degraded_now, Ordering::Relaxed); + trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample"); + } } else { debug!( rpc_type = format_args!("0x{pt:08x}"), diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index abb2b0c..9905d1d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,58 +1,133 @@ use std::collections::HashMap; +use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; - -use tokio::sync::{RwLock, mpsc}; - -use super::MeResponse; -use super::codec::RpcWriter; use std::sync::Arc; -use tokio::sync::Mutex; + +use tokio::sync::{mpsc, Mutex, RwLock}; + +use super::codec::RpcWriter; +use super::MeResponse; + +#[derive(Clone)] +pub struct ConnMeta { + pub target_dc: i16, + pub client_addr: SocketAddr, + pub our_addr: SocketAddr, + pub proto_flags: u32, +} + +#[derive(Clone)] +pub struct BoundConn { + pub conn_id: u64, + pub meta: ConnMeta, +} + +#[derive(Clone)] +pub struct ConnWriter { + pub writer_id: u64, + pub writer: Arc>, +} pub struct ConnRegistry { map: RwLock>>, writers: RwLock>>>, + writer_for_conn: RwLock>, + conns_for_writer: RwLock>>, + meta: RwLock>, next_id: AtomicU64, } impl ConnRegistry { pub fn new() -> Self { - // Avoid fully predictable conn_id sequence from 1. let start = rand::random::() | 1; Self { map: RwLock::new(HashMap::new()), writers: RwLock::new(HashMap::new()), + writer_for_conn: RwLock::new(HashMap::new()), + conns_for_writer: RwLock::new(HashMap::new()), + meta: RwLock::new(HashMap::new()), next_id: AtomicU64::new(start), } } pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); - let (tx, rx) = mpsc::channel(256); + let (tx, rx) = mpsc::channel(1024); self.map.write().await.insert(id, tx); (id, rx) } pub async fn unregister(&self, id: u64) { self.map.write().await.remove(&id); - self.writers.write().await.remove(&id); + self.meta.write().await.remove(&id); + if let Some(writer_id) = self.writer_for_conn.write().await.remove(&id) { + if let Some(list) = self.conns_for_writer.write().await.get_mut(&writer_id) { + list.retain(|c| *c != id); + } + } } pub async fn route(&self, id: u64, resp: MeResponse) -> bool { let m = self.map.read().await; if let Some(tx) = m.get(&id) { - tx.send(resp).await.is_ok() + tx.try_send(resp).is_ok() } else { false } } - pub async fn set_writer(&self, id: u64, w: Arc>) { - let mut guard = self.writers.write().await; - guard.entry(id).or_insert_with(|| w); + pub async fn bind_writer( + &self, + conn_id: u64, + writer_id: u64, + writer: Arc>, + meta: ConnMeta, + ) { + self.meta.write().await.entry(conn_id).or_insert(meta); + self.writer_for_conn.write().await.insert(conn_id, writer_id); + self.writers.write().await.entry(writer_id).or_insert_with(|| writer.clone()); + self.conns_for_writer + .write() + .await + .entry(writer_id) + .or_insert_with(Vec::new) + .push(conn_id); } - pub async fn get_writer(&self, id: u64) -> Option>> { - let guard = self.writers.read().await; - guard.get(&id).cloned() + pub async fn get_writer(&self, conn_id: u64) -> Option { + let writer_id = { + let guard = self.writer_for_conn.read().await; + guard.get(&conn_id).cloned() + }?; + let writer = { + let guard = self.writers.read().await; + guard.get(&writer_id).cloned() + }?; + Some(ConnWriter { writer_id, writer }) + } + + pub async fn writer_lost(&self, writer_id: u64) -> Vec { + self.writers.write().await.remove(&writer_id); + let conns = self.conns_for_writer.write().await.remove(&writer_id).unwrap_or_default(); + + let mut out = Vec::new(); + let mut writer_for_conn = self.writer_for_conn.write().await; + let meta = self.meta.read().await; + + for conn_id in conns { + writer_for_conn.remove(&conn_id); + if let Some(m) = meta.get(&conn_id) { + out.push(BoundConn { + conn_id, + meta: m.clone(), + }); + } + } + out + } + + pub async fn get_meta(&self, conn_id: u64) -> Option { + let guard = self.meta.read().await; + guard.get(&conn_id).cloned() } } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs new file mode 100644 index 0000000..5313bdb --- /dev/null +++ b/src/transport/middle_proxy/rotation.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; +use std::time::Duration; + +use tracing::{info, warn}; + +use crate::crypto::SecureRandom; + +use super::MePool; + +/// Periodically refresh ME connections to avoid long-lived degradation. +pub async fn me_rotation_task(pool: Arc, rng: Arc, interval: Duration) { + let interval = interval.max(Duration::from_secs(600)); + loop { + tokio::time::sleep(interval).await; + + let candidate = { + let ws = pool.writers.read().await; + ws.get(0).cloned() + }; + + let Some(w) = candidate else { + continue; + }; + + info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); + match pool.connect_one(w.addr, rng.as_ref()).await { + Ok(()) => { + // Remove old writer after new one is up. + pool.remove_writer_and_reroute(w.id).await; + } + Err(e) => { + warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); + } + } + } +} + diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 9dba939..a9e224d 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,6 +1,8 @@ use std::time::Duration; use tracing::{debug, info, warn}; +use std::time::SystemTime; +use httpdate; use crate::error::{ProxyError, Result}; @@ -63,6 +65,23 @@ pub async fn download_proxy_secret() -> Result> { ))); } + if let Some(date) = resp.headers().get(reqwest::header::DATE) { + if let Ok(date_str) = date.to_str() { + if let Ok(server_time) = httpdate::parse_http_date(date_str) { + if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| { + server_time.duration_since(SystemTime::now()).map_err(|_| e) + }) { + let skew_secs = skew.as_secs(); + if skew_secs > 60 { + warn!(skew_secs, "Time skew >60s detected from proxy-secret Date header"); + } else if skew_secs > 30 { + warn!(skew_secs, "Time skew >30s detected from proxy-secret Date header"); + } + } + } + } + } + let data = resp .bytes() .await diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 29e6e50..174127d 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -1,6 +1,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::Ordering; +use std::time::Duration; use tokio::sync::Mutex; use tracing::{debug, warn}; @@ -9,14 +10,14 @@ use crate::error::{ProxyError, Result}; use crate::protocol::constants::RPC_CLOSE_EXT_U32; use super::MePool; -use super::codec::RpcWriter; use super::wire::build_proxy_req_payload; use crate::crypto::SecureRandom; use rand::seq::SliceRandom; +use super::registry::ConnMeta; impl MePool { pub async fn send_proxy_req( - &self, + self: &Arc, conn_id: u64, target_dc: i16, client_addr: SocketAddr, @@ -32,18 +33,50 @@ impl MePool { self.proxy_tag.as_deref(), proto_flags, ); + let meta = ConnMeta { + target_dc, + client_addr, + our_addr, + proto_flags, + }; + let mut emergency_attempts = 0; loop { - let ws = self.writers.read().await; - if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); + if let Some(current) = self.registry.get_writer(conn_id).await { + let send_res = { + if let Ok(mut guard) = current.writer.try_lock() { + let r = guard.send(&payload).await; + drop(guard); + r + } else { + current.writer.lock().await.send(&payload).await + } + }; + match send_res { + Ok(()) => return Ok(()), + Err(e) => { + warn!(error = %e, writer_id = current.writer_id, "ME write failed"); + self.remove_writer_and_reroute(current.writer_id).await; + continue; + } + } } - let writers: Vec<(SocketAddr, Arc>)> = ws.iter().cloned().collect(); - drop(ws); - let mut candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await; + let mut writers_snapshot = { + let ws = self.writers.read().await; + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + ws.clone() + }; + + let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; if candidate_indices.is_empty() { - // Emergency: try to connect to target DC addresses on the fly, then recompute writers + // Emergency connect-on-demand + if emergency_attempts >= 3 { + return Err(ProxyError::Proxy("No ME writers available for target DC".into())); + } + emergency_attempts += 1; let map = self.proxy_map_v4.read().await; if let Some(addrs) = map.get(&(target_dc as i32)) { let mut shuffled = addrs.clone(); @@ -55,65 +88,73 @@ impl MePool { break; } } + tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await; let ws2 = self.writers.read().await; - let writers: Vec<(SocketAddr, Arc>)> = ws2.iter().cloned().collect(); + writers_snapshot = ws2.clone(); drop(ws2); - candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await; + candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; } if candidate_indices.is_empty() { return Err(ProxyError::Proxy("No ME writers available for target DC".into())); } } + + candidate_indices.sort_by_key(|idx| { + writers_snapshot[*idx] + .degraded + .load(Ordering::Relaxed) + .then_some(1usize) + .unwrap_or(0) + }); + let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); - // Prefer immediately available writer to avoid waiting on stalled connection. for offset in 0..candidate_indices.len() { - let cidx = (start + offset) % candidate_indices.len(); - let idx = candidate_indices[cidx]; - let w = writers[idx].1.clone(); - if let Ok(mut guard) = w.try_lock() { + let idx = candidate_indices[(start + offset) % candidate_indices.len()]; + let w = &writers_snapshot[idx]; + if let Ok(mut guard) = w.writer.try_lock() { let send_res = guard.send(&payload).await; drop(guard); match send_res { - Ok(()) => return Ok(()), + Ok(()) => { + self.registry + .bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) + .await; + return Ok(()); + } Err(e) => { - warn!(error = %e, "ME write failed, removing dead conn"); - let mut ws = self.writers.write().await; - ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); - if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); - } + warn!(error = %e, writer_id = w.id, "ME write failed"); + self.remove_writer_and_reroute(w.id).await; continue; } } } } - // All writers are currently busy, wait for the selected one. - let w = writers[candidate_indices[start]].1.clone(); - match w.lock().await.send(&payload).await { - Ok(()) => return Ok(()), + let w = writers_snapshot[candidate_indices[start]].clone(); + match w.writer.lock().await.send(&payload).await { + Ok(()) => { + self.registry + .bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) + .await; + return Ok(()); + } Err(e) => { - warn!(error = %e, "ME write failed, removing dead conn"); - let mut ws = self.writers.write().await; - ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); - if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); - } + warn!(error = %e, writer_id = w.id, "ME write failed (blocking)"); + self.remove_writer_and_reroute(w.id).await; } } } } - pub async fn send_close(&self, conn_id: u64) -> Result<()> { + pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { if let Some(w) = self.registry.get_writer(conn_id).await { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = w.lock().await.send(&p).await { + if let Err(e) = w.writer.lock().await.send(&p).await { debug!(error = %e, "ME close write failed"); - let mut ws = self.writers.write().await; - ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); + self.remove_writer_and_reroute(w.writer_id).await; } } else { debug!(conn_id, "ME close skipped (writer missing)"); @@ -129,7 +170,7 @@ impl MePool { pub(super) async fn candidate_indices_for_dc( &self, - writers: &[(SocketAddr, Arc>)], + writers: &[super::pool::MeWriter], target_dc: i16, ) -> Vec { let mut preferred = Vec::::new(); @@ -165,8 +206,8 @@ impl MePool { } let mut out = Vec::new(); - for (idx, (addr, _)) in writers.iter().enumerate() { - if preferred.iter().any(|p| p == addr) { + for (idx, w) in writers.iter().enumerate() { + if preferred.iter().any(|p| *p == w.addr) { out.push(idx); } } diff --git a/src/transport/pool.rs b/src/transport/pool.rs index 1daa998..8d83321 100644 --- a/src/transport/pool.rs +++ b/src/transport/pool.rs @@ -285,12 +285,17 @@ where #[cfg(test)] mod tests { use super::*; + use std::io::ErrorKind; use tokio::net::TcpListener; #[tokio::test] async fn test_pool_basic() { // Start a test server - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(l) => l, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("bind failed: {e}"), + }; let addr = listener.local_addr().unwrap(); // Accept connections in background @@ -303,7 +308,11 @@ mod tests { let pool = ConnectionPool::new(); // Get a connection - let conn1 = pool.get(addr).await.unwrap(); + let conn1 = match pool.get(addr).await { + Ok(c) => c, + Err(ProxyError::Io(e)) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("connect failed: {e}"), + }; // Return it to pool pool.put(addr, conn1).await; @@ -335,4 +344,4 @@ mod tests { assert_eq!(stats.endpoints, 0); assert_eq!(stats.total_connections, 0); } -} \ No newline at end of file +} diff --git a/src/transport/socket.rs b/src/transport/socket.rs index a07c21c..a4a7034 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -205,15 +205,29 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result l, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("bind failed: {e}"), + }; let addr = listener.local_addr().unwrap(); - let stream = TcpStream::connect(addr).await.unwrap(); - configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap(); + let stream = match TcpStream::connect(addr).await { + Ok(s) => s, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("connect failed: {e}"), + }; + if let Err(e) = configure_tcp_socket(&stream, true, Duration::from_secs(30)) { + if e.kind() == ErrorKind::PermissionDenied { + return; + } + panic!("configure_tcp_socket failed: {e}"); + } } #[test] @@ -234,4 +248,4 @@ mod tests { assert!(opts.reuse_port); assert_eq!(opts.backlog, 1024); } -} \ No newline at end of file +} diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 4b5fe9c..db0d366 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -2,6 +2,7 @@ //! //! IPv6/IPv4 connectivity checks with configurable preference. +use std::collections::HashMap; use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; use std::time::Duration; @@ -350,7 +351,11 @@ impl UpstreamManager { /// Ping all Telegram DCs through all upstreams. /// Tests BOTH IPv6 and IPv4, returns separate results for each. - pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec { + pub async fn ping_all_dcs( + &self, + prefer_ipv6: bool, + dc_overrides: &HashMap>, + ) -> Vec { let upstreams: Vec<(usize, UpstreamConfig)> = { let guard = self.upstreams.read().await; guard.iter().enumerate() @@ -450,6 +455,58 @@ impl UpstreamManager { v4_results.push(ping_result); } + // === Ping DC overrides (v4/v6) === + for (dc_key, addrs) in dc_overrides { + let dc_num: i16 = match dc_key.parse::() { + Ok(v) if v > 0 => v, + Err(_) => { + warn!(dc = %dc_key, "Invalid dc_overrides key, skipping"); + continue; + }, + _ => continue, + }; + let dc_idx = dc_num as usize; + for addr_str in addrs { + match addr_str.parse::() { + Ok(addr) => { + let is_v6 = addr.is_ipv6(); + let result = tokio::time::timeout( + Duration::from_secs(DC_PING_TIMEOUT_SECS), + self.ping_single_dc(&upstream_config, addr) + ).await; + + let ping_result = match result { + Ok(Ok(rtt_ms)) => DcPingResult { + dc_idx, + dc_addr: addr, + rtt_ms: Some(rtt_ms), + error: None, + }, + Ok(Err(e)) => DcPingResult { + dc_idx, + dc_addr: addr, + rtt_ms: None, + error: Some(e.to_string()), + }, + Err(_) => DcPingResult { + dc_idx, + dc_addr: addr, + rtt_ms: None, + error: Some("timeout".to_string()), + }, + }; + + if is_v6 { + v6_results.push(ping_result); + } else { + v4_results.push(ping_result); + } + } + Err(_) => warn!(dc = %dc_idx, addr = %addr_str, "Invalid dc_overrides address, skipping"), + } + } + } + // Check if both IP versions have at least one working DC let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some()); let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some()); @@ -624,4 +681,4 @@ impl UpstreamManager { Some(SocketAddr::new(ip, TG_DATACENTER_PORT)) } -} \ No newline at end of file +} diff --git a/telemt b/telemt deleted file mode 100644 index fbb8d6f..0000000 Binary files a/telemt and /dev/null differ diff --git a/tools/dc.py b/tools/dc.py index f142baf..d431966 100644 --- a/tools/dc.py +++ b/tools/dc.py @@ -1,121 +1,204 @@ +"""Telegram datacenter server checker.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from itertools import groupby +from operator import attrgetter +from pathlib import Path +from typing import TYPE_CHECKING + from telethon import TelegramClient from telethon.tl.functions.help import GetConfigRequest -import asyncio -api_id = '' -api_hash = '' +if TYPE_CHECKING: + from telethon.tl.types import DcOption -async def get_all_servers(): - print("🔄 Подключаемся к Telegram...") - client = TelegramClient('session', api_id, api_hash) - - await client.start() - print("✅ Подключение установлено!\n") - - print("📡 Запрашиваем конфигурацию серверов...") - config = await client(GetConfigRequest()) - - print(f"📊 Получено серверов: {len(config.dc_options)}\n") - print("="*80) - - # Группируем серверы по DC ID - dc_groups = {} - for dc in config.dc_options: - if dc.id not in dc_groups: - dc_groups[dc.id] = [] - dc_groups[dc.id].append(dc) - - # Выводим все серверы, сгруппированные по DC - for dc_id in sorted(dc_groups.keys()): - servers = dc_groups[dc_id] - print(f"\n🌐 DATACENTER {dc_id} ({len(servers)} серверов)") - print("-" * 80) - - for dc in servers: - # Собираем флаги - flags = [] - if dc.ipv6: - flags.append("IPv6") - if dc.media_only: - flags.append("🎬 MEDIA-ONLY") - if dc.cdn: - flags.append("📦 CDN") - if dc.tcpo_only: - flags.append("🔒 TCPO") - if dc.static: - flags.append("📌 STATIC") - - flags_str = f" [{', '.join(flags)}]" if flags else " [STANDARD]" - - # Форматируем IP (выравниваем для читаемости) - ip_display = f"{dc.ip_address:45}" - - print(f" {ip_display}:{dc.port:5}{flags_str}") - - # Статистика - print("\n" + "="*80) - print("📈 СТАТИСТИКА:") - print("="*80) - - total = len(config.dc_options) - ipv4_count = sum(1 for dc in config.dc_options if not dc.ipv6) - ipv6_count = sum(1 for dc in config.dc_options if dc.ipv6) - media_count = sum(1 for dc in config.dc_options if dc.media_only) - cdn_count = sum(1 for dc in config.dc_options if dc.cdn) - tcpo_count = sum(1 for dc in config.dc_options if dc.tcpo_only) - static_count = sum(1 for dc in config.dc_options if dc.static) - - print(f" Всего серверов: {total}") - print(f" IPv4 серверы: {ipv4_count}") - print(f" IPv6 серверы: {ipv6_count}") - print(f" Media-only: {media_count}") - print(f" CDN серверы: {cdn_count}") - print(f" TCPO-only: {tcpo_count}") - print(f" Static: {static_count}") - - # Дополнительная информация из config - print("\n" + "="*80) - print("ℹ️ ДОПОЛНИТЕЛЬНАЯ ИНФОРМАЦИЯ:") - print("="*80) - print(f" Дата конфигурации: {config.date}") - print(f" Expires: {config.expires}") - print(f" Test mode: {config.test_mode}") - print(f" This DC: {config.this_dc}") - - # Сохраняем в файл - print("\n💾 Сохраняем результаты в файл telegram_servers.txt...") - with open('telegram_servers.txt', 'w', encoding='utf-8') as f: - f.write("TELEGRAM DATACENTER SERVERS\n") - f.write("="*80 + "\n\n") - - for dc_id in sorted(dc_groups.keys()): - servers = dc_groups[dc_id] - f.write(f"\nDATACENTER {dc_id} ({len(servers)} servers)\n") - f.write("-" * 80 + "\n") - - for dc in servers: - flags = [] - if dc.ipv6: - flags.append("IPv6") - if dc.media_only: - flags.append("MEDIA-ONLY") - if dc.cdn: - flags.append("CDN") - if dc.tcpo_only: - flags.append("TCPO") - if dc.static: - flags.append("STATIC") - - flags_str = f" [{', '.join(flags)}]" if flags else " [STANDARD]" - f.write(f" {dc.ip_address}:{dc.port}{flags_str}\n") - - f.write(f"\n\nTotal servers: {total}\n") - f.write(f"Generated: {config.date}\n") - - print("✅ Результаты сохранены в telegram_servers.txt") - - await client.disconnect() - print("\n👋 Отключились от Telegram") +API_ID: int = 123456 +API_HASH: str = "" +SESSION_NAME: str = "session" +OUTPUT_FILE: Path = Path("telegram_servers.txt") -if __name__ == '__main__': - asyncio.run(get_all_servers()) \ No newline at end of file +_CONSOLE_FLAG_MAP: dict[str, str] = { + "IPv6": "IPv6", + "MEDIA-ONLY": "🎬 MEDIA-ONLY", + "CDN": "📦 CDN", + "TCPO": "🔒 TCPO", + "STATIC": "📌 STATIC", +} + + +@dataclass(frozen=True, slots=True) +class DCServer: + """Typed representation of a Telegram DC server. + + Attributes: + dc_id: Datacenter identifier. + ip: Server IP address. + port: Server port. + flags: Active flag labels (plain, without emoji). + """ + + dc_id: int + ip: str + port: int + flags: frozenset[str] = field(default_factory=frozenset) + + @classmethod + def from_option(cls, dc: DcOption) -> DCServer: + """Create from a Telethon DcOption. + + Args: + dc: Raw DcOption object. + + Returns: + Parsed DCServer instance. + """ + checks: dict[str, bool] = { + "IPv6": dc.ipv6, + "MEDIA-ONLY": dc.media_only, + "CDN": dc.cdn, + "TCPO": dc.tcpo_only, + "STATIC": dc.static, + } + return cls( + dc_id=dc.id, + ip=dc.ip_address, + port=dc.port, + flags=frozenset(k for k, v in checks.items() if v), + ) + + def flags_display(self, *, emoji: bool = False) -> str: + """Formatted flags string. + + Args: + emoji: Whether to include emoji prefixes. + + Returns: + Bracketed flags or '[STANDARD]'. + """ + if not self.flags: + return "[STANDARD]" + labels = sorted( + _CONSOLE_FLAG_MAP[f] if emoji else f for f in self.flags + ) + return f"[{', '.join(labels)}]" + + +class TelegramDCChecker: + """Fetches and displays Telegram DC configuration. + + Attributes: + _client: Telethon client instance. + _servers: Parsed server list. + """ + + def __init__(self) -> None: + """Initialize the checker.""" + self._client = TelegramClient(SESSION_NAME, API_ID, API_HASH) + self._servers: list[DCServer] = [] + + async def run(self) -> None: + """Connect, fetch config, display and save results.""" + print("🔄 Подключаемся к Telegram...") # noqa: T201 + try: + await self._client.start() + print("✅ Подключение установлено!\n") # noqa: T201 + + print("📡 Запрашиваем конфигурацию серверов...") # noqa: T201 + config = await self._client(GetConfigRequest()) + self._servers = [DCServer.from_option(dc) for dc in config.dc_options] + + self._print(config) + self._save(config) + finally: + await self._client.disconnect() + print("\n👋 Отключились от Telegram") # noqa: T201 + + def _grouped(self) -> dict[int, list[DCServer]]: + """Group servers by DC ID. + + Returns: + Ordered mapping of DC ID to servers. + """ + ordered = sorted(self._servers, key=attrgetter("dc_id")) + return {k: list(g) for k, g in groupby(ordered, key=attrgetter("dc_id"))} + + def _print(self, config: object) -> None: + """Print results to stdout in original format. + + Args: + config: Raw Telegram config. + """ + sep = "=" * 80 + dash = "-" * 80 + total = len(self._servers) + + print(f"📊 Получено серверов: {total}\n") # noqa: T201 + print(sep) # noqa: T201 + + for dc_id, servers in self._grouped().items(): + print(f"\n🌐 DATACENTER {dc_id} ({len(servers)} серверов)") # noqa: T201 + print(dash) # noqa: T201 + for s in servers: + print(f" {s.ip:45}:{s.port:5} {s.flags_display(emoji=True)}") # noqa: T201 + + ipv4 = total - self._flag_count("IPv6") + print(f"\n{sep}") # noqa: T201 + print("📈 СТАТИСТИКА:") # noqa: T201 + print(sep) # noqa: T201 + print(f" Всего серверов: {total}") # noqa: T201 + print(f" IPv4 серверы: {ipv4}") # noqa: T201 + print(f" IPv6 серверы: {self._flag_count('IPv6')}") # noqa: T201 + print(f" Media-only: {self._flag_count('MEDIA-ONLY')}") # noqa: T201 + print(f" CDN серверы: {self._flag_count('CDN')}") # noqa: T201 + print(f" TCPO-only: {self._flag_count('TCPO')}") # noqa: T201 + print(f" Static: {self._flag_count('STATIC')}") # noqa: T201 + + print(f"\n{sep}") # noqa: T201 + print("ℹ️ ДОПОЛНИТЕЛЬНАЯ ИНФОРМАЦИЯ:") # noqa: T201 + print(sep) # noqa: T201 + print(f" Дата конфигурации: {config.date}") # noqa: T201 # type: ignore[attr-defined] + print(f" Expires: {config.expires}") # noqa: T201 # type: ignore[attr-defined] + print(f" Test mode: {config.test_mode}") # noqa: T201 # type: ignore[attr-defined] + print(f" This DC: {config.this_dc}") # noqa: T201 # type: ignore[attr-defined] + + def _flag_count(self, flag: str) -> int: + """Count servers with a given flag. + + Args: + flag: Flag name. + + Returns: + Count of matching servers. + """ + return sum(1 for s in self._servers if flag in s.flags) + + def _save(self, config: object) -> None: + """Save results to file in original format. + + Args: + config: Raw Telegram config. + """ + parts: list[str] = [] + parts.append("TELEGRAM DATACENTER SERVERS\n") + parts.append("=" * 80 + "\n\n") + + for dc_id, servers in self._grouped().items(): + parts.append(f"\nDATACENTER {dc_id} ({len(servers)} servers)\n") + parts.append("-" * 80 + "\n") + for s in servers: + parts.append(f" {s.ip}:{s.port} {s.flags_display(emoji=False)}\n") + + parts.append(f"\n\nTotal servers: {len(self._servers)}\n") + parts.append(f"Generated: {config.date}\n") # type: ignore[attr-defined] + + OUTPUT_FILE.write_text("".join(parts), encoding="utf-8") + + print(f"\n💾 Сохраняем результаты в файл {OUTPUT_FILE}...") # noqa: T201 + print(f"✅ Результаты сохранены в {OUTPUT_FILE}") # noqa: T201 + + +if __name__ == "__main__": + asyncio.run(TelegramDCChecker().run())