diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1e8d5a0..116c1d4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,8 +3,8 @@ name: Release on: push: tags: - - '[0-9]+.[0-9]+.[0-9]+' # Matches tags like 3.0.0, 3.1.2, etc. - workflow_dispatch: # Manual trigger from GitHub Actions UI + - '[0-9]+.[0-9]+.[0-9]+' # Matches tags like 3.0.0, 3.1.2, etc. + workflow_dispatch: # Manual trigger from GitHub Actions UI permissions: contents: read @@ -84,6 +84,32 @@ jobs: target/${{ matrix.target }}/release/${{ matrix.asset_name }}.tar.gz target/${{ matrix.target }}/release/${{ matrix.asset_name }}.sha256 + build-docker-image: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.TOKEN_GH_DEPLOY }} + + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ github.ref }} + release: name: Create Release needs: build @@ -108,17 +134,17 @@ jobs: # Extract version from tag (remove 'v' prefix if present) VERSION="${GITHUB_REF#refs/tags/}" VERSION="${VERSION#v}" - + # Install cargo-edit for version bumping cargo install cargo-edit - + # Update Cargo.toml version cargo set-version "$VERSION" - + # Configure git git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - + # Commit and push changes #git add Cargo.toml Cargo.lock #git commit -m "chore: bump version to $VERSION" || echo "No changes to commit" diff --git a/AGENTS_SYSTEM_PROMT.md b/AGENTS_SYSTEM_PROMT.md index cec8c38..e6c5f2e 100644 --- a/AGENTS_SYSTEM_PROMT.md +++ b/AGENTS_SYSTEM_PROMT.md @@ -1,6 +1,7 @@ ## System Prompt — Production Rust Codebase: Modification and Architecture Guidelines -You are a senior Rust systems engineer acting as a strict code reviewer and implementation partner. Your responses are precise, minimal, and architecturally sound. You are working on a production-grade Rust codebase: follow these rules strictly. +You are a senior Rust Engineer and pricipal Rust Architect acting as a strict code reviewer and implementation partner. +Your responses are precise, minimal, and architecturally sound. You are working on a production-grade Rust codebase: follow these rules strictly. --- @@ -32,6 +33,11 @@ The user can override this behavior with explicit commands: - `"Make minimal changes"` — no coordinated fixes, narrowest possible diff. - `"Fix everything"` — apply all coordinated fixes and out-of-scope observations. +### Core Rule + +The codebase must never enter an invalid intermediate state. +No response may leave the repository in a condition that requires follow-up fixes. + --- ### 1. Comments and Documentation @@ -131,16 +137,32 @@ You MUST: - Document non-obvious logic with comments that describe *why*, not *what*. - Limit changes strictly to the requested scope (plus coordinated fixes per Section 0). - Keep all existing symbol names unless renaming is explicitly requested. -- Preserve global formatting as-is. +- Preserve global formatting as-is +- Result every modification in a self-contained, compilable, runnable state of the codebase You MUST NOT: -- Use placeholders: no `// ... rest of code`, no `// implement here`, no `/* TODO */` stubs that replace existing working code. Write full, working implementation. If the implementation is unclear, ask first. -- Refactor code outside the requested scope. -- Make speculative improvements. +- Use placeholders: no `// ... rest of code`, no `// implement here`, no `/* TODO */` stubs that replace existing working code. Write full, working implementation. If the implementation is unclear, ask first +- Refactor code outside the requested scope +- Make speculative improvements +- Spawn multiple agents for EDITING +- Produce partial changes +- Introduce references to entities that are not yet implemented +- Leave TODO placeholders in production paths Note: `todo!()` and `unimplemented!()` are allowed as idiomatic Rust markers for genuinely unfinished code paths. +Every change must: + - compile, + - pass type checks, + - have no broken imports, + - preserve invariants, + - not rely on future patches. + +If the task requires multiple phases: + - either implement all required phases, + - or explicitly refuse and explain missing dependencies. + --- ### 8. Decision Process for Complex Changes @@ -160,6 +182,7 @@ When facing a non-trivial modification, follow this sequence: - When provided with partial code, assume the rest of the codebase exists and functions correctly unless stated otherwise. - Reference existing types, functions, and module structures by their actual names as shown in the provided code. - When the provided context is insufficient to make a safe change, request the missing context explicitly. +- Spawn multiple agents for SEARCHING information, code, functions --- @@ -167,14 +190,14 @@ When facing a non-trivial modification, follow this sequence: #### Language Policy -- Code, comments, commit messages, documentation: **English**. -- Reasoning and explanations in response text: **Russian**. +- Code, comments, commit messages, documentation ONLY ON **English**! +- Reasoning and explanations in response text on language from promt #### Response Structure Your response MUST consist of two sections: -**Section 1: `## Reasoning` (in Russian)** +**Section 1: `## Reasoning`** - What needs to be done and why. - Which files and modules are affected. @@ -205,3 +228,183 @@ If the response exceeds the output limit: 2. List the files that will be provided in subsequent parts. 3. Wait for user confirmation before continuing. 4. No single file may be split across parts. + +## 11. Anti-LLM Degeneration Safeguards (Principal-Paranoid, Visionary) + +This section exists to prevent common LLM failure modes: scope creep, semantic drift, cargo-cult refactors, performance regressions, contract breakage, and hidden behavior changes. + +### 11.1 Non-Negotiable Invariants + +- **No semantic drift:** Do not reinterpret requirements, rename concepts, or change meaning of existing terms. +- **No “helpful refactors”:** Any refactor not explicitly requested is forbidden. +- **No architectural drift:** Do not introduce new layers, patterns, abstractions, or “clean architecture” migrations unless requested. +- **No dependency drift:** Do not add crates, features, or versions unless explicitly requested. +- **No behavior drift:** If a change could alter runtime behavior, you MUST call it out explicitly in `## Reasoning` and justify it. + +### 11.2 Minimal Surface Area Rule + +- Touch the smallest number of files possible. +- Prefer local changes over cross-cutting edits. +- Do not “align style” across a file/module—only adjust the modified region. +- Do not reorder items, imports, or code unless required for correctness. + +### 11.3 No Implicit Contract Changes + +Contracts include: +- public APIs, trait bounds, visibility, error types, timeouts/retries, logging semantics, metrics semantics, +- protocol formats, framing, padding, keepalive cadence, state machine transitions, +- concurrency guarantees, cancellation behavior, backpressure behavior. + +Rule: +- If you change a contract, you MUST update all dependents in the same patch AND document the contract delta explicitly. + +### 11.4 Hot-Path Preservation (Performance Paranoia) + +- Do not introduce extra allocations, cloning, or formatting in hot paths. +- Do not add logging/metrics on hot paths unless requested. +- Do not add new locks or broaden lock scope. +- Prefer `&str` / slices / borrowed data where the codebase already does so. +- Avoid `String` building for errors/logs if it changes current patterns. + +If you cannot prove performance neutrality, label it as risk in `## Reasoning`. + +### 11.5 Async / Concurrency Safety (Cancellation & Backpressure) + +- No blocking calls inside async contexts. +- Preserve cancellation safety: do not introduce `await` between lock acquisition and critical invariants unless already present. +- Preserve backpressure: do not replace bounded channels with unbounded, do not remove flow control. +- Do not change task lifecycle semantics (spawn patterns, join handles, shutdown order) unless requested. +- Do not introduce `tokio::spawn` / background tasks unless explicitly requested. + +### 11.6 Error Semantics Integrity + +- Do not replace structured errors with generic strings. +- Do not widen/narrow error types or change error categories without explicit approval. +- Avoid introducing panics in production paths (`unwrap`, `expect`) unless the codebase already treats that path as impossible and documented. + +### 11.7 “No New Abstractions” Default + +Default stance: +- No new traits, generics, macros, builder patterns, type-level cleverness, or “frameworking”. +- If abstraction is necessary, prefer the smallest possible local helper (private function) and justify it. + +### 11.8 Negative-Diff Protection + +Avoid “diff inflation” patterns: +- mass edits, +- moving code between files, +- rewrapping long lines, +- rearranging module order, +- renaming for aesthetics. + +If a diff becomes large, STOP and ask before proceeding. + +### 11.9 Consistency with Existing Style (But Not Style Refactors) + +- Follow existing conventions of the touched module (naming, error style, return patterns). +- Do not enforce global “best practices” that the codebase does not already use. + +### 11.10 Two-Phase Safety Gate (Plan → Patch) + +For non-trivial changes: +1) Provide a micro-plan (1–5 bullets): what files, what functions, what invariants, what risks. +2) Implement exactly that plan—no extra improvements. + +### 11.11 Pre-Response Checklist (Hard Gate) + +Before final output, verify internally: + +- No unresolved symbols / broken imports. +- No partially updated call sites. +- No new public surface changes unless requested. +- No transitional states / TODO placeholders replacing working code. +- Changes are atomic: the repository remains buildable and runnable. +- Any behavior change is explicitly stated. + +If any check fails: fix it before responding. + +### 11.12 Truthfulness Policy (No Hallucinated Claims) + +- Do not claim “this compiles” or “tests pass” unless you actually verified with the available tooling/context. +- If verification is not possible, state: “Not executed; reasoning-based consistency check only.” + +### 11.13 Visionary Guardrail: Preserve Optionality + +When multiple valid designs exist, prefer the one that: +- minimally constrains future evolution, +- preserves existing extension points, +- avoids locking the project into a new paradigm, +- keeps interfaces stable and implementation local. + +Default to reversible changes. + +### 11.14 Stop Conditions + +STOP and ask targeted questions if: +- required context is missing, +- a change would cross module boundaries, +- a contract might change, +- concurrency/protocol invariants are unclear, +- the diff is growing beyond a minimal patch. + +No guessing. + +### 12. Invariant Preservation + +You MUST explicitly preserve: +- Thread-safety guarantees (`Send` / `Sync` expectations). +- Memory safety assumptions (no hidden `unsafe` expansions). +- Lock ordering and deadlock invariants. +- State machine correctness (no new invalid transitions). +- Backward compatibility of serialized formats (if applicable). + +If a change touches concurrency, networking, protocol logic, or state machines, +you MUST explain why existing invariants remain valid. + +### 13. Error Handling Policy + +- Do not replace structured errors with generic strings. +- Preserve existing error propagation semantics. +- Do not widen or narrow error types without approval. +- Avoid introducing panics in production paths. +- Prefer explicit error mapping over implicit conversions. + +### 14. Test Safety + +- Do not modify existing tests unless the task explicitly requires it. +- Do not weaken assertions. +- Preserve determinism in testable components. + +### 15. Security Constraints + +- Do not weaken cryptographic assumptions. +- Do not modify key derivation logic without explicit request. +- Do not change constant-time behavior. +- Do not introduce logging of secrets. +- Preserve TLS/MTProto protocol correctness. + +### 16. Logging Policy + +- Do not introduce excessive logging in hot paths. +- Do not log sensitive data. +- Preserve existing log levels and style. + +### 17. Pre-Response Verification Checklist + +Before producing the final answer, verify internally: + +- The change compiles conceptually. +- No unresolved symbols exist. +- All modified call sites are updated. +- No accidental behavioral changes were introduced. +- Architectural boundaries remain intact. + +### 18. Atomic Change Principle +Every patch must be **atomic and production-safe**. +* **Self-contained** — no dependency on future patches or unimplemented components. +* **Build-safe** — the project must compile successfully after the change. +* **Contract-consistent** — no partial interface or behavioral changes; all dependent code must be updated within the same patch. +* **No transitional states** — no placeholders, incomplete refactors, or temporary inconsistencies. + +**Invariant:** After any single patch, the repository remains fully functional and buildable. + diff --git a/Cargo.toml b/Cargo.toml index f8ee25f..15563dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.0.8" +version = "3.0.11" edition = "2024" [dependencies] @@ -20,6 +20,7 @@ sha1 = "0.10" md-5 = "0.10" hmac = "0.12" crc32fast = "1.4" +crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } # Network diff --git a/README.md b/README.md index 3705856..a88b8df 100644 --- a/README.md +++ b/README.md @@ -10,41 +10,77 @@ ### 🇷🇺 RU -18 февраля мы опубликовали `telemt 3.0.3`, он имеет: +#### Драфтинг LTS и текущие улучшения -- улучшенный механизм Middle-End Health Check -- высокоскоростное восстановление инициализации Middle-End -- меньше задержек на hot-path -- более корректную работу в Dualstack, а именно - IPv6 Middle-End -- аккуратное переподключение клиента без дрифта сессий между Middle-End -- автоматическая деградация на Direct-DC при массовой (>2 ME-DC-групп) недоступности Middle-End -- автодетект IP за NAT, при возможности - будет выполнен хендшейк с ME, при неудаче - автодеградация -- единственный известный специальный DC=203 уже добавлен в код: медиа загружаются с CDN в Direct-DC режиме +С 21 февраля мы начали подготовку LTS-версии. -[Здесь вы можете найти релиз](https://github.com/telemt/telemt/releases/tag/3.0.3) +Мы внимательно анализируем весь доступный фидбек. +Наша цель — сделать LTS-кандидаты максимально стабильными, тщательно отлаженными и готовыми к long-run и highload production-сценариям. -Если у вас есть компетенции в асинхронных сетевых приложениях, анализе трафика, реверс-инжиниринге или сетевых расследованиях - мы открыты к идеям и pull requests! +--- +#### Улучшения от 23 февраля + +23 февраля были внесены улучшения производительности в режимах **DC** и **Middle-End (ME)**, с акцентом на обратный канал (путь клиент → DC / ME). + +Дополнительно реализован ряд изменений, направленных на повышение устойчивости системы: + +- Смягчение сетевой нестабильности +- Повышение устойчивости к десинхронизации криптографии +- Снижение дрейфа сессий при неблагоприятных условиях +- Улучшение обработки ошибок в edge-case транспортных сценариях + +Релиз: +[3.0.9](https://github.com/telemt/telemt/releases/tag/3.0.9) + +--- + +Если у вас есть компетенции в: + +- Асинхронных сетевых приложениях +- Анализе трафика +- Реверс-инжиниринге +- Сетевых расследованиях + +Мы открыты к архитектурным предложениям, идеям и pull requests ### 🇬🇧 EN -On February 18, we released `telemt 3.0.3`. This version introduces: +#### LTS Drafting and Ongoing Improvements -- improved Middle-End Health Check method -- high-speed recovery of Middle-End init -- reduced latency on the hot path -- correct Dualstack support: proper handling of IPv6 Middle-End -- *clean* client reconnection without session "drift" between Middle-End -- automatic degradation to Direct-DC mode in case of large-scale (>2 ME-DC groups) Middle-End unavailability -- automatic public IP detection behind NAT; first - Middle-End handshake is performed, otherwise automatic degradation is applied -- known special DC=203 is now handled natively: media is delivered from the CDN via Direct-DC mode +Starting February 21, we began drafting the upcoming LTS version. -[Release is available here](https://github.com/telemt/telemt/releases/tag/3.0.3) +We are carefully reviewing and analyzing all available feedback. +The goal is to ensure that LTS candidates are максимально stable, thoroughly debugged, and ready for long-run and high-load production scenarios. -If you have expertise in asynchronous network applications, traffic analysis, reverse engineering, or network forensics - we welcome ideas and pull requests! +--- +#### February 23 Improvements + +On February 23, we introduced performance improvements for both **DC** and **Middle-End (ME)** modes, specifically optimizing the reverse channel (client → DC / ME data path). + +Additionally, we implemented a set of robustness enhancements designed to: + +- Mitigate network-related instability +- Improve resilience against cryptographic desynchronization +- Reduce session drift under adverse conditions +- Improve error handling in edge-case transport scenarios + +Release: +[3.0.9](https://github.com/telemt/telemt/releases/tag/3.0.9) + +--- + +If you have expertise in: + +- Asynchronous network applications +- Traffic analysis +- Reverse engineering +- Network forensics + +We welcome ideas, architectural feedback, and pull requests. @@ -178,147 +214,21 @@ then Ctrl+X -> Y -> Enter to save ```toml # === General Settings === [general] -fast_mode = true -use_middle_proxy = true # ad_tag = "00000000000000000000000000000000" -# Path to proxy-secret binary (auto-downloaded if missing). -proxy_secret_path = "proxy-secret" -# disable_colors = false # Disable colored output in logs (useful for files/systemd) - -# === Log Level === -# Log level: debug | verbose | normal | silent -# Can be overridden with --silent or --log-level CLI flags -# RUST_LOG env var takes absolute priority over all of these -log_level = "normal" - -# === Middle Proxy - ME === -# Public IP override for ME KDF when behind NAT; leave unset to auto-detect. -# middle_proxy_nat_ip = "203.0.113.10" -# Enable STUN probing to discover public IP:port for ME. -middle_proxy_nat_probe = true -# Primary STUN server (host:port); defaults to Telegram STUN when empty. -middle_proxy_nat_stun = "stun.l.google.com:19302" -# Optional fallback STUN servers list. -middle_proxy_nat_stun_servers = ["stun1.l.google.com:19302", "stun2.l.google.com:19302"] -# Desired number of concurrent ME writers in pool. -middle_proxy_pool_size = 16 -# Pre-initialized warm-standby ME connections kept idle. -middle_proxy_warm_standby = 8 -# Ignore STUN/interface mismatch and keep ME enabled even if IP differs. -stun_iface_mismatch_ignore = false -# Keepalive padding frames - fl==4 -me_keepalive_enabled = true -me_keepalive_interval_secs = 25 # Period between keepalives -me_keepalive_jitter_secs = 5 # Jitter added to interval -me_keepalive_payload_random = true # Randomize 4-byte payload (vs zeros) -# Stagger extra ME connections on warmup to de-phase lifecycles. -me_warmup_stagger_enabled = true -me_warmup_step_delay_ms = 500 # Base delay between extra connects -me_warmup_step_jitter_ms = 300 # Jitter for warmup delay -# Reconnect policy knobs. -me_reconnect_max_concurrent_per_dc = 1 # Parallel reconnects per DC - EXPERIMENTAL! UNSTABLE! -me_reconnect_backoff_base_ms = 500 # Backoff start -me_reconnect_backoff_cap_ms = 30000 # Backoff cap -me_reconnect_fast_retry_count = 11 # Quick retries before backoff [general.modes] classic = false secure = false tls = true -[general.links] -show = "*" -# show = ["alice", "bob"] # Only show links for alice and bob -# show = "*" # Show links for all users -# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links -# public_port = 443 # Port for tg:// links (default: server.port) - -# === Network Parameters === -[network] -# Enable/disable families: true/false/auto(None) -ipv4 = true -ipv6 = false # UNSTABLE WITH ME -# prefer = 4 or 6 -prefer = 4 -multipath = false # EXPERIMENTAL! - -# === Server Binding === -[server] -port = 443 -listen_addr_ipv4 = "0.0.0.0" -listen_addr_ipv6 = "::" -# listen_unix_sock = "/var/run/telemt.sock" # Unix socket -# listen_unix_sock_perm = "0666" # Socket file permissions -# metrics_port = 9090 -# metrics_whitelist = [ -# "192.168.0.0/24", -# "172.16.0.0/12", -# "127.0.0.1/32", -# "::1/128" -#] - -# Listen on multiple interfaces/IPs - IPv4 -[[server.listeners]] -ip = "0.0.0.0" - -# Listen on multiple interfaces/IPs - IPv6 -[[server.listeners]] -ip = "::" - -# === Timeouts (in seconds) === -[timeouts] -client_handshake = 30 -tg_connect = 10 -client_keepalive = 60 -client_ack = 300 -# Quick ME reconnects for single-address DCs (count and per-attempt timeout, ms). -me_one_retry = 12 -me_one_timeout_ms = 1200 - # === Anti-Censorship & Masking === [censorship] tls_domain = "petrovich.ru" -mask = true -mask_port = 443 -# mask_host = "petrovich.ru" # Defaults to tls_domain if not set -# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host) -fake_cert_len = 2048 - -# === Access Control & Users === -[access] -replay_check_len = 65536 -replay_window_secs = 1800 -ignore_time_skew = false [access.users] # format: "username" = "32_hex_chars_secret" hello = "00000000000000000000000000000000" -# [access.user_max_tcp_conns] -# hello = 50 - -# [access.user_max_unique_ips] -# hello = 5 - -# [access.user_data_quota] -# hello = 1073741824 # 1 GB - -# === Upstreams & Routing === -[[upstreams]] -type = "direct" -enabled = true -weight = 10 - -# [[upstreams]] -# type = "socks5" -# address = "127.0.0.1:1080" -# enabled = false -# weight = 1 - -# === DC Address Overrides === -# [dc_overrides] -# "203" = "91.105.192.100:443" - ``` ### Advanced #### Adtag diff --git a/src/cli.rs b/src/cli.rs index 25d14f0..cf98121 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -213,6 +213,7 @@ listen_addr_ipv6 = "::" [[server.listeners]] ip = "0.0.0.0" +# reuse_allow = false # Set true only when intentionally running multiple telemt instances on same port [[server.listeners]] ip = "::" @@ -228,6 +229,7 @@ tls_domain = "{domain}" mask = true mask_port = 443 fake_cert_len = 2048 +tls_full_cert_ttl_secs = 90 [access] replay_check_len = 65536 diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 2dee3e0..90dd6f9 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -122,6 +122,10 @@ pub(crate) fn default_tls_new_session_tickets() -> u8 { 0 } +pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 { + 90 +} + pub(crate) fn default_server_hello_delay_min_ms() -> u64 { 0 } diff --git a/src/config/load.rs b/src/config/load.rs index ec8011a..827687a 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -227,6 +227,7 @@ impl ProxyConfig { announce: None, announce_ip: None, proxy_protocol: None, + reuse_allow: false, }); } if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { @@ -236,6 +237,7 @@ impl ProxyConfig { announce: None, announce_ip: None, proxy_protocol: None, + reuse_allow: false, }); } } diff --git a/src/config/types.rs b/src/config/types.rs index 6c54598..a303db8 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -74,8 +74,8 @@ pub struct ProxyModes { impl Default for ProxyModes { fn default() -> Self { Self { - classic: true, - secure: true, + classic: false, + secure: false, tls: true, } } @@ -118,7 +118,7 @@ impl Default for NetworkConfig { fn default() -> Self { Self { ipv4: true, - ipv6: None, + ipv6: Some(false), prefer: 4, multipath: false, stun_servers: default_stun_servers(), @@ -291,7 +291,7 @@ impl Default for GeneralConfig { middle_proxy_nat_stun: None, middle_proxy_nat_stun_servers: Vec::new(), middle_proxy_pool_size: default_pool_size(), - middle_proxy_warm_standby: 0, + middle_proxy_warm_standby: 8, me_keepalive_enabled: true, me_keepalive_interval_secs: default_keepalive_interval(), me_keepalive_jitter_secs: default_keepalive_jitter(), @@ -299,10 +299,10 @@ impl Default for GeneralConfig { me_warmup_stagger_enabled: true, me_warmup_step_delay_ms: default_warmup_step_delay_ms(), me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(), - me_reconnect_max_concurrent_per_dc: 1, + me_reconnect_max_concurrent_per_dc: 4, me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(), me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(), - me_reconnect_fast_retry_count: 1, + me_reconnect_fast_retry_count: 8, stun_iface_mismatch_ignore: false, unknown_dc_log_path: default_unknown_dc_log_path(), log_level: LogLevel::Normal, @@ -474,6 +474,12 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_tls_new_session_tickets")] pub tls_new_session_tickets: u8, + /// TTL in seconds for sending full certificate payload per client IP. + /// First client connection per (SNI domain, client IP) gets full cert payload. + /// Subsequent handshakes within TTL use compact cert metadata payload. + #[serde(default = "default_tls_full_cert_ttl_secs")] + pub tls_full_cert_ttl_secs: u64, + /// Enforce ALPN echo of client preference. #[serde(default = "default_alpn_enforce")] pub alpn_enforce: bool, @@ -494,6 +500,7 @@ impl Default for AntiCensorshipConfig { server_hello_delay_min_ms: default_server_hello_delay_min_ms(), server_hello_delay_max_ms: default_server_hello_delay_max_ms(), tls_new_session_tickets: default_tls_new_session_tickets(), + tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(), alpn_enforce: default_alpn_enforce(), } } @@ -603,6 +610,10 @@ pub struct ListenerConfig { /// Per-listener PROXY protocol override. When set, overrides global server.proxy_protocol. #[serde(default)] pub proxy_protocol: Option, + /// Allow multiple telemt instances to listen on the same IP:port (SO_REUSEPORT). + /// Default is false for safety. + #[serde(default)] + pub reuse_allow: bool, } // ============= ShowLink ============= diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index 1586e50..d3f6f55 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -55,6 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 { crc32fast::hash(data) } +/// CRC32C (Castagnoli) +pub fn crc32c(data: &[u8]) -> u32 { + crc32c::crc32c(data) +} + /// Build the exact prekey buffer used by Telegram Middle Proxy KDF. /// /// Returned buffer layout (IPv4): diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 40951c6..266a3cb 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -5,5 +5,8 @@ pub mod hash; pub mod random; pub use aes::{AesCtr, AesCbc}; -pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey}; +pub use hash::{ + build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, md5, sha1, sha256, + sha256_hmac, +}; pub use random::SecureRandom; diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 99aa5f3..f3432e0 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -49,19 +49,32 @@ impl SecureRandom { } } - /// Generate random bytes - pub fn bytes(&self, len: usize) -> Vec { + /// Fill a caller-provided buffer with random bytes. + pub fn fill(&self, out: &mut [u8]) { let mut inner = self.inner.lock(); const CHUNK_SIZE: usize = 512; - - while inner.buffer.len() < len { - let mut chunk = vec![0u8; CHUNK_SIZE]; - inner.rng.fill_bytes(&mut chunk); - inner.cipher.apply(&mut chunk); - inner.buffer.extend_from_slice(&chunk); + + let mut written = 0usize; + while written < out.len() { + if inner.buffer.is_empty() { + let mut chunk = vec![0u8; CHUNK_SIZE]; + inner.rng.fill_bytes(&mut chunk); + inner.cipher.apply(&mut chunk); + inner.buffer.extend_from_slice(&chunk); + } + + let take = (out.len() - written).min(inner.buffer.len()); + out[written..written + take].copy_from_slice(&inner.buffer[..take]); + inner.buffer.drain(..take); + written += take; } - - inner.buffer.drain(..len).collect() + } + + /// Generate random bytes + pub fn bytes(&self, len: usize) -> Vec { + let mut out = vec![0u8; len]; + self.fill(&mut out); + out } /// Generate random number in range [0, max) diff --git a/src/main.rs b/src/main.rs index a9b0e0a..61debb9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,7 +38,7 @@ use crate::stream::BufferPool; use crate::transport::middle_proxy::{ MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, }; -use crate::transport::{ListenOptions, UpstreamManager, create_listener}; +use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes}; use crate::tls_front::TlsFrontCache; fn parse_cli() -> (String, bool, Option) { @@ -265,7 +265,7 @@ async fn main() -> std::result::Result<(), Box> { } // Connection concurrency limit - let _max_connections = Arc::new(Semaphore::new(10_000)); + let max_connections = Arc::new(Semaphore::new(10_000)); if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me { warn!("No usable IP family for Middle Proxy detected; falling back to direct DC"); @@ -715,6 +715,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai continue; } let options = ListenOptions { + reuse_port: listener_conf.reuse_allow, ipv6_only: listener_conf.ip.is_ipv6(), ..Default::default() }; @@ -753,7 +754,33 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai listeners.push((listener, listener_proxy_protocol)); } Err(e) => { - error!("Failed to bind to {}: {}", addr, e); + if e.kind() == std::io::ErrorKind::AddrInUse { + let owners = find_listener_processes(addr); + if owners.is_empty() { + error!( + %addr, + "Failed to bind: address already in use (owner process unresolved)" + ); + } else { + for owner in owners { + error!( + %addr, + pid = owner.pid, + process = %owner.process, + "Failed to bind: address already in use" + ); + } + } + + if !listener_conf.reuse_allow { + error!( + %addr, + "reuse_allow=false; set [[server.listeners]].reuse_allow=true to allow multi-instance listening" + ); + } + } else { + error!("Failed to bind to {}: {}", addr, e); + } } } } @@ -817,6 +844,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let max_connections_unix = max_connections.clone(); tokio::spawn(async move { let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); @@ -824,6 +852,13 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai loop { match unix_listener.accept().await { Ok((stream, _)) => { + let permit = match max_connections_unix.clone().acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + error!("Connection limiter is closed"); + break; + } + }; let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16)); @@ -839,6 +874,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let proxy_protocol_enabled = config.server.proxy_protocol; tokio::spawn(async move { + let _permit = permit; if let Err(e) = crate::proxy::client::handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, @@ -906,11 +942,19 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let max_connections_tcp = max_connections.clone(); tokio::spawn(async move { loop { match listener.accept().await { Ok((stream, peer_addr)) => { + let permit = match max_connections_tcp.clone().acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + error!("Connection limiter is closed"); + break; + } + }; let config = config_rx.borrow_and_update().clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); @@ -923,6 +967,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let proxy_protocol_enabled = listener_proxy_protocol; tokio::spawn(async move { + let _permit = permit; if let Err(e) = ClientHandler::new( stream, peer_addr, diff --git a/src/metrics.rs b/src/metrics.rs index 940a0d8..e00091f 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -100,6 +100,14 @@ fn render_metrics(stats: &Stats) -> String { let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed()); + let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"); + let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); + let _ = writeln!(out, "telemt_me_keepalive_pong_total {}", stats.get_me_keepalive_pong()); + + let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"); + let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); + let _ = writeln!(out, "telemt_me_keepalive_timeout_total {}", stats.get_me_keepalive_timeout()); + let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts()); @@ -108,6 +116,30 @@ fn render_metrics(stats: &Stats) -> String { let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success()); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); + let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); + let _ = writeln!(out, "telemt_me_crc_mismatch_total {}", stats.get_me_crc_mismatch()); + + let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches"); + let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); + let _ = writeln!(out, "telemt_me_seq_mismatch_total {}", stats.get_me_seq_mismatch()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_no_conn_total {}", stats.get_me_route_drop_no_conn()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_channel_closed_total {}", stats.get_me_route_drop_channel_closed()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_queue_full_total {}", stats.get_me_route_drop_queue_full()); + + let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); + let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); + let _ = writeln!(out, "telemt_secure_padding_invalid_total {}", stats.get_secure_padding_invalid()); + let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 826f2b2..c930a1b 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -156,14 +156,28 @@ pub const MAX_TLS_RECORD_SIZE: usize = 16384; /// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; -/// Generate padding length for Secure Intermediate protocol. -/// Total (data + padding) must not be divisible by 4 per MTProto spec. -pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { - if data_len % 4 == 0 { - (rng.range(3) + 1) as usize // 1-3 - } else { - rng.range(4) as usize // 0-3 +/// Secure Intermediate payload is expected to be 4-byte aligned. +pub fn is_valid_secure_payload_len(data_len: usize) -> bool { + data_len % 4 == 0 +} + +/// Compute Secure Intermediate payload length from wire length. +/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary. +pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { + if wire_len < 4 { + return None; } + Some(wire_len - (wire_len % 4)) +} + +/// Generate padding length for Secure Intermediate protocol. +/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4. +pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { + debug_assert!( + is_valid_secure_payload_len(data_len), + "Secure payload must be 4-byte aligned, got {data_len}" + ); + (rng.range(3) + 1) as usize } // ============= Timeouts ============= @@ -297,6 +311,10 @@ pub mod rpc_flags { pub const FLAG_ABRIDGED: u32 = 0x40000000; pub const FLAG_QUICKACK: u32 = 0x80000000; } + + pub mod rpc_crypto_flags { + pub const USE_CRC32C: u32 = 0x800; + } pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; @@ -332,4 +350,43 @@ mod tests { assert_eq!(TG_DATACENTERS_V4.len(), 5); assert_eq!(TG_DATACENTERS_V6.len(), 5); } + + #[test] + fn secure_padding_never_produces_aligned_total() { + let rng = SecureRandom::new(); + for data_len in (0..1000).step_by(4) { + for _ in 0..100 { + let padding = secure_padding_len(data_len, &rng); + assert!( + padding <= 3, + "padding out of range: data_len={data_len}, padding={padding}" + ); + assert_ne!( + (data_len + padding) % 4, + 0, + "invariant violated: data_len={data_len}, padding={padding}, total={}", + data_len + padding + ); + } + } + } + + #[test] + fn secure_wire_len_roundtrip_for_aligned_payload() { + for payload_len in (4..4096).step_by(4) { + for padding in 0..=3usize { + let wire_len = payload_len + padding; + let recovered = secure_payload_len_from_wire_len(wire_len); + assert_eq!(recovered, Some(payload_len)); + } + } + } + + #[test] + fn secure_wire_len_rejects_too_short_frames() { + assert_eq!(secure_payload_len_from_wire_len(0), None); + assert_eq!(secure_payload_len_from_wire_len(1), None); + assert_eq!(secure_payload_len_from_wire_len(2), None); + assert_eq!(secure_payload_len_from_wire_len(3), None); + } } diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8d48c8b..750d839 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace, info}; use zeroize::Zeroize; @@ -108,11 +109,23 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { - Some(cache.get(&sni).await) + let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + if cache.contains_domain(&sni).await { + sni + } else { + config.censorship.tls_domain.clone() + } } else { - Some(cache.get(&config.censorship.tls_domain).await) - } + config.censorship.tls_domain.clone() + }; + let cached_entry = cache.get(&selected_domain).await; + let use_full_cert_payload = cache + .take_full_cert_budget_for_ip( + peer.ip(), + Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), + ) + .await; + Some((cached_entry, use_full_cert_payload)) } else { None } @@ -137,12 +150,13 @@ where None }; - let response = if let Some(cached_entry) = cached { + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, &validation.digest, &validation.session_id, &cached_entry, + use_full_cert_payload, rng, selected_alpn.clone(), config.censorship.tls_new_session_tickets, @@ -253,7 +267,11 @@ where let mode_ok = match proto_tag { ProtoTag::Secure => { - if is_tls { config.general.modes.tls } else { config.general.modes.secure } + if is_tls { + config.general.modes.tls || config.general.modes.secure + } else { + config.general.modes.secure || config.general.modes.tls + } } ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, }; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0735d01..3b98112 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tracing::{debug, info, trace, warn}; use crate::config::ProxyConfig; @@ -14,6 +14,11 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; +enum C2MeCommand { + Data { payload: Vec, flags: u32 }, + Close, +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -59,6 +64,30 @@ where let frame_limit = config.general.max_client_frame; + let (c2me_tx, mut c2me_rx) = mpsc::channel::(1024); + let me_pool_c2me = me_pool.clone(); + let c2me_sender = tokio::spawn(async move { + while let Some(cmd) = c2me_rx.recv().await { + match cmd { + C2MeCommand::Data { payload, flags } => { + me_pool_c2me.send_proxy_req( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + &payload, + flags, + ).await?; + } + C2MeCommand::Close => { + let _ = me_pool_c2me.send_close(conn_id).await; + return Ok(()); + } + } + } + Ok(()) + }); + let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); let mut me_rx_task = me_rx; let stats_clone = stats.clone(); @@ -66,6 +95,7 @@ where let user_clone = user.clone(); let me_writer = tokio::spawn(async move { let mut writer = crypto_writer; + let mut frame_buf = Vec::with_capacity(16 * 1024); loop { tokio::select! { msg = me_rx_task.recv() => { @@ -73,7 +103,44 @@ where Some(MeResponse::Data { flags, data }) => { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); stats_clone.add_user_octets_to(&user_clone, data.len() as u64); - write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?; + write_client_payload( + &mut writer, + proto_tag, + flags, + &data, + rng_clone.as_ref(), + &mut frame_buf, + ) + .await?; + + // Drain all immediately queued ME responses and flush once. + while let Ok(next) = me_rx_task.try_recv() { + match next { + MeResponse::Data { flags, data } => { + trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); + stats_clone.add_user_octets_to(&user_clone, data.len() as u64); + write_client_payload( + &mut writer, + proto_tag, + flags, + &data, + rng_clone.as_ref(), + &mut frame_buf, + ).await?; + } + MeResponse::Ack(confirm) => { + trace!(conn_id, confirm, "ME->C quickack (batched)"); + write_client_ack(&mut writer, proto_tag, confirm).await?; + } + MeResponse::Close => { + debug!(conn_id, "ME sent close (batched)"); + let _ = writer.flush().await; + return Ok(()); + } + } + } + + writer.flush().await.map_err(ProxyError::Io)?; } Some(MeResponse::Ack(confirm)) => { trace!(conn_id, confirm, "ME->C quickack"); @@ -81,6 +148,7 @@ where } Some(MeResponse::Close) => { debug!(conn_id, "ME sent close"); + let _ = writer.flush().await; return Ok(()); } None => { @@ -99,8 +167,16 @@ where let mut main_result: Result<()> = Ok(()); let mut client_closed = false; + let mut frame_counter: u64 = 0; loop { - match read_client_payload(&mut crypto_reader, proto_tag, frame_limit, &user).await { + match read_client_payload( + &mut crypto_reader, + proto_tag, + frame_limit, + &user, + &mut frame_counter, + &stats, + ).await { Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); stats.add_user_octets_from(&user, payload.len() as u64); @@ -111,22 +187,20 @@ where if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { flags |= RPC_FLAG_NOT_ENCRYPTED; } - if let Err(e) = me_pool.send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - &payload, - flags, - ).await { - main_result = Err(e); + // Keep client read loop lightweight: route heavy ME send path via a dedicated task. + if c2me_tx + .send(C2MeCommand::Data { payload, flags }) + .await + .is_err() + { + main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); break; } } Ok(None) => { debug!(conn_id, "Client EOF"); client_closed = true; - let _ = me_pool.send_close(conn_id).await; + let _ = c2me_tx.send(C2MeCommand::Close).await; break; } Err(e) => { @@ -136,6 +210,11 @@ where } } + drop(c2me_tx); + let c2me_result = c2me_sender + .await + .unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))); + let _ = stop_tx.send(()); let mut writer_result = me_writer .await @@ -151,10 +230,11 @@ where } } - let result = match (main_result, writer_result) { - (Ok(()), Ok(())) => Ok(()), - (Err(e), _) => Err(e), - (_, Err(e)) => Err(e), + let result = match (main_result, c2me_result, writer_result) { + (Ok(()), Ok(()), Ok(())) => Ok(()), + (Err(e), _, _) => Err(e), + (_, Err(e), _) => Err(e), + (_, _, Err(e)) => Err(e), }; debug!(user = %user, conn_id, "ME relay cleanup"); @@ -168,73 +248,123 @@ async fn read_client_payload( proto_tag: ProtoTag, max_frame: usize, user: &str, + frame_counter: &mut u64, + stats: &Stats, ) -> Result, bool)>> where R: AsyncRead + Unpin + Send + 'static, { - let (len, quickack) = match proto_tag { - ProtoTag::Abridged => { - let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + loop { + let (len, quickack, raw_len_bytes) = match proto_tag { + ProtoTag::Abridged => { + let mut first = [0u8; 1]; + match client_reader.read_exact(&mut first).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + + let quickack = (first[0] & 0x80) != 0; + let len_words = if (first[0] & 0x7f) == 0x7f { + let mut ext = [0u8; 3]; + client_reader + .read_exact(&mut ext) + .await + .map_err(ProxyError::Io)?; + u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize + } else { + (first[0] & 0x7f) as usize + }; + + let len = len_words + .checked_mul(4) + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; + (len, quickack, None) } - - let quickack = (first[0] & 0x80) != 0; - let len_words = if (first[0] & 0x7f) == 0x7f { - let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; - u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize - } else { - (first[0] & 0x7f) as usize - }; - - let len = len_words - .checked_mul(4) - .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; - (len, quickack) - } - ProtoTag::Intermediate | ProtoTag::Secure => { - let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len_buf = [0u8; 4]; + match client_reader.read_exact(&mut len_buf).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + let quickack = (len_buf[3] & 0x80) != 0; + ( + (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, + quickack, + Some(len_buf), + ) } - let quickack = (len_buf[3] & 0x80) != 0; - ((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack) + }; + + if len == 0 { + continue; } - }; - - if len > max_frame { - warn!( - user = %user, - raw_len = len, - raw_len_hex = format_args!("0x{:08x}", len), - proto = ?proto_tag, - "Frame too large — possible crypto desync or TLS record error" - ); - return Err(ProxyError::Proxy(format!("Frame too large: {len} (max {max_frame})"))); - } - - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; - - // Secure Intermediate: remove random padding (last len%4 bytes) - if proto_tag == ProtoTag::Secure { - let rem = len % 4; - if rem != 0 && payload.len() >= rem { - payload.truncate(len - rem); + if len < 4 && proto_tag != ProtoTag::Abridged { + warn!( + user = %user, + len, + proto = ?proto_tag, + "Frame too small — corrupt or probe" + ); + return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); } + + if len > max_frame { + let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes()); + let looks_like_tls = raw_len_bytes + .map(|b| b[0] == 0x16 && b[1] == 0x03) + .unwrap_or(false); + let looks_like_http = raw_len_bytes + .map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D')) + .unwrap_or(false); + warn!( + user = %user, + raw_len = len, + raw_len_hex = format_args!("0x{:08x}", len), + raw_bytes = format_args!( + "{:02x} {:02x} {:02x} {:02x}", + len_buf[0], len_buf[1], len_buf[2], len_buf[3] + ), + proto = ?proto_tag, + tls_like = looks_like_tls, + http_like = looks_like_http, + frames_ok = *frame_counter, + "Frame too large — crypto desync forensics" + ); + return Err(ProxyError::Proxy(format!( + "Frame too large: {len} (max {max_frame}), frames_ok={}", + *frame_counter + ))); + } + + let secure_payload_len = if proto_tag == ProtoTag::Secure { + match secure_payload_len_from_wire_len(len) { + Some(payload_len) => payload_len, + None => { + stats.increment_secure_padding_invalid(); + return Err(ProxyError::Proxy(format!( + "Invalid secure frame length: {len}" + ))); + } + } + } else { + len + }; + + let mut payload = vec![0u8; len]; + client_reader + .read_exact(&mut payload) + .await + .map_err(ProxyError::Io)?; + + // Secure Intermediate: strip validated trailing padding bytes. + if proto_tag == ProtoTag::Secure { + payload.truncate(secure_payload_len); + } + *frame_counter += 1; + return Ok(Some((payload, quickack))); } - Ok(Some((payload, quickack))) } async fn write_client_payload( @@ -243,6 +373,7 @@ async fn write_client_payload( flags: u32, data: &[u8], rng: &SecureRandom, + frame_buf: &mut Vec, ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, @@ -264,8 +395,12 @@ where if quickack { first |= 0x80; } + frame_buf.clear(); + frame_buf.reserve(1 + data.len()); + frame_buf.push(first); + frame_buf.extend_from_slice(data); client_writer - .write_all(&[first]) + .write_all(&frame_buf) .await .map_err(ProxyError::Io)?; } else if len_words < (1 << 24) { @@ -274,8 +409,12 @@ where first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); + frame_buf.clear(); + frame_buf.reserve(4 + data.len()); + frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); + frame_buf.extend_from_slice(data); client_writer - .write_all(&[first, lw[0], lw[1], lw[2]]) + .write_all(&frame_buf) .await .map_err(ProxyError::Io)?; } else { @@ -284,47 +423,40 @@ where data.len() ))); } - - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; } ProtoTag::Intermediate | ProtoTag::Secure => { let padding_len = if proto_tag == ProtoTag::Secure { + if !is_valid_secure_payload_len(data.len()) { + return Err(ProxyError::Proxy(format!( + "Secure payload must be 4-byte aligned, got {}", + data.len() + ))); + } secure_padding_len(data.len(), rng) } else { 0 }; - let mut len = (data.len() + padding_len) as u32; + let mut len_val = (data.len() + padding_len) as u32; if quickack { - len |= 0x8000_0000; + len_val |= 0x8000_0000; } - client_writer - .write_all(&len.to_le_bytes()) - .await - .map_err(ProxyError::Io)?; - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; + let total = 4 + data.len() + padding_len; + frame_buf.clear(); + frame_buf.reserve(total); + frame_buf.extend_from_slice(&len_val.to_le_bytes()); + frame_buf.extend_from_slice(data); if padding_len > 0 { - let pad = rng.bytes(padding_len); - client_writer - .write_all(&pad) - .await - .map_err(ProxyError::Io)?; + let start = frame_buf.len(); + frame_buf.resize(start + padding_len, 0); + rng.fill(&mut frame_buf[start..]); } + client_writer + .write_all(&frame_buf) + .await + .map_err(ProxyError::Io)?; } } - // Avoid unconditional per-frame flush (throughput killer on large downloads). - // Flush only when low-latency ack semantics are requested or when - // CryptoWriter has buffered pending ciphertext that must be drained. - if quickack || client_writer.has_pending() { - client_writer.flush().await.map_err(ProxyError::Io)?; - } - Ok(()) } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 38318cc..307da6d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -21,8 +21,16 @@ pub struct Stats { handshake_timeouts: AtomicU64, me_keepalive_sent: AtomicU64, me_keepalive_failed: AtomicU64, + me_keepalive_pong: AtomicU64, + me_keepalive_timeout: AtomicU64, me_reconnect_attempts: AtomicU64, me_reconnect_success: AtomicU64, + me_crc_mismatch: AtomicU64, + me_seq_mismatch: AtomicU64, + me_route_drop_no_conn: AtomicU64, + me_route_drop_channel_closed: AtomicU64, + me_route_drop_queue_full: AtomicU64, + secure_padding_invalid: AtomicU64, user_stats: DashMap, start_time: parking_lot::RwLock>, } @@ -49,14 +57,45 @@ impl Stats { pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_pong(&self) { self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_timeout(&self) { self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_timeout_by(&self, value: u64) { + self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed); + } pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_crc_mismatch(&self) { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_seq_mismatch(&self) { self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_route_drop_no_conn(&self) { self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_route_drop_channel_closed(&self) { + self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_me_route_drop_queue_full(&self) { + self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_secure_padding_invalid(&self) { + self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed); + } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } + pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) } + pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) } pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } + pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } + pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } + pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) } + pub fn get_me_route_drop_channel_closed(&self) -> u64 { + self.me_route_drop_channel_closed.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_queue_full(&self) -> u64 { + self.me_route_drop_queue_full.load(Ordering::Relaxed) + } + pub fn get_secure_padding_invalid(&self) -> u64 { + self.secure_padding_invalid.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { self.user_stats.entry(user.to_string()).or_default() @@ -70,7 +109,22 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { if let Some(stats) = self.user_stats.get(user) { - stats.curr_connects.fetch_sub(1, Ordering::Relaxed); + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match counter.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } } } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 30bcc95..6b90892 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -8,7 +8,9 @@ use std::io::{self, Error, ErrorKind}; use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; -use crate::protocol::constants::ProtoTag; +use crate::protocol::constants::{ + ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len, +}; use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; @@ -274,13 +276,13 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result io::R return Ok(()); } - // Generate padding to make length not divisible by 4 - let padding_len = if data.len() % 4 == 0 { - // Add 1-3 bytes to make it non-aligned - (rng.range(3) + 1) as usize - } else { - // Already non-aligned, can add 0-3 - rng.range(4) as usize - }; + if !is_valid_secure_payload_len(data.len()) { + return Err(Error::new( + ErrorKind::InvalidData, + format!("secure payload must be 4-byte aligned, got {}", data.len()), + )); + } + + // Generate padding that keeps total length non-divisible by 4. + let padding_len = secure_padding_len(data.len(), rng); let total_len = data.len() + padding_len; dst.reserve(4 + total_len); @@ -625,4 +628,4 @@ mod tests { let result = codec.decode(&mut buf); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 1ea6d1b..1726a06 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -232,11 +232,13 @@ impl SecureIntermediateFrameReader { let mut data = vec![0u8; len]; self.upstream.read_exact(&mut data).await?; - // Strip padding (not aligned to 4) - if len % 4 != 0 { - let actual_len = len - (len % 4); - data.truncate(actual_len); - } + let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid secure frame length: {len}"), + ) + })?; + data.truncate(payload_len); Ok((Bytes::from(data), meta)) } @@ -267,6 +269,13 @@ impl SecureIntermediateFrameWriter { return Ok(()); } + if !is_valid_secure_payload_len(data.len()) { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Secure payload must be 4-byte aligned, got {}", data.len()), + )); + } + // Add padding so total length is never divisible by 4 (MTProto Secure) let padding_len = secure_padding_len(data.len(), &self.rng); let padding = self.rng.bytes(padding_len); @@ -550,9 +559,7 @@ mod tests { writer.flush().await.unwrap(); let (received, _meta) = reader.read_frame().await.unwrap(); - // Received should have padding stripped to align to 4 - let expected_len = (data.len() / 4) * 4; - assert_eq!(received.len(), expected_len); + assert_eq!(received.len(), data.len()); } #[tokio::test] diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs index 103c2b1..15a97af 100644 --- a/src/tls_front/cache.rs +++ b/src/tls_front/cache.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; +use std::net::IpAddr; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::{SystemTime, Duration}; +use std::time::{Duration, Instant, SystemTime}; use tokio::sync::RwLock; use tokio::time::sleep; @@ -14,6 +15,7 @@ use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult}; pub struct TlsFrontCache { memory: RwLock>>, default: Arc, + full_cert_sent: RwLock>, disk_path: PathBuf, } @@ -31,6 +33,7 @@ impl TlsFrontCache { let default = Arc::new(CachedTlsData { server_hello_template: default_template, cert_info: None, + cert_payload: None, app_data_records_sizes: vec![default_len], total_app_data_len: default_len, fetched_at: SystemTime::now(), @@ -45,6 +48,7 @@ impl TlsFrontCache { Self { memory: RwLock::new(map), default, + full_cert_sent: RwLock::new(HashMap::new()), disk_path: disk_path.as_ref().to_path_buf(), } } @@ -54,6 +58,45 @@ impl TlsFrontCache { guard.get(sni).cloned().unwrap_or_else(|| self.default.clone()) } + pub async fn contains_domain(&self, domain: &str) -> bool { + self.memory.read().await.contains_key(domain) + } + + /// Returns true when full cert payload should be sent for client_ip + /// according to TTL policy. + pub async fn take_full_cert_budget_for_ip( + &self, + client_ip: IpAddr, + ttl: Duration, + ) -> bool { + if ttl.is_zero() { + self.full_cert_sent + .write() + .await + .insert(client_ip, Instant::now()); + return true; + } + + let now = Instant::now(); + let mut guard = self.full_cert_sent.write().await; + guard.retain(|_, seen_at| now.duration_since(*seen_at) < ttl); + + match guard.get_mut(&client_ip) { + Some(seen_at) => { + if now.duration_since(*seen_at) >= ttl { + *seen_at = now; + true + } else { + false + } + } + None => { + guard.insert(client_ip, now); + true + } + } + } + pub async fn set(&self, domain: &str, data: CachedTlsData) { let mut guard = self.memory.write().await; guard.insert(domain.to_string(), Arc::new(data)); @@ -142,6 +185,7 @@ impl TlsFrontCache { let data = CachedTlsData { server_hello_template: fetched.server_hello_parsed, cert_info: fetched.cert_info, + cert_payload: fetched.cert_payload, app_data_records_sizes: fetched.app_data_records_sizes.clone(), total_app_data_len: fetched.total_app_data_len, fetched_at: SystemTime::now(), @@ -161,3 +205,50 @@ impl TlsFrontCache { &self.disk_path } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_take_full_cert_budget_for_ip_uses_ttl() { + let cache = TlsFrontCache::new( + &["example.com".to_string()], + 1024, + "tlsfront-test-cache", + ); + let ip: IpAddr = "127.0.0.1".parse().expect("ip"); + let ttl = Duration::from_millis(80); + + assert!(cache + .take_full_cert_budget_for_ip(ip, ttl) + .await); + assert!(!cache + .take_full_cert_budget_for_ip(ip, ttl) + .await); + + tokio::time::sleep(Duration::from_millis(90)).await; + + assert!(cache + .take_full_cert_budget_for_ip(ip, ttl) + .await); + } + + #[tokio::test] + async fn test_take_full_cert_budget_for_ip_zero_ttl_always_allows_full_payload() { + let cache = TlsFrontCache::new( + &["example.com".to_string()], + 1024, + "tlsfront-test-cache", + ); + let ip: IpAddr = "127.0.0.1".parse().expect("ip"); + let ttl = Duration::ZERO; + + assert!(cache + .take_full_cert_budget_for_ip(ip, ttl) + .await); + assert!(cache + .take_full_cert_budget_for_ip(ip, ttl) + .await); + } +} diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 4d3e64d..25d2a8c 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -3,7 +3,7 @@ use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, }; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; -use crate::tls_front::types::CachedTlsData; +use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo}; const MIN_APP_DATA: usize = 64; const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 allows up to 2^14 + 256 @@ -27,12 +27,81 @@ fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec { .collect() } +fn app_data_body_capacity(sizes: &[usize]) -> usize { + sizes.iter().map(|&size| size.saturating_sub(17)).sum() +} + +fn ensure_payload_capacity(mut sizes: Vec, payload_len: usize) -> Vec { + if payload_len == 0 { + return sizes; + } + + let mut body_total = app_data_body_capacity(&sizes); + if body_total >= payload_len { + return sizes; + } + + if let Some(last) = sizes.last_mut() { + let free = MAX_APP_DATA.saturating_sub(*last); + let grow = free.min(payload_len - body_total); + *last += grow; + body_total += grow; + } + + while body_total < payload_len { + let remaining = payload_len - body_total; + let chunk = (remaining + 17).min(MAX_APP_DATA).max(MIN_APP_DATA); + sizes.push(chunk); + body_total += chunk.saturating_sub(17); + } + + sizes +} + +fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option> { + let mut fields = Vec::new(); + + if let Some(subject) = cert_info.subject_cn.as_deref() { + fields.push(format!("CN={subject}")); + } + if let Some(issuer) = cert_info.issuer_cn.as_deref() { + fields.push(format!("ISSUER={issuer}")); + } + if let Some(not_before) = cert_info.not_before_unix { + fields.push(format!("NB={not_before}")); + } + if let Some(not_after) = cert_info.not_after_unix { + fields.push(format!("NA={not_after}")); + } + if !cert_info.san_names.is_empty() { + let san = cert_info + .san_names + .iter() + .take(8) + .map(String::as_str) + .collect::>() + .join(","); + fields.push(format!("SAN={san}")); + } + + if fields.is_empty() { + return None; + } + + let mut payload = fields.join(";").into_bytes(); + if payload.len() > 512 { + payload.truncate(512); + } + Some(payload) +} + /// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata. pub fn build_emulated_server_hello( secret: &[u8], client_digest: &[u8; TLS_DIGEST_LEN], session_id: &[u8], cached: &CachedTlsData, + use_full_cert_payload: bool, rng: &SecureRandom, alpn: Option>, new_session_tickets: u8, @@ -109,21 +178,60 @@ pub fn build_emulated_server_hello( if sizes.is_empty() { sizes.push(cached.total_app_data_len.max(1024)); } - let sizes = jitter_and_clamp_sizes(&sizes, rng); + let mut sizes = jitter_and_clamp_sizes(&sizes, rng); + let compact_payload = cached + .cert_info + .as_ref() + .and_then(build_compact_cert_info_payload); + let selected_payload: Option<&[u8]> = if use_full_cert_payload { + cached + .cert_payload + .as_ref() + .map(|payload| payload.certificate_message.as_slice()) + .filter(|payload| !payload.is_empty()) + .or_else(|| compact_payload.as_deref()) + } else { + compact_payload.as_deref() + }; + + if let Some(payload) = selected_payload { + sizes = ensure_payload_capacity(sizes, payload.len()); + } let mut app_data = Vec::new(); + let mut payload_offset = 0usize; for size in sizes { let mut rec = Vec::with_capacity(5 + size); rec.push(TLS_RECORD_APPLICATION); rec.extend_from_slice(&TLS_VERSION); rec.extend_from_slice(&(size as u16).to_be_bytes()); - if size > 17 { - let body_len = size - 17; - rec.extend_from_slice(&rng.bytes(body_len)); - rec.push(0x16); // inner content type marker (handshake) - rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag + + if let Some(payload) = selected_payload { + if size > 17 { + let body_len = size - 17; + let remaining = payload.len().saturating_sub(payload_offset); + let copy_len = remaining.min(body_len); + if copy_len > 0 { + rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]); + payload_offset += copy_len; + } + if body_len > copy_len { + rec.extend_from_slice(&rng.bytes(body_len - copy_len)); + } + rec.push(0x16); // inner content type marker (handshake) + rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag + } else { + rec.extend_from_slice(&rng.bytes(size)); + } } else { - rec.extend_from_slice(&rng.bytes(size)); + if size > 17 { + let body_len = size - 17; + rec.extend_from_slice(&rng.bytes(body_len)); + rec.push(0x16); // inner content type marker (handshake) + rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag + } else { + rec.extend_from_slice(&rng.bytes(size)); + } } app_data.extend_from_slice(&rec); } @@ -158,3 +266,125 @@ pub fn build_emulated_server_hello( response } + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsCertPayload}; + + use super::build_emulated_server_hello; + use crate::crypto::SecureRandom; + use crate::protocol::constants::{ + TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, + }; + + fn first_app_data_payload(response: &[u8]) -> &[u8] { + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize; + &response[app_start + 5..app_start + 5 + app_len] + } + + fn make_cached(cert_payload: Option) -> CachedTlsData { + CachedTlsData { + server_hello_template: ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload, + app_data_records_sizes: vec![64], + total_app_data_len: 64, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + } + } + + #[test] + fn test_build_emulated_server_hello_uses_cached_cert_payload() { + let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]; + let cached = make_cached(Some(TlsCertPayload { + cert_chain_der: vec![vec![0x30, 0x01, 0x00]], + certificate_message: cert_msg.clone(), + })); + let rng = SecureRandom::new(); + let response = build_emulated_server_hello( + b"secret", + &[0x11; 32], + &[0x22; 16], + &cached, + true, + &rng, + None, + 0, + ); + + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + let app_start = ccs_start + 6; + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); + + let payload = first_app_data_payload(&response); + assert!(payload.starts_with(&cert_msg)); + } + + #[test] + fn test_build_emulated_server_hello_random_fallback_when_no_cert_payload() { + let cached = make_cached(None); + let rng = SecureRandom::new(); + let response = build_emulated_server_hello( + b"secret", + &[0x22; 32], + &[0x33; 16], + &cached, + true, + &rng, + None, + 0, + ); + + let payload = first_app_data_payload(&response); + assert!(payload.len() >= 64); + assert_eq!(payload[payload.len() - 17], 0x16); + } + + #[test] + fn test_build_emulated_server_hello_uses_compact_payload_after_first() { + let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]; + let mut cached = make_cached(Some(TlsCertPayload { + cert_chain_der: vec![vec![0x30, 0x01, 0x00]], + certificate_message: cert_msg, + })); + cached.cert_info = Some(crate::tls_front::types::ParsedCertificateInfo { + not_after_unix: Some(1_900_000_000), + not_before_unix: Some(1_700_000_000), + issuer_cn: Some("Issuer".to_string()), + subject_cn: Some("example.com".to_string()), + san_names: vec!["example.com".to_string(), "www.example.com".to_string()], + }); + + let rng = SecureRandom::new(); + let response = build_emulated_server_hello( + b"secret", + &[0x44; 32], + &[0x55; 16], + &cached, + false, + &rng, + None, + 0, + ); + + let payload = first_app_data_payload(&response); + assert!(payload.starts_with(b"CN=example.com")); + } +} diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 217b50d..4678ea3 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Duration; -use anyhow::{Context, Result, anyhow}; +use anyhow::{Result, anyhow}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::time::timeout; @@ -19,7 +19,13 @@ use x509_parser::certificate::X509Certificate; use crate::crypto::SecureRandom; use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE}; -use crate::tls_front::types::{ParsedServerHello, TlsExtension, TlsFetchResult, ParsedCertificateInfo}; +use crate::tls_front::types::{ + ParsedCertificateInfo, + ParsedServerHello, + TlsCertPayload, + TlsExtension, + TlsFetchResult, +}; /// No-op verifier: accept any certificate (we only need lengths and metadata). #[derive(Debug)] @@ -315,6 +321,46 @@ fn parse_cert_info(certs: &[CertificateDer<'static>]) -> Option Option<[u8; 3]> { + if value > 0x00ff_ffff { + return None; + } + Some([ + ((value >> 16) & 0xff) as u8, + ((value >> 8) & 0xff) as u8, + (value & 0xff) as u8, + ]) +} + +fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { + if cert_chain_der.is_empty() { + return None; + } + + let mut certificate_list = Vec::new(); + for cert in cert_chain_der { + if cert.is_empty() { + return None; + } + certificate_list.extend_from_slice(&u24_bytes(cert.len())?); + certificate_list.extend_from_slice(cert); + certificate_list.extend_from_slice(&0u16.to_be_bytes()); // cert_entry extensions + } + + // Certificate = context_len(1) + certificate_list_len(3) + entries + let body_len = 1usize + .checked_add(3)? + .checked_add(certificate_list.len())?; + + let mut message = Vec::with_capacity(4 + body_len); + message.push(0x0b); // HandshakeType::certificate + message.extend_from_slice(&u24_bytes(body_len)?); + message.push(0x00); // certificate_request_context length + message.extend_from_slice(&u24_bytes(certificate_list.len())?); + message.extend_from_slice(&certificate_list); + Some(message) +} + async fn fetch_via_raw_tls( host: &str, port: u16, @@ -368,26 +414,18 @@ async fn fetch_via_raw_tls( }, total_app_data_len, cert_info: None, + cert_payload: None, }) } -/// Fetch real TLS metadata for the given SNI: negotiated cipher and cert lengths. -pub async fn fetch_real_tls( +async fn fetch_via_rustls( host: &str, port: u16, sni: &str, connect_timeout: Duration, upstream: Option>, ) -> Result { - // Preferred path: raw TLS probe for accurate record sizing - match fetch_via_raw_tls(host, port, sni, connect_timeout).await { - Ok(res) => return Ok(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed, falling back to rustls"); - } - } - - // Fallback: rustls handshake to at least get certificate sizes + // rustls handshake path for certificate and basic negotiated metadata. let stream = if let Some(manager) = upstream { // Resolve host to SocketAddr if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { @@ -429,8 +467,19 @@ pub async fn fetch_real_tls( .peer_certificates() .map(|slice| slice.to_vec()) .unwrap_or_default(); + let cert_chain_der: Vec> = certs.iter().map(|c| c.as_ref().to_vec()).collect(); + let cert_payload = encode_tls13_certificate_message(&cert_chain_der).map(|certificate_message| { + TlsCertPayload { + cert_chain_der: cert_chain_der.clone(), + certificate_message, + } + }); - let total_cert_len: usize = certs.iter().map(|c| c.len()).sum::().max(1024); + let total_cert_len = cert_payload + .as_ref() + .map(|payload| payload.certificate_message.len()) + .unwrap_or_else(|| cert_chain_der.iter().map(Vec::len).sum::()) + .max(1024); let cert_info = parse_cert_info(&certs); // Heuristic: split across two records if large to mimic real servers a bit. @@ -453,6 +502,7 @@ pub async fn fetch_real_tls( sni = %sni, len = total_cert_len, cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)), + has_cert_payload = cert_payload.is_some(), "Fetched TLS metadata via rustls" ); @@ -461,5 +511,81 @@ pub async fn fetch_real_tls( app_data_records_sizes: app_data_records_sizes.clone(), total_app_data_len: app_data_records_sizes.iter().sum(), cert_info, + cert_payload, }) } + +/// Fetch real TLS metadata for the given SNI. +/// +/// Strategy: +/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. +/// 2) Fetch certificate chain via rustls to build cert payload. +/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +pub async fn fetch_real_tls( + host: &str, + port: u16, + sni: &str, + connect_timeout: Duration, + upstream: Option>, +) -> Result { + let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout).await { + Ok(res) => Some(res), + Err(e) => { + warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); + None + } + }; + + match fetch_via_rustls(host, port, sni, connect_timeout, upstream).await { + Ok(rustls_result) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls_result.cert_info; + raw.cert_payload = rustls_result.cert_payload; + debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls_result) + } + } + Err(e) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else { + Err(e) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::encode_tls13_certificate_message; + + fn read_u24(bytes: &[u8]) -> usize { + ((bytes[0] as usize) << 16) | ((bytes[1] as usize) << 8) | (bytes[2] as usize) + } + + #[test] + fn test_encode_tls13_certificate_message_single_cert() { + let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; + let message = encode_tls13_certificate_message(&[cert.clone()]).expect("message"); + + assert_eq!(message[0], 0x0b); + assert_eq!(read_u24(&message[1..4]), message.len() - 4); + assert_eq!(message[4], 0x00); + + let cert_list_len = read_u24(&message[5..8]); + assert_eq!(cert_list_len, cert.len() + 5); + + let cert_len = read_u24(&message[8..11]); + assert_eq!(cert_len, cert.len()); + assert_eq!(&message[11..11 + cert.len()], cert.as_slice()); + assert_eq!(&message[11 + cert.len()..13 + cert.len()], &[0x00, 0x00]); + } + + #[test] + fn test_encode_tls13_certificate_message_empty_chain() { + assert!(encode_tls13_certificate_message(&[]).is_none()); + } +} diff --git a/src/tls_front/types.rs b/src/tls_front/types.rs index eef1953..c411081 100644 --- a/src/tls_front/types.rs +++ b/src/tls_front/types.rs @@ -29,11 +29,23 @@ pub struct ParsedCertificateInfo { pub san_names: Vec, } +/// TLS certificate payload captured from profiled upstream. +/// +/// `certificate_message` stores an encoded TLS 1.3 Certificate handshake +/// message body that can be replayed as opaque ApplicationData bytes in FakeTLS. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsCertPayload { + pub cert_chain_der: Vec>, + pub certificate_message: Vec, +} + /// Cached data per SNI used by the emulator. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CachedTlsData { pub server_hello_template: ParsedServerHello, pub cert_info: Option, + #[serde(default)] + pub cert_payload: Option, pub app_data_records_sizes: Vec, pub total_app_data_len: usize, #[serde(default = "now_system_time", skip_serializing, skip_deserializing)] @@ -52,4 +64,5 @@ pub struct TlsFetchResult { pub app_data_records_sizes: Vec, pub total_app_data_len: usize, pub cert_info: Option, + pub cert_payload: Option, } diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 82f0960..6d83761 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -1,6 +1,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use crate::crypto::{AesCbc, crc32}; +use crate::crypto::{AesCbc, crc32, crc32c}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; @@ -8,17 +8,46 @@ use crate::protocol::constants::*; pub(crate) enum WriterCommand { Data(Vec), DataAndFlush(Vec), - Keepalive, Close, } -pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum RpcChecksumMode { + Crc32, + Crc32c, +} + +impl RpcChecksumMode { + pub(crate) fn from_handshake_flags(flags: u32) -> Self { + if (flags & rpc_crypto_flags::USE_CRC32C) != 0 { + Self::Crc32c + } else { + Self::Crc32 + } + } + + pub(crate) fn advertised_flags(self) -> u32 { + match self { + Self::Crc32 => 0, + Self::Crc32c => rpc_crypto_flags::USE_CRC32C, + } + } +} + +pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 { + match mode { + RpcChecksumMode::Crc32 => crc32(data), + RpcChecksumMode::Crc32c => crc32c(data), + } +} + +pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec { let total_len = (4 + 4 + payload.len() + 4) as u32; let mut frame = Vec::with_capacity(total_len as usize); frame.extend_from_slice(&total_len.to_le_bytes()); frame.extend_from_slice(&seq_no.to_le_bytes()); frame.extend_from_slice(payload); - let c = crc32(&frame); + let c = rpc_crc(crc_mode, &frame); frame.extend_from_slice(&c.to_le_bytes()); frame } @@ -45,7 +74,7 @@ pub(crate) async fn read_rpc_frame_plaintext( let crc_offset = total_len - 4; let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); - let actual_crc = crc32(&full[..crc_offset]); + let actual_crc = rpc_crc(RpcChecksumMode::Crc32, &full[..crc_offset]); if expected_crc != actual_crc { return Err(ProxyError::InvalidHandshake(format!( "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" @@ -95,24 +124,52 @@ pub(crate) fn build_handshake_payload( our_port: u16, peer_ip: [u8; 4], peer_port: u16, + flags: u32, ) -> [u8; 32] { let mut p = [0u8; 32]; p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); + p[4..8].copy_from_slice(&flags.to_le_bytes()); - // Keep C memory layout compatibility for PID IPv4 bytes. + // process_id sender_pid p[8..12].copy_from_slice(&our_ip); p[12..14].copy_from_slice(&our_port.to_le_bytes()); - let pid = (std::process::id() & 0xffff) as u16; - p[14..16].copy_from_slice(&pid.to_le_bytes()); + p[14..16].copy_from_slice(&process_pid16().to_le_bytes()); + p[16..20].copy_from_slice(&process_utime().to_le_bytes()); + + // process_id peer_pid + p[20..24].copy_from_slice(&peer_ip); + p[24..26].copy_from_slice(&peer_port.to_le_bytes()); + p[26..28].copy_from_slice(&0u16.to_le_bytes()); + p[28..32].copy_from_slice(&0u32.to_le_bytes()); + p +} + +pub(crate) fn parse_handshake_flags(payload: &[u8]) -> Result { + if payload.len() != 32 { + return Err(ProxyError::InvalidHandshake(format!( + "Bad handshake payload len: {}", + payload.len() + ))); + } + let hs_type = u32::from_le_bytes(payload[0..4].try_into().unwrap()); + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + Ok(u32::from_le_bytes(payload[4..8].try_into().unwrap())) +} + +fn process_pid16() -> u16 { + (std::process::id() & 0xffff) as u16 +} + +fn process_utime() -> u32 { let utime = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as u32; - p[16..20].copy_from_slice(&utime.to_le_bytes()); - - p[20..24].copy_from_slice(&peer_ip); - p[24..26].copy_from_slice(&peer_port.to_le_bytes()); - p + utime } pub(crate) fn cbc_encrypt_padded( @@ -160,12 +217,13 @@ pub(crate) struct RpcWriter { pub(crate) key: [u8; 32], pub(crate) iv: [u8; 16], pub(crate) seq_no: i32, + pub(crate) crc_mode: RpcChecksumMode, } impl RpcWriter { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { - let frame = build_rpc_frame(self.seq_no, payload); - self.seq_no += 1; + let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode); + self.seq_no = self.seq_no.wrapping_add(1); let pad = (16 - (frame.len() % 16)) % 16; let mut buf = frame; @@ -189,12 +247,4 @@ impl RpcWriter { self.send(payload).await?; self.writer.flush().await.map_err(ProxyError::Io) } - - pub(crate) async fn send_keepalive(&mut self, payload: [u8; 4]) -> Result<()> { - // Keepalive is a frame with fl == 4 and 4 bytes payload. - let mut frame = Vec::with_capacity(8); - frame.extend_from_slice(&4u32.to_le_bytes()); - frame.extend_from_slice(&payload); - self.send(&frame).await - } } diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 6814371..95a9d6e 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -18,13 +18,14 @@ use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_k use crate::error::{ProxyError, Result}; use crate::network::IpFamily; use crate::protocol::constants::{ - ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, - RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, + ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, + RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags, }; use super::codec::{ - build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, - cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, + RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame, + cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload, + read_rpc_frame_plaintext, rpc_crc, }; use super::wire::{extract_ip_material, IpMaterial}; use super::MePool; @@ -37,6 +38,7 @@ pub(crate) struct HandshakeOutput { pub read_iv: [u8; 16], pub write_key: [u8; 32], pub write_iv: [u8; 16], + pub crc_mode: RpcChecksumMode, pub handshake_ms: f64, } @@ -146,7 +148,7 @@ impl MePool { let ks = self.key_selector().await; let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); - let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); debug!( key_selector = format_args!("0x{ks:08x}"), @@ -284,8 +286,15 @@ impl MePool { srv_v6_opt.as_ref(), ); - let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); - let hs_frame = build_rpc_frame(-1, &hs_payload); + let requested_crc_mode = RpcChecksumMode::Crc32c; + let hs_payload = build_handshake_payload( + hs_our_ip, + local_addr.port(), + hs_peer_ip, + peer_addr.port(), + requested_crc_mode.advertised_flags(), + ); + let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32); if diag_level >= 1 { info!( write_key = %hex_dump(&wk), @@ -314,7 +323,7 @@ impl MePool { ); } - let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; if diag_level >= 1 { info!( hs_cipher = %hex_dump(&encrypted_hs), @@ -328,6 +337,7 @@ impl MePool { let mut enc_buf = BytesMut::with_capacity(256); let mut dec_buf = BytesMut::with_capacity(256); let mut read_iv = ri; + let mut negotiated_crc_mode = RpcChecksumMode::Crc32; let mut handshake_ok = false; while Instant::now() < deadline && !handshake_ok { @@ -375,17 +385,23 @@ impl MePool { let frame = dec_buf.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - let ac = crate::crypto::crc32(&frame[..pe]); + let ac = rpc_crc(RpcChecksumMode::Crc32, &frame[..pe]); if ec != ac { return Err(ProxyError::InvalidHandshake(format!( "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" ))); } - let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + let hs_payload = &frame[8..pe]; + if hs_payload.len() < 4 { + return Err(ProxyError::InvalidHandshake( + "Handshake payload too short".to_string(), + )); + } + let hs_type = u32::from_le_bytes(hs_payload[0..4].try_into().unwrap()); if hs_type == RPC_HANDSHAKE_ERROR_U32 { - let err_code = if frame.len() >= 16 { - i32::from_le_bytes(frame[12..16].try_into().unwrap()) + let err_code = if hs_payload.len() >= 8 { + i32::from_le_bytes(hs_payload[4..8].try_into().unwrap()) } else { -1 }; @@ -393,11 +409,21 @@ impl MePool { "ME rejected handshake (error={err_code})" ))); } - if hs_type != RPC_HANDSHAKE_U32 { + let hs_flags = parse_handshake_flags(hs_payload)?; + if hs_flags & 0xff != 0 { return Err(ProxyError::InvalidHandshake(format!( - "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + "Unsupported handshake flags: 0x{hs_flags:08x}" ))); } + negotiated_crc_mode = if (hs_flags & requested_crc_mode.advertised_flags()) != 0 { + RpcChecksumMode::from_handshake_flags(hs_flags) + } else if (hs_flags & rpc_crypto_flags::USE_CRC32C) != 0 { + return Err(ProxyError::InvalidHandshake(format!( + "Peer negotiated unsupported CRC flags: 0x{hs_flags:08x}" + ))); + } else { + RpcChecksumMode::Crc32 + }; handshake_ok = true; break; @@ -418,6 +444,7 @@ impl MePool { read_iv, write_key: wk, write_iv, + crc_mode: negotiated_crc_mode, handshake_ms, }) } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 3572671..8faeabf 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -17,13 +17,11 @@ use crate::network::IpFamily; use crate::protocol::constants::*; use super::ConnRegistry; -use super::registry::{BoundConn, ConnMeta}; +use super::registry::BoundConn; use super::codec::{RpcWriter, WriterCommand}; use super::reader::reader_loop; -use super::MeResponse; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; -const ME_KEEPALIVE_PAYLOAD_LEN: usize = 4; #[derive(Clone)] pub struct MeWriter { @@ -361,7 +359,6 @@ impl MePool { // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. if self.me_warmup_stagger_enabled { - let mut delay_ms = 0u64; for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { @@ -369,7 +366,7 @@ impl MePool { } let addr = SocketAddr::new(*ip, *port); let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); - delay_ms = delay_ms.saturating_add(self.me_warmup_step_delay.as_millis() as u64 + jitter); + let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; tokio::time::sleep(Duration::from_millis(delay_ms)).await; if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); @@ -418,14 +415,12 @@ impl MePool { let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); - let tx_for_keepalive = tx.clone(); - let keepalive_random = self.me_keepalive_payload_random; - let stats = self.stats.clone(); let mut rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, + crc_mode: hs.crc_mode, }; let cancel_wr = cancel.clone(); tokio::spawn(async move { @@ -439,21 +434,6 @@ impl MePool { Some(WriterCommand::DataAndFlush(payload)) => { if rpc_writer.send_and_flush(&payload).await.is_err() { break; } } - Some(WriterCommand::Keepalive) => { - let mut payload = [0u8; ME_KEEPALIVE_PAYLOAD_LEN]; - if keepalive_random { - rand::rng().fill(&mut payload); - } - match rpc_writer.send_keepalive(payload).await { - Ok(()) => { - stats.increment_me_keepalive_sent(); - } - Err(_) => { - stats.increment_me_keepalive_failed(); - break; - } - } - } Some(WriterCommand::Close) | None => break, } } @@ -471,12 +451,15 @@ impl MePool { }; self.writers.write().await.push(writer.clone()); self.conn_count.fetch_add(1, Ordering::Relaxed); - self.writer_available.notify_waiters(); + self.writer_available.notify_one(); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); let ping_tracker = self.ping_tracker.clone(); + let ping_tracker_reader = ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); + let stats_reader = self.stats.clone(); + let stats_ping = self.stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); let tx_ping = tx.clone(); @@ -489,19 +472,20 @@ impl MePool { let keepalive_jitter = self.me_keepalive_jitter; let cancel_reader_token = cancel.clone(); let cancel_ping_token = cancel_ping.clone(); - let cancel_keepalive_token = cancel.clone(); tokio::spawn(async move { let res = reader_loop( hs.rd, hs.read_key, hs.read_iv, + hs.crc_mode, reg.clone(), BytesMut::new(), BytesMut::new(), tx.clone(), - ping_tracker.clone(), + ping_tracker_reader, rtt_stats.clone(), + stats_reader, writer_id, degraded.clone(), cancel_reader_token.clone(), @@ -526,15 +510,40 @@ impl MePool { let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); - loop { + // Per-writer jittered start to avoid phase sync. + let startup_jitter = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(wait) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(startup_jitter) => {} + } + loop { + let wait = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + keepalive_interval + + Duration::from_millis( + rand::rng().random_range(0..=effective_jitter_ms as u64) + ) + } else { + let jitter = rand::rng() + .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(secs) + }; tokio::select! { _ = cancel_ping_token.cancelled() => { break; } - _ = tokio::time::sleep(Duration::from_secs(wait)) => {} + _ = tokio::time::sleep(wait) => {} } let sent_id = ping_id; let mut p = Vec::with_capacity(12); @@ -542,12 +551,19 @@ impl MePool { p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; + let before = tracker.len(); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); + stats_ping.increment_me_keepalive_sent(); if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { - debug!("Active ME ping failed, removing dead writer"); + stats_ping.increment_me_keepalive_failed(); + debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() { if cleanup_for_ping @@ -562,27 +578,6 @@ impl MePool { } }); - if keepalive_enabled { - let tx_keepalive = tx_for_keepalive; - let cancel_keepalive = cancel_keepalive_token; - tokio::spawn(async move { - // Per-writer jittered start to avoid phase sync. - let jitter_cap_ms = keepalive_interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64); - tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await; - loop { - tokio::select! { - _ = cancel_keepalive.cancelled() => break, - _ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {} - } - if tx_keepalive.send(WriterCommand::Keepalive).await.is_err() { - break; - } - } - }); - } - Ok(()) } @@ -619,15 +614,19 @@ impl MePool { } async fn remove_writer_only(&self, writer_id: u64) -> Vec { + let mut close_tx: Option> = None; { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); w.cancel.cancel(); - let _ = w.tx.send(WriterCommand::Close).await; + close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); } } + if let Some(tx) = close_tx { + let _ = tx.send(WriterCommand::Close).await; + } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index c22ed68..95bd0d8 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -10,31 +10,33 @@ use tokio::sync::{Mutex, mpsc}; use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; -use crate::crypto::{AesCbc, crc32}; +use crate::crypto::AesCbc; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; +use crate::stats::Stats; -use super::codec::WriterCommand; +use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; +use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; pub(crate) async fn reader_loop( mut rd: tokio::io::ReadHalf, dk: [u8; 32], mut div: [u8; 16], + crc_mode: RpcChecksumMode, reg: Arc, enc_leftover: BytesMut, mut dec: BytesMut, tx: mpsc::Sender, ping_tracker: Arc>>, rtt_stats: Arc>>, + stats: Arc, _writer_id: u64, degraded: Arc, cancel: CancellationToken, ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; - let mut crc_errors = 0u32; - let mut seq_mismatch = 0u32; loop { let mut tmp = [0u8; 16_384]; @@ -80,26 +82,28 @@ pub(crate) async fn reader_loop( let frame = dec.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - if crc32(&frame[..pe]) != ec { - warn!("CRC mismatch in data frame"); - crc_errors += 1; - if crc_errors > 3 { - return Err(ProxyError::Proxy("Too many CRC mismatches".into())); - } - continue; + let actual_crc = rpc_crc(crc_mode, &frame[..pe]); + if actual_crc != ec { + stats.increment_me_crc_mismatch(); + warn!( + frame_len = fl, + expected_crc = format_args!("0x{ec:08x}"), + actual_crc = format_args!("0x{actual_crc:08x}"), + "CRC mismatch — CBC crypto desync, aborting ME connection" + ); + return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into())); } let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); if seq_no != expected_seq { + stats.increment_me_seq_mismatch(); warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); - seq_mismatch += 1; - if seq_mismatch > 10 { - return Err(ProxyError::Proxy("Too many seq mismatches".into())); - } - expected_seq = seq_no.wrapping_add(1); - } else { - expected_seq = expected_seq.wrapping_add(1); + return Err(ProxyError::SeqNoMismatch { + expected: expected_seq, + got: seq_no, + }); } + expected_seq = expected_seq.wrapping_add(1); let payload = &frame[8..pe]; if payload.len() < 4 { @@ -116,7 +120,13 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let routed = reg.route(cid, MeResponse::Data { flags, data }).await; - if !routed { + if !matches!(routed, RouteResult::Routed) { + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(), + RouteResult::Routed => {} + } reg.unregister(cid).await; send_close_conn(&tx, cid).await; } @@ -126,7 +136,13 @@ pub(crate) async fn reader_loop( trace!(cid, cfm, "RPC_SIMPLE_ACK"); let routed = reg.route(cid, MeResponse::Ack(cfm)).await; - if !routed { + if !matches!(routed, RouteResult::Routed) { + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(), + RouteResult::Routed => {} + } reg.unregister(cid).await; send_close_conn(&tx, cid).await; } @@ -152,6 +168,7 @@ pub(crate) async fn reader_loop( } } else if pt == RPC_PONG_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); + stats.increment_me_keepalive_pong(); if let Some((sent, wid)) = { let mut guard = ping_tracker.lock().await; guard.remove(&ping_id) diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ab4f280..6a9250d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,13 +1,25 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; +use std::time::Duration; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{mpsc, RwLock}; +use tokio::sync::mpsc::error::TrySendError; use super::codec::WriterCommand; use super::MeResponse; +const ROUTE_CHANNEL_CAPACITY: usize = 4096; +const ROUTE_BACKPRESSURE_TIMEOUT: Duration = Duration::from_millis(25); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteResult { + Routed, + NoConn, + ChannelClosed, + QueueFull, +} + #[derive(Clone)] pub struct ConnMeta { pub target_dc: i16, @@ -64,7 +76,7 @@ impl ConnRegistry { pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); - let (tx, rx) = mpsc::channel(1024); + let (tx, rx) = mpsc::channel(ROUTE_CHANNEL_CAPACITY); self.inner.write().await.map.insert(id, tx); (id, rx) } @@ -83,12 +95,27 @@ impl ConnRegistry { None } - pub async fn route(&self, id: u64, resp: MeResponse) -> bool { - let inner = self.inner.read().await; - if let Some(tx) = inner.map.get(&id) { - tx.try_send(resp).is_ok() - } else { - false + pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { + let tx = { + let inner = self.inner.read().await; + inner.map.get(&id).cloned() + }; + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(resp)) => { + // Absorb short bursts without dropping/closing the session immediately. + match tokio::time::timeout(ROUTE_BACKPRESSURE_TIMEOUT, tx.send(resp)).await { + Ok(Ok(())) => RouteResult::Routed, + Ok(Err(_)) => RouteResult::ChannelClosed, + Err(_) => RouteResult::QueueFull, + } + } } } diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 627906d..2ebafea 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -62,6 +62,8 @@ impl MePool { let mut writers_snapshot = { let ws = self.writers.read().await; if ws.is_empty() { + // Create waiter before recovery attempts so notify_one permits are not missed. + let waiter = self.writer_available.notified(); drop(ws); for family in self.family_order() { let map = match family { @@ -72,13 +74,19 @@ impl MePool { for (ip, port) in addrs { let addr = SocketAddr::new(*ip, *port); if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { - self.writer_available.notify_waiters(); + self.writer_available.notify_one(); break; } } } } - if tokio::time::timeout(Duration::from_secs(3), self.writer_available.notified()).await.is_err() { + if !self.writers.read().await.is_empty() { + continue; + } + if tokio::time::timeout(Duration::from_secs(3), waiter).await.is_err() { + if !self.writers.read().await.is_empty() { + continue; + } return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into())); } continue; diff --git a/src/transport/socket.rs b/src/transport/socket.rs index f353c52..b41cfd1 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -1,5 +1,7 @@ //! TCP Socket Configuration +use std::collections::HashSet; +use std::fs; use std::io::Result; use std::net::{SocketAddr, IpAddr}; use std::time::Duration; @@ -234,6 +236,133 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result Vec { + #[cfg(target_os = "linux")] + { + find_listener_processes_linux(addr) + } + #[cfg(not(target_os = "linux"))] + { + let _ = addr; + Vec::new() + } +} + +#[cfg(target_os = "linux")] +fn find_listener_processes_linux(addr: SocketAddr) -> Vec { + let inodes = listening_inodes_for_port(addr); + if inodes.is_empty() { + return Vec::new(); + } + + let mut out = Vec::new(); + + let proc_entries = match fs::read_dir("/proc") { + Ok(entries) => entries, + Err(_) => return out, + }; + + for entry in proc_entries.flatten() { + let pid = match entry.file_name().to_string_lossy().parse::() { + Ok(pid) => pid, + Err(_) => continue, + }; + + let fd_dir = entry.path().join("fd"); + let fd_entries = match fs::read_dir(fd_dir) { + Ok(entries) => entries, + Err(_) => continue, + }; + + let mut matched = false; + for fd in fd_entries.flatten() { + let link_target = match fs::read_link(fd.path()) { + Ok(link) => link, + Err(_) => continue, + }; + + let link_str = link_target.to_string_lossy(); + let Some(rest) = link_str.strip_prefix("socket:[") else { + continue; + }; + let Some(inode_str) = rest.strip_suffix(']') else { + continue; + }; + let Ok(inode) = inode_str.parse::() else { + continue; + }; + + if inodes.contains(&inode) { + matched = true; + break; + } + } + + if matched { + let process = fs::read_to_string(entry.path().join("comm")) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "unknown".to_string()); + out.push(ListenerProcessInfo { pid, process }); + } + } + + out.sort_by_key(|p| p.pid); + out.dedup_by_key(|p| p.pid); + out +} + +#[cfg(target_os = "linux")] +fn listening_inodes_for_port(addr: SocketAddr) -> HashSet { + let path = match addr { + SocketAddr::V4(_) => "/proc/net/tcp", + SocketAddr::V6(_) => "/proc/net/tcp6", + }; + + let mut inodes = HashSet::new(); + let Ok(data) = fs::read_to_string(path) else { + return inodes; + }; + + for line in data.lines().skip(1) { + let cols: Vec<&str> = line.split_whitespace().collect(); + if cols.len() < 10 { + continue; + } + + // LISTEN state in /proc/net/tcp* + if cols[3] != "0A" { + continue; + } + + let Some(port_hex) = cols[1].split(':').nth(1) else { + continue; + }; + let Ok(port) = u16::from_str_radix(port_hex, 16) else { + continue; + }; + if port != addr.port() { + continue; + } + + if let Ok(inode) = cols[9].parse::() { + inodes.insert(inode); + } + } + + inodes +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 7d8927d..6dcc36f 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -24,6 +24,8 @@ const NUM_DCS: usize = 5; /// Timeout for individual DC ping attempt const DC_PING_TIMEOUT_SECS: u64 = 5; +/// Timeout for direct TG DC TCP connect readiness. +const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10; // ============= RTT Tracking ============= @@ -375,7 +377,16 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } @@ -383,6 +394,7 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks4 { address, interface, user_id } => { + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -405,7 +417,15 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: proxy_addr.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } @@ -416,8 +436,15 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring"); } - TcpStream::connect(address).await - .map_err(ProxyError::Io)? + match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: address.clone(), + }); + } + } }; // replace socks user_id with config.selected_scope, if set @@ -425,10 +452,19 @@ impl UpstreamManager { .filter(|s| !s.is_empty()); let _user_id: Option<&str> = scope.or(user_id.as_deref()); - connect_socks4(&mut stream, target, _user_id).await?; + match tokio::time::timeout(connect_timeout, connect_socks4(&mut stream, target, _user_id)).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } Ok(stream) }, UpstreamType::Socks5 { address, interface, username, password } => { + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -451,7 +487,15 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: proxy_addr.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } @@ -462,8 +506,15 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring"); } - TcpStream::connect(address).await - .map_err(ProxyError::Io)? + match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: address.clone(), + }); + } + } }; debug!(config = ?config, "Socks5 connection"); @@ -473,7 +524,20 @@ impl UpstreamManager { let _username: Option<&str> = scope.or(username.as_deref()); let _password: Option<&str> = scope.or(password.as_deref()); - connect_socks5(&mut stream, target, _username, _password).await?; + match tokio::time::timeout( + connect_timeout, + connect_socks5(&mut stream, target, _username, _password), + ) + .await + { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } Ok(stream) }, } diff --git a/tools/tlsearch.py b/tools/tlsearch.py new file mode 100644 index 0000000..32b42c7 --- /dev/null +++ b/tools/tlsearch.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +""" +TLS Profile Inspector + +Usage: + python3 tools/tlsearch.py + python3 tools/tlsearch.py tlsfront + python3 tools/tlsearch.py tlsfront/petrovich.ru.json + python3 tools/tlsearch.py tlsfront --only-current +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + + +TLS_VERSIONS = { + 0x0301: "TLS 1.0", + 0x0302: "TLS 1.1", + 0x0303: "TLS 1.2", + 0x0304: "TLS 1.3", +} + +EXT_NAMES = { + 0: "server_name", + 5: "status_request", + 10: "supported_groups", + 11: "ec_point_formats", + 13: "signature_algorithms", + 16: "alpn", + 18: "signed_certificate_timestamp", + 21: "padding", + 23: "extended_master_secret", + 35: "session_ticket", + 43: "supported_versions", + 45: "psk_key_exchange_modes", + 51: "key_share", +} + +CIPHER_NAMES = { + 0x1301: "TLS_AES_128_GCM_SHA256", + 0x1302: "TLS_AES_256_GCM_SHA384", + 0x1303: "TLS_CHACHA20_POLY1305_SHA256", + 0x1304: "TLS_AES_128_CCM_SHA256", + 0x1305: "TLS_AES_128_CCM_8_SHA256", + 0x009C: "TLS_RSA_WITH_AES_128_GCM_SHA256", + 0x009D: "TLS_RSA_WITH_AES_256_GCM_SHA384", + 0xC02F: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + 0xC030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + 0xCCA8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + 0xCCA9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", +} + +NAMED_GROUPS = { + 0x001D: "x25519", + 0x0017: "secp256r1", + 0x0018: "secp384r1", + 0x0019: "secp521r1", + 0x0100: "ffdhe2048", + 0x0101: "ffdhe3072", + 0x0102: "ffdhe4096", +} + + +@dataclass +class ProfileRecognition: + schema: str + mode: str + has_cert_info: bool + has_full_cert_payload: bool + cert_message_len: int + cert_chain_count: int + cert_chain_total_len: int + issues: list[str] + + +def to_hex(data: Iterable[int]) -> str: + return "".join(f"{b:02x}" for b in data) + + +def read_u16be(data: list[int], off: int = 0) -> int: + return (data[off] << 8) | data[off + 1] + + +def normalize_u8_list(value: Any) -> list[int]: + if not isinstance(value, list): + return [] + out: list[int] = [] + for item in value: + if isinstance(item, int) and 0 <= item <= 0xFF: + out.append(item) + else: + return [] + return out + + +def as_dict(value: Any) -> dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def as_int(value: Any, default: int = 0) -> int: + return value if isinstance(value, int) else default + + +def decode_version_pair(v: list[int]) -> str: + if len(v) != 2: + return f"invalid({v})" + ver = read_u16be(v) + return f"0x{ver:04x} ({TLS_VERSIONS.get(ver, 'unknown')})" + + +def decode_cipher_suite(v: list[int]) -> str: + if len(v) != 2: + return f"invalid({v})" + cs = read_u16be(v) + name = CIPHER_NAMES.get(cs, "unknown") + return f"0x{cs:04x} ({name})" + + +def decode_supported_versions(data: list[int]) -> str: + if len(data) == 2: + ver = read_u16be(data) + return f"selected=0x{ver:04x} ({TLS_VERSIONS.get(ver, 'unknown')})" + if not data: + return "empty" + if len(data) < 3: + return f"raw={to_hex(data)}" + vec_len = data[0] + versions: list[str] = [] + for i in range(1, min(1 + vec_len, len(data)), 2): + if i + 1 >= len(data): + break + ver = read_u16be(data, i) + versions.append(f"0x{ver:04x}({TLS_VERSIONS.get(ver, 'unknown')})") + return "offered=[" + ", ".join(versions) + "]" + + +def decode_key_share(data: list[int]) -> str: + if len(data) < 4: + return f"raw={to_hex(data)}" + group = read_u16be(data, 0) + key_len = read_u16be(data, 2) + key_hex = to_hex(data[4 : 4 + min(key_len, len(data) - 4)]) + gname = NAMED_GROUPS.get(group, "unknown_group") + return f"group=0x{group:04x}({gname}), key_len={key_len}, key={key_hex}" + + +def decode_alpn(data: list[int]) -> str: + if len(data) < 3: + return f"raw={to_hex(data)}" + total = read_u16be(data, 0) + pos = 2 + vals: list[str] = [] + limit = min(len(data), 2 + total) + while pos < limit: + ln = data[pos] + pos += 1 + if pos + ln > limit: + break + raw = bytes(data[pos : pos + ln]) + pos += ln + try: + vals.append(raw.decode("ascii")) + except UnicodeDecodeError: + vals.append(raw.hex()) + return "protocols=[" + ", ".join(vals) + "]" + + +def decode_extension(ext_type: int, data: list[int]) -> str: + if ext_type == 43: + return decode_supported_versions(data) + if ext_type == 51: + return decode_key_share(data) + if ext_type == 16: + return decode_alpn(data) + return f"raw={to_hex(data)}" + + +def ts_to_iso(ts: Any) -> str: + if not isinstance(ts, int): + return "-" + return dt.datetime.fromtimestamp(ts, tz=dt.timezone.utc).isoformat() + + +def recognize_profile(obj: dict[str, Any]) -> ProfileRecognition: + issues: list[str] = [] + + sh = as_dict(obj.get("server_hello_template")) + if not sh: + issues.append("missing server_hello_template") + + version = normalize_u8_list(sh.get("version")) + if version and len(version) != 2: + issues.append("server_hello_template.version must have 2 bytes") + + app_sizes = obj.get("app_data_records_sizes") + if not isinstance(app_sizes, list) or not app_sizes: + issues.append("missing app_data_records_sizes") + elif any((not isinstance(v, int) or v <= 0) for v in app_sizes): + issues.append("app_data_records_sizes contains invalid values") + + if not isinstance(obj.get("total_app_data_len"), int): + issues.append("missing total_app_data_len") + + cert_info = as_dict(obj.get("cert_info")) + has_cert_info = bool( + cert_info.get("subject_cn") + or cert_info.get("issuer_cn") + or cert_info.get("san_names") + or isinstance(cert_info.get("not_before_unix"), int) + or isinstance(cert_info.get("not_after_unix"), int) + ) + + cert_payload = as_dict(obj.get("cert_payload")) + cert_message_len = 0 + cert_chain_count = 0 + cert_chain_total_len = 0 + has_full_cert_payload = False + + if cert_payload: + cert_msg = normalize_u8_list(cert_payload.get("certificate_message")) + if not cert_msg: + issues.append("cert_payload.certificate_message is missing or invalid") + else: + cert_message_len = len(cert_msg) + + chain_raw = cert_payload.get("cert_chain_der") + if not isinstance(chain_raw, list): + issues.append("cert_payload.cert_chain_der is missing or invalid") + else: + for entry in chain_raw: + cert = normalize_u8_list(entry) + if cert: + cert_chain_count += 1 + cert_chain_total_len += len(cert) + else: + issues.append("cert_payload.cert_chain_der has invalid certificate entry") + break + + has_full_cert_payload = cert_message_len > 0 and cert_chain_count > 0 + elif obj.get("cert_payload") is not None: + issues.append("cert_payload is not an object") + + if has_full_cert_payload: + schema = "current" + mode = "full-cert-payload" + elif has_cert_info: + schema = "current-compact" + mode = "compact-cert-info" + else: + schema = "legacy" + mode = "random-fallback" + + if issues: + schema = f"{schema}+issues" + + return ProfileRecognition( + schema=schema, + mode=mode, + has_cert_info=has_cert_info, + has_full_cert_payload=has_full_cert_payload, + cert_message_len=cert_message_len, + cert_chain_count=cert_chain_count, + cert_chain_total_len=cert_chain_total_len, + issues=issues, + ) + + +def decode_profile(path: Path) -> tuple[str, ProfileRecognition]: + obj: dict[str, Any] = json.loads(path.read_text(encoding="utf-8")) + recognition = recognize_profile(obj) + + sh = as_dict(obj.get("server_hello_template")) + version = normalize_u8_list(sh.get("version")) + cipher = normalize_u8_list(sh.get("cipher_suite")) + random_bytes = normalize_u8_list(sh.get("random")) + session_id = normalize_u8_list(sh.get("session_id")) + + lines: list[str] = [] + lines.append(f"[{path.name}]") + lines.append(f" domain: {obj.get('domain', '-')}") + lines.append(f" profile.schema: {recognition.schema}") + lines.append(f" profile.mode: {recognition.mode}") + lines.append(f" profile.has_full_cert_payload: {recognition.has_full_cert_payload}") + lines.append(f" profile.has_cert_info: {recognition.has_cert_info}") + if recognition.has_full_cert_payload: + lines.append(f" profile.cert_message_len: {recognition.cert_message_len}") + lines.append(f" profile.cert_chain_count: {recognition.cert_chain_count}") + lines.append(f" profile.cert_chain_total_len: {recognition.cert_chain_total_len}") + if recognition.issues: + lines.append(" profile.issues:") + for issue in recognition.issues: + lines.append(f" - {issue}") + + lines.append(f" tls.version: {decode_version_pair(version)}") + lines.append(f" tls.cipher: {decode_cipher_suite(cipher)}") + lines.append(f" tls.compression: {sh.get('compression', '-')}") + lines.append(f" tls.random: {to_hex(random_bytes)}") + lines.append(f" tls.session_id_len: {len(session_id)}") + if session_id: + lines.append(f" tls.session_id: {to_hex(session_id)}") + + app_sizes = obj.get("app_data_records_sizes", []) + if isinstance(app_sizes, list): + lines.append(" app_data_records_sizes: " + ", ".join(str(v) for v in app_sizes)) + else: + lines.append(" app_data_records_sizes: -") + lines.append(f" total_app_data_len: {obj.get('total_app_data_len', '-')}") + + cert = as_dict(obj.get("cert_info")) + if cert: + lines.append(" cert_info:") + lines.append(f" subject_cn: {cert.get('subject_cn') or '-'}") + lines.append(f" issuer_cn: {cert.get('issuer_cn') or '-'}") + lines.append(f" not_before: {ts_to_iso(cert.get('not_before_unix'))}") + lines.append(f" not_after: {ts_to_iso(cert.get('not_after_unix'))}") + sans = cert.get("san_names") + if isinstance(sans, list) and sans: + lines.append(" san_names: " + ", ".join(str(v) for v in sans)) + else: + lines.append(" san_names: -") + else: + lines.append(" cert_info: -") + + exts = sh.get("extensions", []) + if not isinstance(exts, list): + exts = [] + lines.append(f" extensions[{len(exts)}]:") + for ext in exts: + ext_obj = as_dict(ext) + ext_type = as_int(ext_obj.get("ext_type"), -1) + data = normalize_u8_list(ext_obj.get("data")) + name = EXT_NAMES.get(ext_type, "unknown") + decoded = decode_extension(ext_type, data) + lines.append(f" - type={ext_type} ({name}), len={len(data)}: {decoded}") + + lines.append("") + return ("\n".join(lines), recognition) + + +def collect_files(input_path: Path) -> list[Path]: + if input_path.is_file(): + return [input_path] + return sorted(p for p in input_path.glob("*.json") if p.is_file()) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Decode TLS profile JSON files and recognize current schema." + ) + parser.add_argument( + "path", + nargs="?", + default="tlsfront", + help="Path to tlsfront directory or a single JSON file.", + ) + parser.add_argument( + "--only-current", + action="store_true", + help="Show only profiles recognized as current/full-cert-payload.", + ) + args = parser.parse_args() + + base = Path(args.path) + if not base.exists(): + print(f"Path not found: {base}") + return 1 + + files = collect_files(base) + if not files: + print(f"No JSON files found in: {base}") + return 1 + + printed = 0 + for path in files: + try: + rendered, recognition = decode_profile(path) + if args.only_current and recognition.schema != "current": + continue + print(rendered, end="") + printed += 1 + except Exception as e: # noqa: BLE001 + print(f"[{path.name}] decode error: {e}\n") + + if args.only_current and printed == 0: + print("No current profiles found.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())