Merge remote-tracking branch 'upstream/main'

This commit is contained in:
ivulit 2026-02-23 10:34:41 +03:00
commit 836eb6f777
No known key found for this signature in database
32 changed files with 2182 additions and 477 deletions

View File

@ -3,8 +3,8 @@ name: Release
on: on:
push: push:
tags: tags:
- '[0-9]+.[0-9]+.[0-9]+' # Matches tags like 3.0.0, 3.1.2, etc. - '[0-9]+.[0-9]+.[0-9]+' # Matches tags like 3.0.0, 3.1.2, etc.
workflow_dispatch: # Manual trigger from GitHub Actions UI workflow_dispatch: # Manual trigger from GitHub Actions UI
permissions: permissions:
contents: read contents: read
@ -84,6 +84,32 @@ jobs:
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.tar.gz target/${{ matrix.target }}/release/${{ matrix.asset_name }}.tar.gz
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.sha256 target/${{ matrix.target }}/release/${{ matrix.asset_name }}.sha256
build-docker-image:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.TOKEN_GH_DEPLOY }}
- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: ${{ github.ref }}
release: release:
name: Create Release name: Create Release
needs: build needs: build
@ -108,17 +134,17 @@ jobs:
# Extract version from tag (remove 'v' prefix if present) # Extract version from tag (remove 'v' prefix if present)
VERSION="${GITHUB_REF#refs/tags/}" VERSION="${GITHUB_REF#refs/tags/}"
VERSION="${VERSION#v}" VERSION="${VERSION#v}"
# Install cargo-edit for version bumping # Install cargo-edit for version bumping
cargo install cargo-edit cargo install cargo-edit
# Update Cargo.toml version # Update Cargo.toml version
cargo set-version "$VERSION" cargo set-version "$VERSION"
# Configure git # Configure git
git config user.name "github-actions[bot]" git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com" git config user.email "github-actions[bot]@users.noreply.github.com"
# Commit and push changes # Commit and push changes
#git add Cargo.toml Cargo.lock #git add Cargo.toml Cargo.lock
#git commit -m "chore: bump version to $VERSION" || echo "No changes to commit" #git commit -m "chore: bump version to $VERSION" || echo "No changes to commit"

View File

@ -1,6 +1,7 @@
## System Prompt — Production Rust Codebase: Modification and Architecture Guidelines ## System Prompt — Production Rust Codebase: Modification and Architecture Guidelines
You are a senior Rust systems engineer acting as a strict code reviewer and implementation partner. Your responses are precise, minimal, and architecturally sound. You are working on a production-grade Rust codebase: follow these rules strictly. You are a senior Rust Engineer and pricipal Rust Architect acting as a strict code reviewer and implementation partner.
Your responses are precise, minimal, and architecturally sound. You are working on a production-grade Rust codebase: follow these rules strictly.
--- ---
@ -32,6 +33,11 @@ The user can override this behavior with explicit commands:
- `"Make minimal changes"` — no coordinated fixes, narrowest possible diff. - `"Make minimal changes"` — no coordinated fixes, narrowest possible diff.
- `"Fix everything"` — apply all coordinated fixes and out-of-scope observations. - `"Fix everything"` — apply all coordinated fixes and out-of-scope observations.
### Core Rule
The codebase must never enter an invalid intermediate state.
No response may leave the repository in a condition that requires follow-up fixes.
--- ---
### 1. Comments and Documentation ### 1. Comments and Documentation
@ -131,16 +137,32 @@ You MUST:
- Document non-obvious logic with comments that describe *why*, not *what*. - Document non-obvious logic with comments that describe *why*, not *what*.
- Limit changes strictly to the requested scope (plus coordinated fixes per Section 0). - Limit changes strictly to the requested scope (plus coordinated fixes per Section 0).
- Keep all existing symbol names unless renaming is explicitly requested. - Keep all existing symbol names unless renaming is explicitly requested.
- Preserve global formatting as-is. - Preserve global formatting as-is
- Result every modification in a self-contained, compilable, runnable state of the codebase
You MUST NOT: You MUST NOT:
- Use placeholders: no `// ... rest of code`, no `// implement here`, no `/* TODO */` stubs that replace existing working code. Write full, working implementation. If the implementation is unclear, ask first. - Use placeholders: no `// ... rest of code`, no `// implement here`, no `/* TODO */` stubs that replace existing working code. Write full, working implementation. If the implementation is unclear, ask first
- Refactor code outside the requested scope. - Refactor code outside the requested scope
- Make speculative improvements. - Make speculative improvements
- Spawn multiple agents for EDITING
- Produce partial changes
- Introduce references to entities that are not yet implemented
- Leave TODO placeholders in production paths
Note: `todo!()` and `unimplemented!()` are allowed as idiomatic Rust markers for genuinely unfinished code paths. Note: `todo!()` and `unimplemented!()` are allowed as idiomatic Rust markers for genuinely unfinished code paths.
Every change must:
- compile,
- pass type checks,
- have no broken imports,
- preserve invariants,
- not rely on future patches.
If the task requires multiple phases:
- either implement all required phases,
- or explicitly refuse and explain missing dependencies.
--- ---
### 8. Decision Process for Complex Changes ### 8. Decision Process for Complex Changes
@ -160,6 +182,7 @@ When facing a non-trivial modification, follow this sequence:
- When provided with partial code, assume the rest of the codebase exists and functions correctly unless stated otherwise. - When provided with partial code, assume the rest of the codebase exists and functions correctly unless stated otherwise.
- Reference existing types, functions, and module structures by their actual names as shown in the provided code. - Reference existing types, functions, and module structures by their actual names as shown in the provided code.
- When the provided context is insufficient to make a safe change, request the missing context explicitly. - When the provided context is insufficient to make a safe change, request the missing context explicitly.
- Spawn multiple agents for SEARCHING information, code, functions
--- ---
@ -167,14 +190,14 @@ When facing a non-trivial modification, follow this sequence:
#### Language Policy #### Language Policy
- Code, comments, commit messages, documentation: **English**. - Code, comments, commit messages, documentation ONLY ON **English**!
- Reasoning and explanations in response text: **Russian**. - Reasoning and explanations in response text on language from promt
#### Response Structure #### Response Structure
Your response MUST consist of two sections: Your response MUST consist of two sections:
**Section 1: `## Reasoning` (in Russian)** **Section 1: `## Reasoning`**
- What needs to be done and why. - What needs to be done and why.
- Which files and modules are affected. - Which files and modules are affected.
@ -205,3 +228,183 @@ If the response exceeds the output limit:
2. List the files that will be provided in subsequent parts. 2. List the files that will be provided in subsequent parts.
3. Wait for user confirmation before continuing. 3. Wait for user confirmation before continuing.
4. No single file may be split across parts. 4. No single file may be split across parts.
## 11. Anti-LLM Degeneration Safeguards (Principal-Paranoid, Visionary)
This section exists to prevent common LLM failure modes: scope creep, semantic drift, cargo-cult refactors, performance regressions, contract breakage, and hidden behavior changes.
### 11.1 Non-Negotiable Invariants
- **No semantic drift:** Do not reinterpret requirements, rename concepts, or change meaning of existing terms.
- **No “helpful refactors”:** Any refactor not explicitly requested is forbidden.
- **No architectural drift:** Do not introduce new layers, patterns, abstractions, or “clean architecture” migrations unless requested.
- **No dependency drift:** Do not add crates, features, or versions unless explicitly requested.
- **No behavior drift:** If a change could alter runtime behavior, you MUST call it out explicitly in `## Reasoning` and justify it.
### 11.2 Minimal Surface Area Rule
- Touch the smallest number of files possible.
- Prefer local changes over cross-cutting edits.
- Do not “align style” across a file/module—only adjust the modified region.
- Do not reorder items, imports, or code unless required for correctness.
### 11.3 No Implicit Contract Changes
Contracts include:
- public APIs, trait bounds, visibility, error types, timeouts/retries, logging semantics, metrics semantics,
- protocol formats, framing, padding, keepalive cadence, state machine transitions,
- concurrency guarantees, cancellation behavior, backpressure behavior.
Rule:
- If you change a contract, you MUST update all dependents in the same patch AND document the contract delta explicitly.
### 11.4 Hot-Path Preservation (Performance Paranoia)
- Do not introduce extra allocations, cloning, or formatting in hot paths.
- Do not add logging/metrics on hot paths unless requested.
- Do not add new locks or broaden lock scope.
- Prefer `&str` / slices / borrowed data where the codebase already does so.
- Avoid `String` building for errors/logs if it changes current patterns.
If you cannot prove performance neutrality, label it as risk in `## Reasoning`.
### 11.5 Async / Concurrency Safety (Cancellation & Backpressure)
- No blocking calls inside async contexts.
- Preserve cancellation safety: do not introduce `await` between lock acquisition and critical invariants unless already present.
- Preserve backpressure: do not replace bounded channels with unbounded, do not remove flow control.
- Do not change task lifecycle semantics (spawn patterns, join handles, shutdown order) unless requested.
- Do not introduce `tokio::spawn` / background tasks unless explicitly requested.
### 11.6 Error Semantics Integrity
- Do not replace structured errors with generic strings.
- Do not widen/narrow error types or change error categories without explicit approval.
- Avoid introducing panics in production paths (`unwrap`, `expect`) unless the codebase already treats that path as impossible and documented.
### 11.7 “No New Abstractions” Default
Default stance:
- No new traits, generics, macros, builder patterns, type-level cleverness, or “frameworking”.
- If abstraction is necessary, prefer the smallest possible local helper (private function) and justify it.
### 11.8 Negative-Diff Protection
Avoid “diff inflation” patterns:
- mass edits,
- moving code between files,
- rewrapping long lines,
- rearranging module order,
- renaming for aesthetics.
If a diff becomes large, STOP and ask before proceeding.
### 11.9 Consistency with Existing Style (But Not Style Refactors)
- Follow existing conventions of the touched module (naming, error style, return patterns).
- Do not enforce global “best practices” that the codebase does not already use.
### 11.10 Two-Phase Safety Gate (Plan → Patch)
For non-trivial changes:
1) Provide a micro-plan (15 bullets): what files, what functions, what invariants, what risks.
2) Implement exactly that plan—no extra improvements.
### 11.11 Pre-Response Checklist (Hard Gate)
Before final output, verify internally:
- No unresolved symbols / broken imports.
- No partially updated call sites.
- No new public surface changes unless requested.
- No transitional states / TODO placeholders replacing working code.
- Changes are atomic: the repository remains buildable and runnable.
- Any behavior change is explicitly stated.
If any check fails: fix it before responding.
### 11.12 Truthfulness Policy (No Hallucinated Claims)
- Do not claim “this compiles” or “tests pass” unless you actually verified with the available tooling/context.
- If verification is not possible, state: “Not executed; reasoning-based consistency check only.”
### 11.13 Visionary Guardrail: Preserve Optionality
When multiple valid designs exist, prefer the one that:
- minimally constrains future evolution,
- preserves existing extension points,
- avoids locking the project into a new paradigm,
- keeps interfaces stable and implementation local.
Default to reversible changes.
### 11.14 Stop Conditions
STOP and ask targeted questions if:
- required context is missing,
- a change would cross module boundaries,
- a contract might change,
- concurrency/protocol invariants are unclear,
- the diff is growing beyond a minimal patch.
No guessing.
### 12. Invariant Preservation
You MUST explicitly preserve:
- Thread-safety guarantees (`Send` / `Sync` expectations).
- Memory safety assumptions (no hidden `unsafe` expansions).
- Lock ordering and deadlock invariants.
- State machine correctness (no new invalid transitions).
- Backward compatibility of serialized formats (if applicable).
If a change touches concurrency, networking, protocol logic, or state machines,
you MUST explain why existing invariants remain valid.
### 13. Error Handling Policy
- Do not replace structured errors with generic strings.
- Preserve existing error propagation semantics.
- Do not widen or narrow error types without approval.
- Avoid introducing panics in production paths.
- Prefer explicit error mapping over implicit conversions.
### 14. Test Safety
- Do not modify existing tests unless the task explicitly requires it.
- Do not weaken assertions.
- Preserve determinism in testable components.
### 15. Security Constraints
- Do not weaken cryptographic assumptions.
- Do not modify key derivation logic without explicit request.
- Do not change constant-time behavior.
- Do not introduce logging of secrets.
- Preserve TLS/MTProto protocol correctness.
### 16. Logging Policy
- Do not introduce excessive logging in hot paths.
- Do not log sensitive data.
- Preserve existing log levels and style.
### 17. Pre-Response Verification Checklist
Before producing the final answer, verify internally:
- The change compiles conceptually.
- No unresolved symbols exist.
- All modified call sites are updated.
- No accidental behavioral changes were introduced.
- Architectural boundaries remain intact.
### 18. Atomic Change Principle
Every patch must be **atomic and production-safe**.
* **Self-contained** — no dependency on future patches or unimplemented components.
* **Build-safe** — the project must compile successfully after the change.
* **Contract-consistent** — no partial interface or behavioral changes; all dependent code must be updated within the same patch.
* **No transitional states** — no placeholders, incomplete refactors, or temporary inconsistencies.
**Invariant:** After any single patch, the repository remains fully functional and buildable.

View File

@ -1,6 +1,6 @@
[package] [package]
name = "telemt" name = "telemt"
version = "3.0.8" version = "3.0.11"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
@ -20,6 +20,7 @@ sha1 = "0.10"
md-5 = "0.10" md-5 = "0.10"
hmac = "0.12" hmac = "0.12"
crc32fast = "1.4" crc32fast = "1.4"
crc32c = "0.6"
zeroize = { version = "1.8", features = ["derive"] } zeroize = { version = "1.8", features = ["derive"] }
# Network # Network

206
README.md
View File

@ -10,41 +10,77 @@
### 🇷🇺 RU ### 🇷🇺 RU
18 февраля мы опубликовали `telemt 3.0.3`, он имеет: #### Драфтинг LTS и текущие улучшения
- улучшенный механизм Middle-End Health Check С 21 февраля мы начали подготовку LTS-версии.
- высокоскоростное восстановление инициализации Middle-End
- меньше задержек на hot-path
- более корректную работу в Dualstack, а именно - IPv6 Middle-End
- аккуратное переподключение клиента без дрифта сессий между Middle-End
- автоматическая деградация на Direct-DC при массовой (>2 ME-DC-групп) недоступности Middle-End
- автодетект IP за NAT, при возможности - будет выполнен хендшейк с ME, при неудаче - автодеградация
- единственный известный специальный DC=203 уже добавлен в код: медиа загружаются с CDN в Direct-DC режиме
[Здесь вы можете найти релиз](https://github.com/telemt/telemt/releases/tag/3.0.3) Мы внимательно анализируем весь доступный фидбек.
Наша цель — сделать LTS-кандидаты максимально стабильными, тщательно отлаженными и готовыми к long-run и highload production-сценариям.
Если у вас есть компетенции в асинхронных сетевых приложениях, анализе трафика, реверс-инжиниринге или сетевых расследованиях - мы открыты к идеям и pull requests! ---
#### Улучшения от 23 февраля
23 февраля были внесены улучшения производительности в режимах **DC** и **Middle-End (ME)**, с акцентом на обратный канал (путь клиент → DC / ME).
Дополнительно реализован ряд изменений, направленных на повышение устойчивости системы:
- Смягчение сетевой нестабильности
- Повышение устойчивости к десинхронизации криптографии
- Снижение дрейфа сессий при неблагоприятных условиях
- Улучшение обработки ошибок в edge-case транспортных сценариях
Релиз:
[3.0.9](https://github.com/telemt/telemt/releases/tag/3.0.9)
---
Если у вас есть компетенции в:
- Асинхронных сетевых приложениях
- Анализе трафика
- Реверс-инжиниринге
- Сетевых расследованиях
Мы открыты к архитектурным предложениям, идеям и pull requests
</td> </td>
<td width="50%" valign="top"> <td width="50%" valign="top">
### 🇬🇧 EN ### 🇬🇧 EN
On February 18, we released `telemt 3.0.3`. This version introduces: #### LTS Drafting and Ongoing Improvements
- improved Middle-End Health Check method Starting February 21, we began drafting the upcoming LTS version.
- high-speed recovery of Middle-End init
- reduced latency on the hot path
- correct Dualstack support: proper handling of IPv6 Middle-End
- *clean* client reconnection without session "drift" between Middle-End
- automatic degradation to Direct-DC mode in case of large-scale (>2 ME-DC groups) Middle-End unavailability
- automatic public IP detection behind NAT; first - Middle-End handshake is performed, otherwise automatic degradation is applied
- known special DC=203 is now handled natively: media is delivered from the CDN via Direct-DC mode
[Release is available here](https://github.com/telemt/telemt/releases/tag/3.0.3) We are carefully reviewing and analyzing all available feedback.
The goal is to ensure that LTS candidates are максимально stable, thoroughly debugged, and ready for long-run and high-load production scenarios.
If you have expertise in asynchronous network applications, traffic analysis, reverse engineering, or network forensics - we welcome ideas and pull requests! ---
#### February 23 Improvements
On February 23, we introduced performance improvements for both **DC** and **Middle-End (ME)** modes, specifically optimizing the reverse channel (client → DC / ME data path).
Additionally, we implemented a set of robustness enhancements designed to:
- Mitigate network-related instability
- Improve resilience against cryptographic desynchronization
- Reduce session drift under adverse conditions
- Improve error handling in edge-case transport scenarios
Release:
[3.0.9](https://github.com/telemt/telemt/releases/tag/3.0.9)
---
If you have expertise in:
- Asynchronous network applications
- Traffic analysis
- Reverse engineering
- Network forensics
We welcome ideas, architectural feedback, and pull requests.
</td> </td>
</tr> </tr>
</table> </table>
@ -178,147 +214,21 @@ then Ctrl+X -> Y -> Enter to save
```toml ```toml
# === General Settings === # === General Settings ===
[general] [general]
fast_mode = true
use_middle_proxy = true
# ad_tag = "00000000000000000000000000000000" # ad_tag = "00000000000000000000000000000000"
# Path to proxy-secret binary (auto-downloaded if missing).
proxy_secret_path = "proxy-secret"
# disable_colors = false # Disable colored output in logs (useful for files/systemd)
# === Log Level ===
# Log level: debug | verbose | normal | silent
# Can be overridden with --silent or --log-level CLI flags
# RUST_LOG env var takes absolute priority over all of these
log_level = "normal"
# === Middle Proxy - ME ===
# Public IP override for ME KDF when behind NAT; leave unset to auto-detect.
# middle_proxy_nat_ip = "203.0.113.10"
# Enable STUN probing to discover public IP:port for ME.
middle_proxy_nat_probe = true
# Primary STUN server (host:port); defaults to Telegram STUN when empty.
middle_proxy_nat_stun = "stun.l.google.com:19302"
# Optional fallback STUN servers list.
middle_proxy_nat_stun_servers = ["stun1.l.google.com:19302", "stun2.l.google.com:19302"]
# Desired number of concurrent ME writers in pool.
middle_proxy_pool_size = 16
# Pre-initialized warm-standby ME connections kept idle.
middle_proxy_warm_standby = 8
# Ignore STUN/interface mismatch and keep ME enabled even if IP differs.
stun_iface_mismatch_ignore = false
# Keepalive padding frames - fl==4
me_keepalive_enabled = true
me_keepalive_interval_secs = 25 # Period between keepalives
me_keepalive_jitter_secs = 5 # Jitter added to interval
me_keepalive_payload_random = true # Randomize 4-byte payload (vs zeros)
# Stagger extra ME connections on warmup to de-phase lifecycles.
me_warmup_stagger_enabled = true
me_warmup_step_delay_ms = 500 # Base delay between extra connects
me_warmup_step_jitter_ms = 300 # Jitter for warmup delay
# Reconnect policy knobs.
me_reconnect_max_concurrent_per_dc = 1 # Parallel reconnects per DC - EXPERIMENTAL! UNSTABLE!
me_reconnect_backoff_base_ms = 500 # Backoff start
me_reconnect_backoff_cap_ms = 30000 # Backoff cap
me_reconnect_fast_retry_count = 11 # Quick retries before backoff
[general.modes] [general.modes]
classic = false classic = false
secure = false secure = false
tls = true tls = true
[general.links]
show = "*"
# show = ["alice", "bob"] # Only show links for alice and bob
# show = "*" # Show links for all users
# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links
# public_port = 443 # Port for tg:// links (default: server.port)
# === Network Parameters ===
[network]
# Enable/disable families: true/false/auto(None)
ipv4 = true
ipv6 = false # UNSTABLE WITH ME
# prefer = 4 or 6
prefer = 4
multipath = false # EXPERIMENTAL!
# === Server Binding ===
[server]
port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# listen_unix_sock = "/var/run/telemt.sock" # Unix socket
# listen_unix_sock_perm = "0666" # Socket file permissions
# metrics_port = 9090
# metrics_whitelist = [
# "192.168.0.0/24",
# "172.16.0.0/12",
# "127.0.0.1/32",
# "::1/128"
#]
# Listen on multiple interfaces/IPs - IPv4
[[server.listeners]]
ip = "0.0.0.0"
# Listen on multiple interfaces/IPs - IPv6
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 30
tg_connect = 10
client_keepalive = 60
client_ack = 300
# Quick ME reconnects for single-address DCs (count and per-attempt timeout, ms).
me_one_retry = 12
me_one_timeout_ms = 1200
# === Anti-Censorship & Masking === # === Anti-Censorship & Masking ===
[censorship] [censorship]
tls_domain = "petrovich.ru" tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
fake_cert_len = 2048
# === Access Control & Users ===
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users] [access.users]
# format: "username" = "32_hex_chars_secret" # format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000" hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_max_unique_ips]
# hello = 5
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:1080"
# enabled = false
# weight = 1
# === DC Address Overrides ===
# [dc_overrides]
# "203" = "91.105.192.100:443"
``` ```
### Advanced ### Advanced
#### Adtag #### Adtag

View File

@ -213,6 +213,7 @@ listen_addr_ipv6 = "::"
[[server.listeners]] [[server.listeners]]
ip = "0.0.0.0" ip = "0.0.0.0"
# reuse_allow = false # Set true only when intentionally running multiple telemt instances on same port
[[server.listeners]] [[server.listeners]]
ip = "::" ip = "::"
@ -228,6 +229,7 @@ tls_domain = "{domain}"
mask = true mask = true
mask_port = 443 mask_port = 443
fake_cert_len = 2048 fake_cert_len = 2048
tls_full_cert_ttl_secs = 90
[access] [access]
replay_check_len = 65536 replay_check_len = 65536

View File

@ -122,6 +122,10 @@ pub(crate) fn default_tls_new_session_tickets() -> u8 {
0 0
} }
pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
90
}
pub(crate) fn default_server_hello_delay_min_ms() -> u64 { pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
0 0
} }

View File

@ -227,6 +227,7 @@ impl ProxyConfig {
announce: None, announce: None,
announce_ip: None, announce_ip: None,
proxy_protocol: None, proxy_protocol: None,
reuse_allow: false,
}); });
} }
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
@ -236,6 +237,7 @@ impl ProxyConfig {
announce: None, announce: None,
announce_ip: None, announce_ip: None,
proxy_protocol: None, proxy_protocol: None,
reuse_allow: false,
}); });
} }
} }

View File

@ -74,8 +74,8 @@ pub struct ProxyModes {
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { Self {
classic: true, classic: false,
secure: true, secure: false,
tls: true, tls: true,
} }
} }
@ -118,7 +118,7 @@ impl Default for NetworkConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
ipv4: true, ipv4: true,
ipv6: None, ipv6: Some(false),
prefer: 4, prefer: 4,
multipath: false, multipath: false,
stun_servers: default_stun_servers(), stun_servers: default_stun_servers(),
@ -291,7 +291,7 @@ impl Default for GeneralConfig {
middle_proxy_nat_stun: None, middle_proxy_nat_stun: None,
middle_proxy_nat_stun_servers: Vec::new(), middle_proxy_nat_stun_servers: Vec::new(),
middle_proxy_pool_size: default_pool_size(), middle_proxy_pool_size: default_pool_size(),
middle_proxy_warm_standby: 0, middle_proxy_warm_standby: 8,
me_keepalive_enabled: true, me_keepalive_enabled: true,
me_keepalive_interval_secs: default_keepalive_interval(), me_keepalive_interval_secs: default_keepalive_interval(),
me_keepalive_jitter_secs: default_keepalive_jitter(), me_keepalive_jitter_secs: default_keepalive_jitter(),
@ -299,10 +299,10 @@ impl Default for GeneralConfig {
me_warmup_stagger_enabled: true, me_warmup_stagger_enabled: true,
me_warmup_step_delay_ms: default_warmup_step_delay_ms(), me_warmup_step_delay_ms: default_warmup_step_delay_ms(),
me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(), me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(),
me_reconnect_max_concurrent_per_dc: 1, me_reconnect_max_concurrent_per_dc: 4,
me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(), me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(),
me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(), me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(),
me_reconnect_fast_retry_count: 1, me_reconnect_fast_retry_count: 8,
stun_iface_mismatch_ignore: false, stun_iface_mismatch_ignore: false,
unknown_dc_log_path: default_unknown_dc_log_path(), unknown_dc_log_path: default_unknown_dc_log_path(),
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
@ -474,6 +474,12 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_new_session_tickets")] #[serde(default = "default_tls_new_session_tickets")]
pub tls_new_session_tickets: u8, pub tls_new_session_tickets: u8,
/// TTL in seconds for sending full certificate payload per client IP.
/// First client connection per (SNI domain, client IP) gets full cert payload.
/// Subsequent handshakes within TTL use compact cert metadata payload.
#[serde(default = "default_tls_full_cert_ttl_secs")]
pub tls_full_cert_ttl_secs: u64,
/// Enforce ALPN echo of client preference. /// Enforce ALPN echo of client preference.
#[serde(default = "default_alpn_enforce")] #[serde(default = "default_alpn_enforce")]
pub alpn_enforce: bool, pub alpn_enforce: bool,
@ -494,6 +500,7 @@ impl Default for AntiCensorshipConfig {
server_hello_delay_min_ms: default_server_hello_delay_min_ms(), server_hello_delay_min_ms: default_server_hello_delay_min_ms(),
server_hello_delay_max_ms: default_server_hello_delay_max_ms(), server_hello_delay_max_ms: default_server_hello_delay_max_ms(),
tls_new_session_tickets: default_tls_new_session_tickets(), tls_new_session_tickets: default_tls_new_session_tickets(),
tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(),
alpn_enforce: default_alpn_enforce(), alpn_enforce: default_alpn_enforce(),
} }
} }
@ -603,6 +610,10 @@ pub struct ListenerConfig {
/// Per-listener PROXY protocol override. When set, overrides global server.proxy_protocol. /// Per-listener PROXY protocol override. When set, overrides global server.proxy_protocol.
#[serde(default)] #[serde(default)]
pub proxy_protocol: Option<bool>, pub proxy_protocol: Option<bool>,
/// Allow multiple telemt instances to listen on the same IP:port (SO_REUSEPORT).
/// Default is false for safety.
#[serde(default)]
pub reuse_allow: bool,
} }
// ============= ShowLink ============= // ============= ShowLink =============

View File

@ -55,6 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// CRC32C (Castagnoli)
pub fn crc32c(data: &[u8]) -> u32 {
crc32c::crc32c(data)
}
/// Build the exact prekey buffer used by Telegram Middle Proxy KDF. /// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
/// ///
/// Returned buffer layout (IPv4): /// Returned buffer layout (IPv4):

View File

@ -5,5 +5,8 @@ pub mod hash;
pub mod random; pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey}; pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, md5, sha1, sha256,
sha256_hmac,
};
pub use random::SecureRandom; pub use random::SecureRandom;

View File

@ -49,19 +49,32 @@ impl SecureRandom {
} }
} }
/// Generate random bytes /// Fill a caller-provided buffer with random bytes.
pub fn bytes(&self, len: usize) -> Vec<u8> { pub fn fill(&self, out: &mut [u8]) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
const CHUNK_SIZE: usize = 512; const CHUNK_SIZE: usize = 512;
while inner.buffer.len() < len { let mut written = 0usize;
let mut chunk = vec![0u8; CHUNK_SIZE]; while written < out.len() {
inner.rng.fill_bytes(&mut chunk); if inner.buffer.is_empty() {
inner.cipher.apply(&mut chunk); let mut chunk = vec![0u8; CHUNK_SIZE];
inner.buffer.extend_from_slice(&chunk); inner.rng.fill_bytes(&mut chunk);
inner.cipher.apply(&mut chunk);
inner.buffer.extend_from_slice(&chunk);
}
let take = (out.len() - written).min(inner.buffer.len());
out[written..written + take].copy_from_slice(&inner.buffer[..take]);
inner.buffer.drain(..take);
written += take;
} }
}
inner.buffer.drain(..len).collect()
/// Generate random bytes
pub fn bytes(&self, len: usize) -> Vec<u8> {
let mut out = vec![0u8; len];
self.fill(&mut out);
out
} }
/// Generate random number in range [0, max) /// Generate random number in range [0, max)

View File

@ -38,7 +38,7 @@ use crate::stream::BufferPool;
use crate::transport::middle_proxy::{ use crate::transport::middle_proxy::{
MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line,
}; };
use crate::transport::{ListenOptions, UpstreamManager, create_listener}; use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes};
use crate::tls_front::TlsFrontCache; use crate::tls_front::TlsFrontCache;
fn parse_cli() -> (String, bool, Option<String>) { fn parse_cli() -> (String, bool, Option<String>) {
@ -265,7 +265,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
// Connection concurrency limit // Connection concurrency limit
let _max_connections = Arc::new(Semaphore::new(10_000)); let max_connections = Arc::new(Semaphore::new(10_000));
if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me { if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me {
warn!("No usable IP family for Middle Proxy detected; falling back to direct DC"); warn!("No usable IP family for Middle Proxy detected; falling back to direct DC");
@ -715,6 +715,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
continue; continue;
} }
let options = ListenOptions { let options = ListenOptions {
reuse_port: listener_conf.reuse_allow,
ipv6_only: listener_conf.ip.is_ipv6(), ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default() ..Default::default()
}; };
@ -753,7 +754,33 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
listeners.push((listener, listener_proxy_protocol)); listeners.push((listener, listener_proxy_protocol));
} }
Err(e) => { Err(e) => {
error!("Failed to bind to {}: {}", addr, e); if e.kind() == std::io::ErrorKind::AddrInUse {
let owners = find_listener_processes(addr);
if owners.is_empty() {
error!(
%addr,
"Failed to bind: address already in use (owner process unresolved)"
);
} else {
for owner in owners {
error!(
%addr,
pid = owner.pid,
process = %owner.process,
"Failed to bind: address already in use"
);
}
}
if !listener_conf.reuse_allow {
error!(
%addr,
"reuse_allow=false; set [[server.listeners]].reuse_allow=true to allow multi-instance listening"
);
}
} else {
error!("Failed to bind to {}: {}", addr, e);
}
} }
} }
} }
@ -817,6 +844,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let max_connections_unix = max_connections.clone();
tokio::spawn(async move { tokio::spawn(async move {
let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
@ -824,6 +852,13 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
loop { loop {
match unix_listener.accept().await { match unix_listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let permit = match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
}
};
let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16)); let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16));
@ -839,6 +874,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
let proxy_protocol_enabled = config.server.proxy_protocol; let proxy_protocol_enabled = config.server.proxy_protocol;
tokio::spawn(async move { tokio::spawn(async move {
let _permit = permit;
if let Err(e) = crate::proxy::client::handle_client_stream( if let Err(e) = crate::proxy::client::handle_client_stream(
stream, fake_peer, config, stats, stream, fake_peer, config, stats,
upstream_manager, replay_checker, buffer_pool, rng, upstream_manager, replay_checker, buffer_pool, rng,
@ -906,11 +942,19 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let max_connections_tcp = max_connections.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((stream, peer_addr)) => { Ok((stream, peer_addr)) => {
let permit = match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
}
};
let config = config_rx.borrow_and_update().clone(); let config = config_rx.borrow_and_update().clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
@ -923,6 +967,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
let proxy_protocol_enabled = listener_proxy_protocol; let proxy_protocol_enabled = listener_proxy_protocol;
tokio::spawn(async move { tokio::spawn(async move {
let _permit = permit;
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
stream, stream,
peer_addr, peer_addr,

View File

@ -100,6 +100,14 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter");
let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed()); let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed());
let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter");
let _ = writeln!(out, "telemt_me_keepalive_pong_total {}", stats.get_me_keepalive_pong());
let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter");
let _ = writeln!(out, "telemt_me_keepalive_timeout_total {}", stats.get_me_keepalive_timeout());
let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts");
let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter");
let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts()); let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts());
@ -108,6 +116,30 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter");
let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success()); let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success());
let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches");
let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter");
let _ = writeln!(out, "telemt_me_crc_mismatch_total {}", stats.get_me_crc_mismatch());
let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches");
let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter");
let _ = writeln!(out, "telemt_me_seq_mismatch_total {}", stats.get_me_seq_mismatch());
let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter");
let _ = writeln!(out, "telemt_me_route_drop_no_conn_total {}", stats.get_me_route_drop_no_conn());
let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter");
let _ = writeln!(out, "telemt_me_route_drop_channel_closed_total {}", stats.get_me_route_drop_channel_closed());
let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter");
let _ = writeln!(out, "telemt_me_route_drop_queue_full_total {}", stats.get_me_route_drop_queue_full());
let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths");
let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter");
let _ = writeln!(out, "telemt_secure_padding_invalid_total {}", stats.get_secure_padding_invalid());
let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections");
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections");

View File

@ -156,14 +156,28 @@ pub const MAX_TLS_RECORD_SIZE: usize = 16384;
/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext /// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Generate padding length for Secure Intermediate protocol. /// Secure Intermediate payload is expected to be 4-byte aligned.
/// Total (data + padding) must not be divisible by 4 per MTProto spec. pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { data_len % 4 == 0
if data_len % 4 == 0 { }
(rng.range(3) + 1) as usize // 1-3
} else { /// Compute Secure Intermediate payload length from wire length.
rng.range(4) as usize // 0-3 /// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary.
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
} }
Some(wire_len - (wire_len % 4))
}
/// Generate padding length for Secure Intermediate protocol.
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
debug_assert!(
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
(rng.range(3) + 1) as usize
} }
// ============= Timeouts ============= // ============= Timeouts =============
@ -297,6 +311,10 @@ pub mod rpc_flags {
pub const FLAG_ABRIDGED: u32 = 0x40000000; pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
pub mod rpc_crypto_flags {
pub const USE_CRC32C: u32 = 0x800;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
@ -332,4 +350,43 @@ mod tests {
assert_eq!(TG_DATACENTERS_V4.len(), 5); assert_eq!(TG_DATACENTERS_V4.len(), 5);
assert_eq!(TG_DATACENTERS_V6.len(), 5); assert_eq!(TG_DATACENTERS_V6.len(), 5);
} }
#[test]
fn secure_padding_never_produces_aligned_total() {
let rng = SecureRandom::new();
for data_len in (0..1000).step_by(4) {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
padding <= 3,
"padding out of range: data_len={data_len}, padding={padding}"
);
assert_ne!(
(data_len + padding) % 4,
0,
"invariant violated: data_len={data_len}, padding={padding}, total={}",
data_len + padding
);
}
}
}
#[test]
fn secure_wire_len_roundtrip_for_aligned_payload() {
for payload_len in (4..4096).step_by(4) {
for padding in 0..=3usize {
let wire_len = payload_len + padding;
let recovered = secure_payload_len_from_wire_len(wire_len);
assert_eq!(recovered, Some(payload_len));
}
}
}
#[test]
fn secure_wire_len_rejects_too_short_frames() {
assert_eq!(secure_payload_len_from_wire_len(0), None);
assert_eq!(secure_payload_len_from_wire_len(1), None);
assert_eq!(secure_payload_len_from_wire_len(2), None);
assert_eq!(secure_payload_len_from_wire_len(3), None);
}
} }

View File

@ -2,6 +2,7 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info}; use tracing::{debug, warn, trace, info};
use zeroize::Zeroize; use zeroize::Zeroize;
@ -108,11 +109,23 @@ where
let cached = if config.censorship.tls_emulation { let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() { if let Some(cache) = tls_cache.as_ref() {
if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
Some(cache.get(&sni).await) if cache.contains_domain(&sni).await {
sni
} else {
config.censorship.tls_domain.clone()
}
} else { } else {
Some(cache.get(&config.censorship.tls_domain).await) config.censorship.tls_domain.clone()
} };
let cached_entry = cache.get(&selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
Some((cached_entry, use_full_cert_payload))
} else { } else {
None None
} }
@ -137,12 +150,13 @@ where
None None
}; };
let response = if let Some(cached_entry) = cached { let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
emulator::build_emulated_server_hello( emulator::build_emulated_server_hello(
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
&cached_entry, &cached_entry,
use_full_cert_payload,
rng, rng,
selected_alpn.clone(), selected_alpn.clone(),
config.censorship.tls_new_session_tickets, config.censorship.tls_new_session_tickets,
@ -253,7 +267,11 @@ where
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure } if is_tls {
config.general.modes.tls || config.general.modes.secure
} else {
config.general.modes.secure || config.general.modes.tls
}
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };

View File

@ -2,7 +2,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::oneshot; use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@ -14,6 +14,11 @@ use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
enum C2MeCommand {
Data { payload: Vec<u8>, flags: u32 },
Close,
}
pub(crate) async fn handle_via_middle_proxy<R, W>( pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>, mut crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>, crypto_writer: CryptoWriter<W>,
@ -59,6 +64,30 @@ where
let frame_limit = config.general.max_client_frame; let frame_limit = config.general.max_client_frame;
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(1024);
let me_pool_c2me = me_pool.clone();
let c2me_sender = tokio::spawn(async move {
while let Some(cmd) = c2me_rx.recv().await {
match cmd {
C2MeCommand::Data { payload, flags } => {
me_pool_c2me.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
&payload,
flags,
).await?;
}
C2MeCommand::Close => {
let _ = me_pool_c2me.send_close(conn_id).await;
return Ok(());
}
}
}
Ok(())
});
let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
let mut me_rx_task = me_rx; let mut me_rx_task = me_rx;
let stats_clone = stats.clone(); let stats_clone = stats.clone();
@ -66,6 +95,7 @@ where
let user_clone = user.clone(); let user_clone = user.clone();
let me_writer = tokio::spawn(async move { let me_writer = tokio::spawn(async move {
let mut writer = crypto_writer; let mut writer = crypto_writer;
let mut frame_buf = Vec::with_capacity(16 * 1024);
loop { loop {
tokio::select! { tokio::select! {
msg = me_rx_task.recv() => { msg = me_rx_task.recv() => {
@ -73,7 +103,44 @@ where
Some(MeResponse::Data { flags, data }) => { Some(MeResponse::Data { flags, data }) => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data"); trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64); stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?; write_client_payload(
&mut writer,
proto_tag,
flags,
&data,
rng_clone.as_ref(),
&mut frame_buf,
)
.await?;
// Drain all immediately queued ME responses and flush once.
while let Ok(next) = me_rx_task.try_recv() {
match next {
MeResponse::Data { flags, data } => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(
&mut writer,
proto_tag,
flags,
&data,
rng_clone.as_ref(),
&mut frame_buf,
).await?;
}
MeResponse::Ack(confirm) => {
trace!(conn_id, confirm, "ME->C quickack (batched)");
write_client_ack(&mut writer, proto_tag, confirm).await?;
}
MeResponse::Close => {
debug!(conn_id, "ME sent close (batched)");
let _ = writer.flush().await;
return Ok(());
}
}
}
writer.flush().await.map_err(ProxyError::Io)?;
} }
Some(MeResponse::Ack(confirm)) => { Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack"); trace!(conn_id, confirm, "ME->C quickack");
@ -81,6 +148,7 @@ where
} }
Some(MeResponse::Close) => { Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close"); debug!(conn_id, "ME sent close");
let _ = writer.flush().await;
return Ok(()); return Ok(());
} }
None => { None => {
@ -99,8 +167,16 @@ where
let mut main_result: Result<()> = Ok(()); let mut main_result: Result<()> = Ok(());
let mut client_closed = false; let mut client_closed = false;
let mut frame_counter: u64 = 0;
loop { loop {
match read_client_payload(&mut crypto_reader, proto_tag, frame_limit, &user).await { match read_client_payload(
&mut crypto_reader,
proto_tag,
frame_limit,
&user,
&mut frame_counter,
&stats,
).await {
Ok(Some((payload, quickack))) => { Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame"); trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64); stats.add_user_octets_from(&user, payload.len() as u64);
@ -111,22 +187,20 @@ where
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED; flags |= RPC_FLAG_NOT_ENCRYPTED;
} }
if let Err(e) = me_pool.send_proxy_req( // Keep client read loop lightweight: route heavy ME send path via a dedicated task.
conn_id, if c2me_tx
success.dc_idx, .send(C2MeCommand::Data { payload, flags })
peer, .await
translated_local_addr, .is_err()
&payload, {
flags, main_result = Err(ProxyError::Proxy("ME sender channel closed".into()));
).await {
main_result = Err(e);
break; break;
} }
} }
Ok(None) => { Ok(None) => {
debug!(conn_id, "Client EOF"); debug!(conn_id, "Client EOF");
client_closed = true; client_closed = true;
let _ = me_pool.send_close(conn_id).await; let _ = c2me_tx.send(C2MeCommand::Close).await;
break; break;
} }
Err(e) => { Err(e) => {
@ -136,6 +210,11 @@ where
} }
} }
drop(c2me_tx);
let c2me_result = c2me_sender
.await
.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}"))));
let _ = stop_tx.send(()); let _ = stop_tx.send(());
let mut writer_result = me_writer let mut writer_result = me_writer
.await .await
@ -151,10 +230,11 @@ where
} }
} }
let result = match (main_result, writer_result) { let result = match (main_result, c2me_result, writer_result) {
(Ok(()), Ok(())) => Ok(()), (Ok(()), Ok(()), Ok(())) => Ok(()),
(Err(e), _) => Err(e), (Err(e), _, _) => Err(e),
(_, Err(e)) => Err(e), (_, Err(e), _) => Err(e),
(_, _, Err(e)) => Err(e),
}; };
debug!(user = %user, conn_id, "ME relay cleanup"); debug!(user = %user, conn_id, "ME relay cleanup");
@ -168,73 +248,123 @@ async fn read_client_payload<R>(
proto_tag: ProtoTag, proto_tag: ProtoTag,
max_frame: usize, max_frame: usize,
user: &str, user: &str,
frame_counter: &mut u64,
stats: &Stats,
) -> Result<Option<(Vec<u8>, bool)>> ) -> Result<Option<(Vec<u8>, bool)>>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
{ {
let (len, quickack) = match proto_tag { loop {
ProtoTag::Abridged => { let (len, quickack, raw_len_bytes) = match proto_tag {
let mut first = [0u8; 1]; ProtoTag::Abridged => {
match client_reader.read_exact(&mut first).await { let mut first = [0u8; 1];
Ok(_) => {} match client_reader.read_exact(&mut first).await {
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Ok(_) => {}
Err(e) => return Err(ProxyError::Io(e)), Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
let len = len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack, None)
} }
ProtoTag::Intermediate | ProtoTag::Secure => {
let quickack = (first[0] & 0x80) != 0; let mut len_buf = [0u8; 4];
let len_words = if (first[0] & 0x7f) == 0x7f { match client_reader.read_exact(&mut len_buf).await {
let mut ext = [0u8; 3]; Ok(_) => {}
client_reader Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
.read_exact(&mut ext) Err(e) => return Err(ProxyError::Io(e)),
.await }
.map_err(ProxyError::Io)?; let quickack = (len_buf[3] & 0x80) != 0;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize (
} else { (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
(first[0] & 0x7f) as usize quickack,
}; Some(len_buf),
)
let len = len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack)
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
} }
let quickack = (len_buf[3] & 0x80) != 0; };
((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack)
if len == 0 {
continue;
} }
}; if len < 4 && proto_tag != ProtoTag::Abridged {
warn!(
if len > max_frame { user = %user,
warn!( len,
user = %user, proto = ?proto_tag,
raw_len = len, "Frame too small — corrupt or probe"
raw_len_hex = format_args!("0x{:08x}", len), );
proto = ?proto_tag, return Err(ProxyError::Proxy(format!("Frame too small: {len}")));
"Frame too large — possible crypto desync or TLS record error"
);
return Err(ProxyError::Proxy(format!("Frame too large: {len} (max {max_frame})")));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes)
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
} }
if len > max_frame {
let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes());
let looks_like_tls = raw_len_bytes
.map(|b| b[0] == 0x16 && b[1] == 0x03)
.unwrap_or(false);
let looks_like_http = raw_len_bytes
.map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D'))
.unwrap_or(false);
warn!(
user = %user,
raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len),
raw_bytes = format_args!(
"{:02x} {:02x} {:02x} {:02x}",
len_buf[0], len_buf[1], len_buf[2], len_buf[3]
),
proto = ?proto_tag,
tls_like = looks_like_tls,
http_like = looks_like_http,
frames_ok = *frame_counter,
"Frame too large — crypto desync forensics"
);
return Err(ProxyError::Proxy(format!(
"Frame too large: {len} (max {max_frame}), frames_ok={}",
*frame_counter
)));
}
let secure_payload_len = if proto_tag == ProtoTag::Secure {
match secure_payload_len_from_wire_len(len) {
Some(payload_len) => payload_len,
None => {
stats.increment_secure_padding_invalid();
return Err(ProxyError::Proxy(format!(
"Invalid secure frame length: {len}"
)));
}
}
} else {
len
};
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure {
payload.truncate(secure_payload_len);
}
*frame_counter += 1;
return Ok(Some((payload, quickack)));
} }
Ok(Some((payload, quickack)))
} }
async fn write_client_payload<W>( async fn write_client_payload<W>(
@ -243,6 +373,7 @@ async fn write_client_payload<W>(
flags: u32, flags: u32,
data: &[u8], data: &[u8],
rng: &SecureRandom, rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
) -> Result<()> ) -> Result<()>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@ -264,8 +395,12 @@ where
if quickack { if quickack {
first |= 0x80; first |= 0x80;
} }
frame_buf.clear();
frame_buf.reserve(1 + data.len());
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer client_writer
.write_all(&[first]) .write_all(&frame_buf)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) { } else if len_words < (1 << 24) {
@ -274,8 +409,12 @@ where
first |= 0x80; first |= 0x80;
} }
let lw = (len_words as u32).to_le_bytes(); let lw = (len_words as u32).to_le_bytes();
frame_buf.clear();
frame_buf.reserve(4 + data.len());
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer client_writer
.write_all(&[first, lw[0], lw[1], lw[2]]) .write_all(&frame_buf)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} else { } else {
@ -284,47 +423,40 @@ where
data.len() data.len()
))); )));
} }
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure { let padding_len = if proto_tag == ProtoTag::Secure {
if !is_valid_secure_payload_len(data.len()) {
return Err(ProxyError::Proxy(format!(
"Secure payload must be 4-byte aligned, got {}",
data.len()
)));
}
secure_padding_len(data.len(), rng) secure_padding_len(data.len(), rng)
} else { } else {
0 0
}; };
let mut len = (data.len() + padding_len) as u32; let mut len_val = (data.len() + padding_len) as u32;
if quickack { if quickack {
len |= 0x8000_0000; len_val |= 0x8000_0000;
} }
client_writer let total = 4 + data.len() + padding_len;
.write_all(&len.to_le_bytes()) frame_buf.clear();
.await frame_buf.reserve(total);
.map_err(ProxyError::Io)?; frame_buf.extend_from_slice(&len_val.to_le_bytes());
client_writer frame_buf.extend_from_slice(data);
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 { if padding_len > 0 {
let pad = rng.bytes(padding_len); let start = frame_buf.len();
client_writer frame_buf.resize(start + padding_len, 0);
.write_all(&pad) rng.fill(&mut frame_buf[start..]);
.await
.map_err(ProxyError::Io)?;
} }
client_writer
.write_all(&frame_buf)
.await
.map_err(ProxyError::Io)?;
} }
} }
// Avoid unconditional per-frame flush (throughput killer on large downloads).
// Flush only when low-latency ack semantics are requested or when
// CryptoWriter has buffered pending ciphertext that must be drained.
if quickack || client_writer.has_pending() {
client_writer.flush().await.map_err(ProxyError::Io)?;
}
Ok(()) Ok(())
} }

View File

@ -21,8 +21,16 @@ pub struct Stats {
handshake_timeouts: AtomicU64, handshake_timeouts: AtomicU64,
me_keepalive_sent: AtomicU64, me_keepalive_sent: AtomicU64,
me_keepalive_failed: AtomicU64, me_keepalive_failed: AtomicU64,
me_keepalive_pong: AtomicU64,
me_keepalive_timeout: AtomicU64,
me_reconnect_attempts: AtomicU64, me_reconnect_attempts: AtomicU64,
me_reconnect_success: AtomicU64, me_reconnect_success: AtomicU64,
me_crc_mismatch: AtomicU64,
me_seq_mismatch: AtomicU64,
me_route_drop_no_conn: AtomicU64,
me_route_drop_channel_closed: AtomicU64,
me_route_drop_queue_full: AtomicU64,
secure_padding_invalid: AtomicU64,
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>, start_time: parking_lot::RwLock<Option<Instant>>,
} }
@ -49,14 +57,45 @@ impl Stats {
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_pong(&self) { self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout(&self) { self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout_by(&self, value: u64) {
self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed);
}
pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_crc_mismatch(&self) { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_seq_mismatch(&self) { self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_no_conn(&self) { self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_channel_closed(&self) {
self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_me_route_drop_queue_full(&self) {
self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_secure_padding_invalid(&self) {
self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) }
pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) }
pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) }
pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) }
pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) }
pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) }
pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) }
pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) }
pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) }
pub fn get_me_route_drop_channel_closed(&self) -> u64 {
self.me_route_drop_channel_closed.load(Ordering::Relaxed)
}
pub fn get_me_route_drop_queue_full(&self) -> u64 {
self.me_route_drop_queue_full.load(Ordering::Relaxed)
}
pub fn get_secure_padding_invalid(&self) -> u64 {
self.secure_padding_invalid.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
self.user_stats.entry(user.to_string()).or_default() self.user_stats.entry(user.to_string()).or_default()
@ -70,7 +109,22 @@ impl Stats {
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
stats.curr_connects.fetch_sub(1, Ordering::Relaxed); let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
loop {
if current == 0 {
break;
}
match counter.compare_exchange_weak(
current,
current - 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
} }
} }

View File

@ -8,7 +8,9 @@ use std::io::{self, Error, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
@ -274,13 +276,13 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
return Ok(None); return Ok(None);
} }
// Calculate padding (indicated by length not divisible by 4) let data_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
let padding_len = len % 4; Error::new(
let data_len = if padding_len != 0 { ErrorKind::InvalidData,
len - padding_len format!("invalid secure frame length: {len}"),
} else { )
len })?;
}; let padding_len = len - data_len;
meta.padding_len = padding_len as u8; meta.padding_len = padding_len as u8;
@ -303,14 +305,15 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
return Ok(()); return Ok(());
} }
// Generate padding to make length not divisible by 4 if !is_valid_secure_payload_len(data.len()) {
let padding_len = if data.len() % 4 == 0 { return Err(Error::new(
// Add 1-3 bytes to make it non-aligned ErrorKind::InvalidData,
(rng.range(3) + 1) as usize format!("secure payload must be 4-byte aligned, got {}", data.len()),
} else { ));
// Already non-aligned, can add 0-3 }
rng.range(4) as usize
}; // Generate padding that keeps total length non-divisible by 4.
let padding_len = secure_padding_len(data.len(), rng);
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
dst.reserve(4 + total_len); dst.reserve(4 + total_len);
@ -625,4 +628,4 @@ mod tests {
let result = codec.decode(&mut buf); let result = codec.decode(&mut buf);
assert!(result.is_err()); assert!(result.is_err());
} }
} }

View File

@ -232,11 +232,13 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let mut data = vec![0u8; len]; let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?; self.upstream.read_exact(&mut data).await?;
// Strip padding (not aligned to 4) let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
if len % 4 != 0 { Error::new(
let actual_len = len - (len % 4); ErrorKind::InvalidData,
data.truncate(actual_len); format!("Invalid secure frame length: {len}"),
} )
})?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta)) Ok((Bytes::from(data), meta))
} }
@ -267,6 +269,13 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
return Ok(()); return Ok(());
} }
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Secure payload must be 4-byte aligned, got {}", data.len()),
));
}
// Add padding so total length is never divisible by 4 (MTProto Secure) // Add padding so total length is never divisible by 4 (MTProto Secure)
let padding_len = secure_padding_len(data.len(), &self.rng); let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len); let padding = self.rng.bytes(padding_len);
@ -550,9 +559,7 @@ mod tests {
writer.flush().await.unwrap(); writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap(); let (received, _meta) = reader.read_frame().await.unwrap();
// Received should have padding stripped to align to 4 assert_eq!(received.len(), data.len());
let expected_len = (data.len() / 4) * 4;
assert_eq!(received.len(), expected_len);
} }
#[tokio::test] #[tokio::test]

View File

@ -1,7 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::{SystemTime, Duration}; use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio::time::sleep; use tokio::time::sleep;
@ -14,6 +15,7 @@ use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult};
pub struct TlsFrontCache { pub struct TlsFrontCache {
memory: RwLock<HashMap<String, Arc<CachedTlsData>>>, memory: RwLock<HashMap<String, Arc<CachedTlsData>>>,
default: Arc<CachedTlsData>, default: Arc<CachedTlsData>,
full_cert_sent: RwLock<HashMap<IpAddr, Instant>>,
disk_path: PathBuf, disk_path: PathBuf,
} }
@ -31,6 +33,7 @@ impl TlsFrontCache {
let default = Arc::new(CachedTlsData { let default = Arc::new(CachedTlsData {
server_hello_template: default_template, server_hello_template: default_template,
cert_info: None, cert_info: None,
cert_payload: None,
app_data_records_sizes: vec![default_len], app_data_records_sizes: vec![default_len],
total_app_data_len: default_len, total_app_data_len: default_len,
fetched_at: SystemTime::now(), fetched_at: SystemTime::now(),
@ -45,6 +48,7 @@ impl TlsFrontCache {
Self { Self {
memory: RwLock::new(map), memory: RwLock::new(map),
default, default,
full_cert_sent: RwLock::new(HashMap::new()),
disk_path: disk_path.as_ref().to_path_buf(), disk_path: disk_path.as_ref().to_path_buf(),
} }
} }
@ -54,6 +58,45 @@ impl TlsFrontCache {
guard.get(sni).cloned().unwrap_or_else(|| self.default.clone()) guard.get(sni).cloned().unwrap_or_else(|| self.default.clone())
} }
pub async fn contains_domain(&self, domain: &str) -> bool {
self.memory.read().await.contains_key(domain)
}
/// Returns true when full cert payload should be sent for client_ip
/// according to TTL policy.
pub async fn take_full_cert_budget_for_ip(
&self,
client_ip: IpAddr,
ttl: Duration,
) -> bool {
if ttl.is_zero() {
self.full_cert_sent
.write()
.await
.insert(client_ip, Instant::now());
return true;
}
let now = Instant::now();
let mut guard = self.full_cert_sent.write().await;
guard.retain(|_, seen_at| now.duration_since(*seen_at) < ttl);
match guard.get_mut(&client_ip) {
Some(seen_at) => {
if now.duration_since(*seen_at) >= ttl {
*seen_at = now;
true
} else {
false
}
}
None => {
guard.insert(client_ip, now);
true
}
}
}
pub async fn set(&self, domain: &str, data: CachedTlsData) { pub async fn set(&self, domain: &str, data: CachedTlsData) {
let mut guard = self.memory.write().await; let mut guard = self.memory.write().await;
guard.insert(domain.to_string(), Arc::new(data)); guard.insert(domain.to_string(), Arc::new(data));
@ -142,6 +185,7 @@ impl TlsFrontCache {
let data = CachedTlsData { let data = CachedTlsData {
server_hello_template: fetched.server_hello_parsed, server_hello_template: fetched.server_hello_parsed,
cert_info: fetched.cert_info, cert_info: fetched.cert_info,
cert_payload: fetched.cert_payload,
app_data_records_sizes: fetched.app_data_records_sizes.clone(), app_data_records_sizes: fetched.app_data_records_sizes.clone(),
total_app_data_len: fetched.total_app_data_len, total_app_data_len: fetched.total_app_data_len,
fetched_at: SystemTime::now(), fetched_at: SystemTime::now(),
@ -161,3 +205,50 @@ impl TlsFrontCache {
&self.disk_path &self.disk_path
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_take_full_cert_budget_for_ip_uses_ttl() {
let cache = TlsFrontCache::new(
&["example.com".to_string()],
1024,
"tlsfront-test-cache",
);
let ip: IpAddr = "127.0.0.1".parse().expect("ip");
let ttl = Duration::from_millis(80);
assert!(cache
.take_full_cert_budget_for_ip(ip, ttl)
.await);
assert!(!cache
.take_full_cert_budget_for_ip(ip, ttl)
.await);
tokio::time::sleep(Duration::from_millis(90)).await;
assert!(cache
.take_full_cert_budget_for_ip(ip, ttl)
.await);
}
#[tokio::test]
async fn test_take_full_cert_budget_for_ip_zero_ttl_always_allows_full_payload() {
let cache = TlsFrontCache::new(
&["example.com".to_string()],
1024,
"tlsfront-test-cache",
);
let ip: IpAddr = "127.0.0.1".parse().expect("ip");
let ttl = Duration::ZERO;
assert!(cache
.take_full_cert_budget_for_ip(ip, ttl)
.await);
assert!(cache
.take_full_cert_budget_for_ip(ip, ttl)
.await);
}
}

View File

@ -3,7 +3,7 @@ use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
}; };
use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key};
use crate::tls_front::types::CachedTlsData; use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo};
const MIN_APP_DATA: usize = 64; const MIN_APP_DATA: usize = 64;
const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 allows up to 2^14 + 256 const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 allows up to 2^14 + 256
@ -27,12 +27,81 @@ fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec<usize> {
.collect() .collect()
} }
fn app_data_body_capacity(sizes: &[usize]) -> usize {
sizes.iter().map(|&size| size.saturating_sub(17)).sum()
}
fn ensure_payload_capacity(mut sizes: Vec<usize>, payload_len: usize) -> Vec<usize> {
if payload_len == 0 {
return sizes;
}
let mut body_total = app_data_body_capacity(&sizes);
if body_total >= payload_len {
return sizes;
}
if let Some(last) = sizes.last_mut() {
let free = MAX_APP_DATA.saturating_sub(*last);
let grow = free.min(payload_len - body_total);
*last += grow;
body_total += grow;
}
while body_total < payload_len {
let remaining = payload_len - body_total;
let chunk = (remaining + 17).min(MAX_APP_DATA).max(MIN_APP_DATA);
sizes.push(chunk);
body_total += chunk.saturating_sub(17);
}
sizes
}
fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option<Vec<u8>> {
let mut fields = Vec::new();
if let Some(subject) = cert_info.subject_cn.as_deref() {
fields.push(format!("CN={subject}"));
}
if let Some(issuer) = cert_info.issuer_cn.as_deref() {
fields.push(format!("ISSUER={issuer}"));
}
if let Some(not_before) = cert_info.not_before_unix {
fields.push(format!("NB={not_before}"));
}
if let Some(not_after) = cert_info.not_after_unix {
fields.push(format!("NA={not_after}"));
}
if !cert_info.san_names.is_empty() {
let san = cert_info
.san_names
.iter()
.take(8)
.map(String::as_str)
.collect::<Vec<_>>()
.join(",");
fields.push(format!("SAN={san}"));
}
if fields.is_empty() {
return None;
}
let mut payload = fields.join(";").into_bytes();
if payload.len() > 512 {
payload.truncate(512);
}
Some(payload)
}
/// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata. /// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata.
pub fn build_emulated_server_hello( pub fn build_emulated_server_hello(
secret: &[u8], secret: &[u8],
client_digest: &[u8; TLS_DIGEST_LEN], client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8], session_id: &[u8],
cached: &CachedTlsData, cached: &CachedTlsData,
use_full_cert_payload: bool,
rng: &SecureRandom, rng: &SecureRandom,
alpn: Option<Vec<u8>>, alpn: Option<Vec<u8>>,
new_session_tickets: u8, new_session_tickets: u8,
@ -109,21 +178,60 @@ pub fn build_emulated_server_hello(
if sizes.is_empty() { if sizes.is_empty() {
sizes.push(cached.total_app_data_len.max(1024)); sizes.push(cached.total_app_data_len.max(1024));
} }
let sizes = jitter_and_clamp_sizes(&sizes, rng); let mut sizes = jitter_and_clamp_sizes(&sizes, rng);
let compact_payload = cached
.cert_info
.as_ref()
.and_then(build_compact_cert_info_payload);
let selected_payload: Option<&[u8]> = if use_full_cert_payload {
cached
.cert_payload
.as_ref()
.map(|payload| payload.certificate_message.as_slice())
.filter(|payload| !payload.is_empty())
.or_else(|| compact_payload.as_deref())
} else {
compact_payload.as_deref()
};
if let Some(payload) = selected_payload {
sizes = ensure_payload_capacity(sizes, payload.len());
}
let mut app_data = Vec::new(); let mut app_data = Vec::new();
let mut payload_offset = 0usize;
for size in sizes { for size in sizes {
let mut rec = Vec::with_capacity(5 + size); let mut rec = Vec::with_capacity(5 + size);
rec.push(TLS_RECORD_APPLICATION); rec.push(TLS_RECORD_APPLICATION);
rec.extend_from_slice(&TLS_VERSION); rec.extend_from_slice(&TLS_VERSION);
rec.extend_from_slice(&(size as u16).to_be_bytes()); rec.extend_from_slice(&(size as u16).to_be_bytes());
if size > 17 {
let body_len = size - 17; if let Some(payload) = selected_payload {
rec.extend_from_slice(&rng.bytes(body_len)); if size > 17 {
rec.push(0x16); // inner content type marker (handshake) let body_len = size - 17;
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag let remaining = payload.len().saturating_sub(payload_offset);
let copy_len = remaining.min(body_len);
if copy_len > 0 {
rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]);
payload_offset += copy_len;
}
if body_len > copy_len {
rec.extend_from_slice(&rng.bytes(body_len - copy_len));
}
rec.push(0x16); // inner content type marker (handshake)
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
} else {
rec.extend_from_slice(&rng.bytes(size));
}
} else { } else {
rec.extend_from_slice(&rng.bytes(size)); if size > 17 {
let body_len = size - 17;
rec.extend_from_slice(&rng.bytes(body_len));
rec.push(0x16); // inner content type marker (handshake)
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
} else {
rec.extend_from_slice(&rng.bytes(size));
}
} }
app_data.extend_from_slice(&rec); app_data.extend_from_slice(&rec);
} }
@ -158,3 +266,125 @@ pub fn build_emulated_server_hello(
response response
} }
#[cfg(test)]
mod tests {
use std::time::SystemTime;
use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsCertPayload};
use super::build_emulated_server_hello;
use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
fn first_app_data_payload(response: &[u8]) -> &[u8] {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
&response[app_start + 5..app_start + 5 + app_len]
}
fn make_cached(cert_payload: Option<TlsCertPayload>) -> CachedTlsData {
CachedTlsData {
server_hello_template: ParsedServerHello {
version: [0x03, 0x03],
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload,
app_data_records_sizes: vec![64],
total_app_data_len: 64,
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
}
}
#[test]
fn test_build_emulated_server_hello_uses_cached_cert_payload() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
let cached = make_cached(Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
certificate_message: cert_msg.clone(),
}));
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x11; 32],
&[0x22; 16],
&cached,
true,
&rng,
None,
0,
);
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
let app_start = ccs_start + 6;
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
let payload = first_app_data_payload(&response);
assert!(payload.starts_with(&cert_msg));
}
#[test]
fn test_build_emulated_server_hello_random_fallback_when_no_cert_payload() {
let cached = make_cached(None);
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x22; 32],
&[0x33; 16],
&cached,
true,
&rng,
None,
0,
);
let payload = first_app_data_payload(&response);
assert!(payload.len() >= 64);
assert_eq!(payload[payload.len() - 17], 0x16);
}
#[test]
fn test_build_emulated_server_hello_uses_compact_payload_after_first() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
let mut cached = make_cached(Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
certificate_message: cert_msg,
}));
cached.cert_info = Some(crate::tls_front::types::ParsedCertificateInfo {
not_after_unix: Some(1_900_000_000),
not_before_unix: Some(1_700_000_000),
issuer_cn: Some("Issuer".to_string()),
subject_cn: Some("example.com".to_string()),
san_names: vec!["example.com".to_string(), "www.example.com".to_string()],
});
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x44; 32],
&[0x55; 16],
&cached,
false,
&rng,
None,
0,
);
let payload = first_app_data_payload(&response);
assert!(payload.starts_with(b"CN=example.com"));
}
}

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, Result, anyhow}; use anyhow::{Result, anyhow};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
@ -19,7 +19,13 @@ use x509_parser::certificate::X509Certificate;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE}; use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE};
use crate::tls_front::types::{ParsedServerHello, TlsExtension, TlsFetchResult, ParsedCertificateInfo}; use crate::tls_front::types::{
ParsedCertificateInfo,
ParsedServerHello,
TlsCertPayload,
TlsExtension,
TlsFetchResult,
};
/// No-op verifier: accept any certificate (we only need lengths and metadata). /// No-op verifier: accept any certificate (we only need lengths and metadata).
#[derive(Debug)] #[derive(Debug)]
@ -315,6 +321,46 @@ fn parse_cert_info(certs: &[CertificateDer<'static>]) -> Option<ParsedCertificat
}) })
} }
fn u24_bytes(value: usize) -> Option<[u8; 3]> {
if value > 0x00ff_ffff {
return None;
}
Some([
((value >> 16) & 0xff) as u8,
((value >> 8) & 0xff) as u8,
(value & 0xff) as u8,
])
}
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
if cert_chain_der.is_empty() {
return None;
}
let mut certificate_list = Vec::new();
for cert in cert_chain_der {
if cert.is_empty() {
return None;
}
certificate_list.extend_from_slice(&u24_bytes(cert.len())?);
certificate_list.extend_from_slice(cert);
certificate_list.extend_from_slice(&0u16.to_be_bytes()); // cert_entry extensions
}
// Certificate = context_len(1) + certificate_list_len(3) + entries
let body_len = 1usize
.checked_add(3)?
.checked_add(certificate_list.len())?;
let mut message = Vec::with_capacity(4 + body_len);
message.push(0x0b); // HandshakeType::certificate
message.extend_from_slice(&u24_bytes(body_len)?);
message.push(0x00); // certificate_request_context length
message.extend_from_slice(&u24_bytes(certificate_list.len())?);
message.extend_from_slice(&certificate_list);
Some(message)
}
async fn fetch_via_raw_tls( async fn fetch_via_raw_tls(
host: &str, host: &str,
port: u16, port: u16,
@ -368,26 +414,18 @@ async fn fetch_via_raw_tls(
}, },
total_app_data_len, total_app_data_len,
cert_info: None, cert_info: None,
cert_payload: None,
}) })
} }
/// Fetch real TLS metadata for the given SNI: negotiated cipher and cert lengths. async fn fetch_via_rustls(
pub async fn fetch_real_tls(
host: &str, host: &str,
port: u16, port: u16,
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>, upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
// Preferred path: raw TLS probe for accurate record sizing // rustls handshake path for certificate and basic negotiated metadata.
match fetch_via_raw_tls(host, port, sni, connect_timeout).await {
Ok(res) => return Ok(res),
Err(e) => {
warn!(sni = %sni, error = %e, "Raw TLS fetch failed, falling back to rustls");
}
}
// Fallback: rustls handshake to at least get certificate sizes
let stream = if let Some(manager) = upstream { let stream = if let Some(manager) = upstream {
// Resolve host to SocketAddr // Resolve host to SocketAddr
if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
@ -429,8 +467,19 @@ pub async fn fetch_real_tls(
.peer_certificates() .peer_certificates()
.map(|slice| slice.to_vec()) .map(|slice| slice.to_vec())
.unwrap_or_default(); .unwrap_or_default();
let cert_chain_der: Vec<Vec<u8>> = certs.iter().map(|c| c.as_ref().to_vec()).collect();
let cert_payload = encode_tls13_certificate_message(&cert_chain_der).map(|certificate_message| {
TlsCertPayload {
cert_chain_der: cert_chain_der.clone(),
certificate_message,
}
});
let total_cert_len: usize = certs.iter().map(|c| c.len()).sum::<usize>().max(1024); let total_cert_len = cert_payload
.as_ref()
.map(|payload| payload.certificate_message.len())
.unwrap_or_else(|| cert_chain_der.iter().map(Vec::len).sum::<usize>())
.max(1024);
let cert_info = parse_cert_info(&certs); let cert_info = parse_cert_info(&certs);
// Heuristic: split across two records if large to mimic real servers a bit. // Heuristic: split across two records if large to mimic real servers a bit.
@ -453,6 +502,7 @@ pub async fn fetch_real_tls(
sni = %sni, sni = %sni,
len = total_cert_len, len = total_cert_len,
cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)), cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)),
has_cert_payload = cert_payload.is_some(),
"Fetched TLS metadata via rustls" "Fetched TLS metadata via rustls"
); );
@ -461,5 +511,81 @@ pub async fn fetch_real_tls(
app_data_records_sizes: app_data_records_sizes.clone(), app_data_records_sizes: app_data_records_sizes.clone(),
total_app_data_len: app_data_records_sizes.iter().sum(), total_app_data_len: app_data_records_sizes.iter().sum(),
cert_info, cert_info,
cert_payload,
}) })
} }
/// Fetch real TLS metadata for the given SNI.
///
/// Strategy:
/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes.
/// 2) Fetch certificate chain via rustls to build cert payload.
/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded.
pub async fn fetch_real_tls(
host: &str,
port: u16,
sni: &str,
connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
) -> Result<TlsFetchResult> {
let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout).await {
Ok(res) => Some(res),
Err(e) => {
warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
None
}
};
match fetch_via_rustls(host, port, sni, connect_timeout, upstream).await {
Ok(rustls_result) => {
if let Some(mut raw) = raw_result {
raw.cert_info = rustls_result.cert_info;
raw.cert_payload = rustls_result.cert_payload;
debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain");
Ok(raw)
} else {
Ok(rustls_result)
}
}
Err(e) => {
if let Some(raw) = raw_result {
warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only");
Ok(raw)
} else {
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::encode_tls13_certificate_message;
fn read_u24(bytes: &[u8]) -> usize {
((bytes[0] as usize) << 16) | ((bytes[1] as usize) << 8) | (bytes[2] as usize)
}
#[test]
fn test_encode_tls13_certificate_message_single_cert() {
let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01];
let message = encode_tls13_certificate_message(&[cert.clone()]).expect("message");
assert_eq!(message[0], 0x0b);
assert_eq!(read_u24(&message[1..4]), message.len() - 4);
assert_eq!(message[4], 0x00);
let cert_list_len = read_u24(&message[5..8]);
assert_eq!(cert_list_len, cert.len() + 5);
let cert_len = read_u24(&message[8..11]);
assert_eq!(cert_len, cert.len());
assert_eq!(&message[11..11 + cert.len()], cert.as_slice());
assert_eq!(&message[11 + cert.len()..13 + cert.len()], &[0x00, 0x00]);
}
#[test]
fn test_encode_tls13_certificate_message_empty_chain() {
assert!(encode_tls13_certificate_message(&[]).is_none());
}
}

View File

@ -29,11 +29,23 @@ pub struct ParsedCertificateInfo {
pub san_names: Vec<String>, pub san_names: Vec<String>,
} }
/// TLS certificate payload captured from profiled upstream.
///
/// `certificate_message` stores an encoded TLS 1.3 Certificate handshake
/// message body that can be replayed as opaque ApplicationData bytes in FakeTLS.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCertPayload {
pub cert_chain_der: Vec<Vec<u8>>,
pub certificate_message: Vec<u8>,
}
/// Cached data per SNI used by the emulator. /// Cached data per SNI used by the emulator.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedTlsData { pub struct CachedTlsData {
pub server_hello_template: ParsedServerHello, pub server_hello_template: ParsedServerHello,
pub cert_info: Option<ParsedCertificateInfo>, pub cert_info: Option<ParsedCertificateInfo>,
#[serde(default)]
pub cert_payload: Option<TlsCertPayload>,
pub app_data_records_sizes: Vec<usize>, pub app_data_records_sizes: Vec<usize>,
pub total_app_data_len: usize, pub total_app_data_len: usize,
#[serde(default = "now_system_time", skip_serializing, skip_deserializing)] #[serde(default = "now_system_time", skip_serializing, skip_deserializing)]
@ -52,4 +64,5 @@ pub struct TlsFetchResult {
pub app_data_records_sizes: Vec<usize>, pub app_data_records_sizes: Vec<usize>,
pub total_app_data_len: usize, pub total_app_data_len: usize,
pub cert_info: Option<ParsedCertificateInfo>, pub cert_info: Option<ParsedCertificateInfo>,
pub cert_payload: Option<TlsCertPayload>,
} }

View File

@ -1,6 +1,6 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32}; use crate::crypto::{AesCbc, crc32, crc32c};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
@ -8,17 +8,46 @@ use crate::protocol::constants::*;
pub(crate) enum WriterCommand { pub(crate) enum WriterCommand {
Data(Vec<u8>), Data(Vec<u8>),
DataAndFlush(Vec<u8>), DataAndFlush(Vec<u8>),
Keepalive,
Close, Close,
} }
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> { #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RpcChecksumMode {
Crc32,
Crc32c,
}
impl RpcChecksumMode {
pub(crate) fn from_handshake_flags(flags: u32) -> Self {
if (flags & rpc_crypto_flags::USE_CRC32C) != 0 {
Self::Crc32c
} else {
Self::Crc32
}
}
pub(crate) fn advertised_flags(self) -> u32 {
match self {
Self::Crc32 => 0,
Self::Crc32c => rpc_crypto_flags::USE_CRC32C,
}
}
}
pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 {
match mode {
RpcChecksumMode::Crc32 => crc32(data),
RpcChecksumMode::Crc32c => crc32c(data),
}
}
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32; let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize); let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes()); frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes()); frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload); frame.extend_from_slice(payload);
let c = crc32(&frame); let c = rpc_crc(crc_mode, &frame);
frame.extend_from_slice(&c.to_le_bytes()); frame.extend_from_slice(&c.to_le_bytes());
frame frame
} }
@ -45,7 +74,7 @@ pub(crate) async fn read_rpc_frame_plaintext(
let crc_offset = total_len - 4; let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]); let actual_crc = rpc_crc(RpcChecksumMode::Crc32, &full[..crc_offset]);
if expected_crc != actual_crc { if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
@ -95,24 +124,52 @@ pub(crate) fn build_handshake_payload(
our_port: u16, our_port: u16,
peer_ip: [u8; 4], peer_ip: [u8; 4],
peer_port: u16, peer_port: u16,
flags: u32,
) -> [u8; 32] { ) -> [u8; 32] {
let mut p = [0u8; 32]; let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
p[4..8].copy_from_slice(&flags.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes. // process_id sender_pid
p[8..12].copy_from_slice(&our_ip); p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes()); p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16; p[14..16].copy_from_slice(&process_pid16().to_le_bytes());
p[14..16].copy_from_slice(&pid.to_le_bytes()); p[16..20].copy_from_slice(&process_utime().to_le_bytes());
// process_id peer_pid
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p[26..28].copy_from_slice(&0u16.to_le_bytes());
p[28..32].copy_from_slice(&0u32.to_le_bytes());
p
}
pub(crate) fn parse_handshake_flags(payload: &[u8]) -> Result<u32> {
if payload.len() != 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Bad handshake payload len: {}",
payload.len()
)));
}
let hs_type = u32::from_le_bytes(payload[0..4].try_into().unwrap());
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
Ok(u32::from_le_bytes(payload[4..8].try_into().unwrap()))
}
fn process_pid16() -> u16 {
(std::process::id() & 0xffff) as u16
}
fn process_utime() -> u32 {
let utime = std::time::SystemTime::now() let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs() as u32; .as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes()); utime
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
} }
pub(crate) fn cbc_encrypt_padded( pub(crate) fn cbc_encrypt_padded(
@ -160,12 +217,13 @@ pub(crate) struct RpcWriter {
pub(crate) key: [u8; 32], pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16], pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32, pub(crate) seq_no: i32,
pub(crate) crc_mode: RpcChecksumMode,
} }
impl RpcWriter { impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload); let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode);
self.seq_no += 1; self.seq_no = self.seq_no.wrapping_add(1);
let pad = (16 - (frame.len() % 16)) % 16; let pad = (16 - (frame.len() % 16)) % 16;
let mut buf = frame; let mut buf = frame;
@ -189,12 +247,4 @@ impl RpcWriter {
self.send(payload).await?; self.send(payload).await?;
self.writer.flush().await.map_err(ProxyError::Io) self.writer.flush().await.map_err(ProxyError::Io)
} }
pub(crate) async fn send_keepalive(&mut self, payload: [u8; 4]) -> Result<()> {
// Keepalive is a frame with fl == 4 and 4 bytes payload.
let mut frame = Vec::with_capacity(8);
frame.extend_from_slice(&4u32.to_le_bytes());
frame.extend_from_slice(&payload);
self.send(&frame).await
}
} }

View File

@ -18,13 +18,14 @@ use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_k
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::IpFamily; use crate::network::IpFamily;
use crate::protocol::constants::{ use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags,
}; };
use super::codec::{ use super::codec::{
build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload,
read_rpc_frame_plaintext, rpc_crc,
}; };
use super::wire::{extract_ip_material, IpMaterial}; use super::wire::{extract_ip_material, IpMaterial};
use super::MePool; use super::MePool;
@ -37,6 +38,7 @@ pub(crate) struct HandshakeOutput {
pub read_iv: [u8; 16], pub read_iv: [u8; 16],
pub write_key: [u8; 32], pub write_key: [u8; 32],
pub write_iv: [u8; 16], pub write_iv: [u8; 16],
pub crc_mode: RpcChecksumMode,
pub handshake_ms: f64, pub handshake_ms: f64,
} }
@ -146,7 +148,7 @@ impl MePool {
let ks = self.key_selector().await; let ks = self.key_selector().await;
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload); let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32);
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
debug!( debug!(
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
@ -284,8 +286,15 @@ impl MePool {
srv_v6_opt.as_ref(), srv_v6_opt.as_ref(),
); );
let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); let requested_crc_mode = RpcChecksumMode::Crc32c;
let hs_frame = build_rpc_frame(-1, &hs_payload); let hs_payload = build_handshake_payload(
hs_our_ip,
local_addr.port(),
hs_peer_ip,
peer_addr.port(),
requested_crc_mode.advertised_flags(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
if diag_level >= 1 { if diag_level >= 1 {
info!( info!(
write_key = %hex_dump(&wk), write_key = %hex_dump(&wk),
@ -314,7 +323,7 @@ impl MePool {
); );
} }
let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 { if diag_level >= 1 {
info!( info!(
hs_cipher = %hex_dump(&encrypted_hs), hs_cipher = %hex_dump(&encrypted_hs),
@ -328,6 +337,7 @@ impl MePool {
let mut enc_buf = BytesMut::with_capacity(256); let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256); let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri; let mut read_iv = ri;
let mut negotiated_crc_mode = RpcChecksumMode::Crc32;
let mut handshake_ok = false; let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok { while Instant::now() < deadline && !handshake_ok {
@ -375,17 +385,23 @@ impl MePool {
let frame = dec_buf.split_to(fl); let frame = dec_buf.split_to(fl);
let pe = fl - 4; let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]); let ac = rpc_crc(RpcChecksumMode::Crc32, &frame[..pe]);
if ec != ac { if ec != ac {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
))); )));
} }
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); let hs_payload = &frame[8..pe];
if hs_payload.len() < 4 {
return Err(ProxyError::InvalidHandshake(
"Handshake payload too short".to_string(),
));
}
let hs_type = u32::from_le_bytes(hs_payload[0..4].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 { if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 { let err_code = if hs_payload.len() >= 8 {
i32::from_le_bytes(frame[12..16].try_into().unwrap()) i32::from_le_bytes(hs_payload[4..8].try_into().unwrap())
} else { } else {
-1 -1
}; };
@ -393,11 +409,21 @@ impl MePool {
"ME rejected handshake (error={err_code})" "ME rejected handshake (error={err_code})"
))); )));
} }
if hs_type != RPC_HANDSHAKE_U32 { let hs_flags = parse_handshake_flags(hs_payload)?;
if hs_flags & 0xff != 0 {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" "Unsupported handshake flags: 0x{hs_flags:08x}"
))); )));
} }
negotiated_crc_mode = if (hs_flags & requested_crc_mode.advertised_flags()) != 0 {
RpcChecksumMode::from_handshake_flags(hs_flags)
} else if (hs_flags & rpc_crypto_flags::USE_CRC32C) != 0 {
return Err(ProxyError::InvalidHandshake(format!(
"Peer negotiated unsupported CRC flags: 0x{hs_flags:08x}"
)));
} else {
RpcChecksumMode::Crc32
};
handshake_ok = true; handshake_ok = true;
break; break;
@ -418,6 +444,7 @@ impl MePool {
read_iv, read_iv,
write_key: wk, write_key: wk,
write_iv, write_iv,
crc_mode: negotiated_crc_mode,
handshake_ms, handshake_ms,
}) })
} }

View File

@ -17,13 +17,11 @@ use crate::network::IpFamily;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use super::ConnRegistry; use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta}; use super::registry::BoundConn;
use super::codec::{RpcWriter, WriterCommand}; use super::codec::{RpcWriter, WriterCommand};
use super::reader::reader_loop; use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
const ME_KEEPALIVE_PAYLOAD_LEN: usize = 4;
#[derive(Clone)] #[derive(Clone)]
pub struct MeWriter { pub struct MeWriter {
@ -361,7 +359,6 @@ impl MePool {
// Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles.
if self.me_warmup_stagger_enabled { if self.me_warmup_stagger_enabled {
let mut delay_ms = 0u64;
for (dc, addrs) in dc_addrs.iter() { for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs { for (ip, port) in addrs {
if self.connection_count() >= pool_size { if self.connection_count() >= pool_size {
@ -369,7 +366,7 @@ impl MePool {
} }
let addr = SocketAddr::new(*ip, *port); let addr = SocketAddr::new(*ip, *port);
let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
delay_ms = delay_ms.saturating_add(self.me_warmup_step_delay.as_millis() as u64 + jitter); let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
tokio::time::sleep(Duration::from_millis(delay_ms)).await; tokio::time::sleep(Duration::from_millis(delay_ms)).await;
if let Err(e) = self.connect_one(addr, rng.as_ref()).await { if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)");
@ -418,14 +415,12 @@ impl MePool {
let degraded = Arc::new(AtomicBool::new(false)); let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096); let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let tx_for_keepalive = tx.clone();
let keepalive_random = self.me_keepalive_payload_random;
let stats = self.stats.clone();
let mut rpc_writer = RpcWriter { let mut rpc_writer = RpcWriter {
writer: hs.wr, writer: hs.wr,
key: hs.write_key, key: hs.write_key,
iv: hs.write_iv, iv: hs.write_iv,
seq_no: 0, seq_no: 0,
crc_mode: hs.crc_mode,
}; };
let cancel_wr = cancel.clone(); let cancel_wr = cancel.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -439,21 +434,6 @@ impl MePool {
Some(WriterCommand::DataAndFlush(payload)) => { Some(WriterCommand::DataAndFlush(payload)) => {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; } if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
} }
Some(WriterCommand::Keepalive) => {
let mut payload = [0u8; ME_KEEPALIVE_PAYLOAD_LEN];
if keepalive_random {
rand::rng().fill(&mut payload);
}
match rpc_writer.send_keepalive(payload).await {
Ok(()) => {
stats.increment_me_keepalive_sent();
}
Err(_) => {
stats.increment_me_keepalive_failed();
break;
}
}
}
Some(WriterCommand::Close) | None => break, Some(WriterCommand::Close) | None => break,
} }
} }
@ -471,12 +451,15 @@ impl MePool {
}; };
self.writers.write().await.push(writer.clone()); self.writers.write().await.push(writer.clone());
self.conn_count.fetch_add(1, Ordering::Relaxed); self.conn_count.fetch_add(1, Ordering::Relaxed);
self.writer_available.notify_waiters(); self.writer_available.notify_one();
let reg = self.registry.clone(); let reg = self.registry.clone();
let writers_arc = self.writers_arc(); let writers_arc = self.writers_arc();
let ping_tracker = self.ping_tracker.clone(); let ping_tracker = self.ping_tracker.clone();
let ping_tracker_reader = ping_tracker.clone();
let rtt_stats = self.rtt_stats.clone(); let rtt_stats = self.rtt_stats.clone();
let stats_reader = self.stats.clone();
let stats_ping = self.stats.clone();
let pool = Arc::downgrade(self); let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone(); let cancel_ping = cancel.clone();
let tx_ping = tx.clone(); let tx_ping = tx.clone();
@ -489,19 +472,20 @@ impl MePool {
let keepalive_jitter = self.me_keepalive_jitter; let keepalive_jitter = self.me_keepalive_jitter;
let cancel_reader_token = cancel.clone(); let cancel_reader_token = cancel.clone();
let cancel_ping_token = cancel_ping.clone(); let cancel_ping_token = cancel_ping.clone();
let cancel_keepalive_token = cancel.clone();
tokio::spawn(async move { tokio::spawn(async move {
let res = reader_loop( let res = reader_loop(
hs.rd, hs.rd,
hs.read_key, hs.read_key,
hs.read_iv, hs.read_iv,
hs.crc_mode,
reg.clone(), reg.clone(),
BytesMut::new(), BytesMut::new(),
BytesMut::new(), BytesMut::new(),
tx.clone(), tx.clone(),
ping_tracker.clone(), ping_tracker_reader,
rtt_stats.clone(), rtt_stats.clone(),
stats_reader,
writer_id, writer_id,
degraded.clone(), degraded.clone(),
cancel_reader_token.clone(), cancel_reader_token.clone(),
@ -526,15 +510,40 @@ impl MePool {
let pool_ping = Arc::downgrade(self); let pool_ping = Arc::downgrade(self);
tokio::spawn(async move { tokio::spawn(async move {
let mut ping_id: i64 = rand::random::<i64>(); let mut ping_id: i64 = rand::random::<i64>();
loop { // Per-writer jittered start to avoid phase sync.
let startup_jitter = if keepalive_enabled {
let jitter_cap_ms = keepalive_interval.as_millis() / 2;
let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1);
Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))
} else {
let jitter = rand::rng() let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
Duration::from_secs(wait)
};
tokio::select! {
_ = cancel_ping_token.cancelled() => return,
_ = tokio::time::sleep(startup_jitter) => {}
}
loop {
let wait = if keepalive_enabled {
let jitter_cap_ms = keepalive_interval.as_millis() / 2;
let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1);
keepalive_interval
+ Duration::from_millis(
rand::rng().random_range(0..=effective_jitter_ms as u64)
)
} else {
let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
Duration::from_secs(secs)
};
tokio::select! { tokio::select! {
_ = cancel_ping_token.cancelled() => { _ = cancel_ping_token.cancelled() => {
break; break;
} }
_ = tokio::time::sleep(Duration::from_secs(wait)) => {} _ = tokio::time::sleep(wait) => {}
} }
let sent_id = ping_id; let sent_id = ping_id;
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
@ -542,12 +551,19 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes());
{ {
let mut tracker = ping_tracker_ping.lock().await; let mut tracker = ping_tracker_ping.lock().await;
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64);
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
} }
ping_id = ping_id.wrapping_add(1); ping_id = ping_id.wrapping_add(1);
stats_ping.increment_me_keepalive_sent();
if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() {
debug!("Active ME ping failed, removing dead writer"); stats_ping.increment_me_keepalive_failed();
debug!("ME ping failed, removing dead writer");
cancel_ping.cancel(); cancel_ping.cancel();
if let Some(pool) = pool_ping.upgrade() { if let Some(pool) = pool_ping.upgrade() {
if cleanup_for_ping if cleanup_for_ping
@ -562,27 +578,6 @@ impl MePool {
} }
}); });
if keepalive_enabled {
let tx_keepalive = tx_for_keepalive;
let cancel_keepalive = cancel_keepalive_token;
tokio::spawn(async move {
// Per-writer jittered start to avoid phase sync.
let jitter_cap_ms = keepalive_interval.as_millis() / 2;
let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1);
let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64);
tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await;
loop {
tokio::select! {
_ = cancel_keepalive.cancelled() => break,
_ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {}
}
if tx_keepalive.send(WriterCommand::Keepalive).await.is_err() {
break;
}
}
});
}
Ok(()) Ok(())
} }
@ -619,15 +614,19 @@ impl MePool {
} }
async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> { async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
{ {
let mut ws = self.writers.write().await; let mut ws = self.writers.write().await;
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos); let w = ws.remove(pos);
w.cancel.cancel(); w.cancel.cancel();
let _ = w.tx.send(WriterCommand::Close).await; close_tx = Some(w.tx.clone());
self.conn_count.fetch_sub(1, Ordering::Relaxed); self.conn_count.fetch_sub(1, Ordering::Relaxed);
} }
} }
if let Some(tx) = close_tx {
let _ = tx.send(WriterCommand::Close).await;
}
self.rtt_stats.lock().await.remove(&writer_id); self.rtt_stats.lock().await.remove(&writer_id);
self.registry.writer_lost(writer_id).await self.registry.writer_lost(writer_id).await
} }

View File

@ -10,31 +10,33 @@ use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32}; use crate::crypto::AesCbc;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::stats::Stats;
use super::codec::WriterCommand; use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
use super::registry::RouteResult;
use super::{ConnRegistry, MeResponse}; use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop( pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>, mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32], dk: [u8; 32],
mut div: [u8; 16], mut div: [u8; 16],
crc_mode: RpcChecksumMode,
reg: Arc<ConnRegistry>, reg: Arc<ConnRegistry>,
enc_leftover: BytesMut, enc_leftover: BytesMut,
mut dec: BytesMut, mut dec: BytesMut,
tx: mpsc::Sender<WriterCommand>, tx: mpsc::Sender<WriterCommand>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>, ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
stats: Arc<Stats>,
_writer_id: u64, _writer_id: u64,
degraded: Arc<AtomicBool>, degraded: Arc<AtomicBool>,
cancel: CancellationToken, cancel: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let mut raw = enc_leftover; let mut raw = enc_leftover;
let mut expected_seq: i32 = 0; let mut expected_seq: i32 = 0;
let mut crc_errors = 0u32;
let mut seq_mismatch = 0u32;
loop { loop {
let mut tmp = [0u8; 16_384]; let mut tmp = [0u8; 16_384];
@ -80,26 +82,28 @@ pub(crate) async fn reader_loop(
let frame = dec.split_to(fl); let frame = dec.split_to(fl);
let pe = fl - 4; let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec { let actual_crc = rpc_crc(crc_mode, &frame[..pe]);
warn!("CRC mismatch in data frame"); if actual_crc != ec {
crc_errors += 1; stats.increment_me_crc_mismatch();
if crc_errors > 3 { warn!(
return Err(ProxyError::Proxy("Too many CRC mismatches".into())); frame_len = fl,
} expected_crc = format_args!("0x{ec:08x}"),
continue; actual_crc = format_args!("0x{actual_crc:08x}"),
"CRC mismatch — CBC crypto desync, aborting ME connection"
);
return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into()));
} }
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq { if seq_no != expected_seq {
stats.increment_me_seq_mismatch();
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
seq_mismatch += 1; return Err(ProxyError::SeqNoMismatch {
if seq_mismatch > 10 { expected: expected_seq,
return Err(ProxyError::Proxy("Too many seq mismatches".into())); got: seq_no,
} });
expected_seq = seq_no.wrapping_add(1);
} else {
expected_seq = expected_seq.wrapping_add(1);
} }
expected_seq = expected_seq.wrapping_add(1);
let payload = &frame[8..pe]; let payload = &frame[8..pe];
if payload.len() < 4 { if payload.len() < 4 {
@ -116,7 +120,13 @@ pub(crate) async fn reader_loop(
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await; let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed { if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&tx, cid).await; send_close_conn(&tx, cid).await;
} }
@ -126,7 +136,13 @@ pub(crate) async fn reader_loop(
trace!(cid, cfm, "RPC_SIMPLE_ACK"); trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await; let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed { if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&tx, cid).await; send_close_conn(&tx, cid).await;
} }
@ -152,6 +168,7 @@ pub(crate) async fn reader_loop(
} }
} else if pt == RPC_PONG_U32 && body.len() >= 8 { } else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
stats.increment_me_keepalive_pong();
if let Some((sent, wid)) = { if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await; let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id) guard.remove(&ping_id)

View File

@ -1,13 +1,25 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::time::Duration;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc::error::TrySendError;
use super::codec::WriterCommand; use super::codec::WriterCommand;
use super::MeResponse; use super::MeResponse;
const ROUTE_CHANNEL_CAPACITY: usize = 4096;
const ROUTE_BACKPRESSURE_TIMEOUT: Duration = Duration::from_millis(25);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
Routed,
NoConn,
ChannelClosed,
QueueFull,
}
#[derive(Clone)] #[derive(Clone)]
pub struct ConnMeta { pub struct ConnMeta {
pub target_dc: i16, pub target_dc: i16,
@ -64,7 +76,7 @@ impl ConnRegistry {
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) { pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed); let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(1024); let (tx, rx) = mpsc::channel(ROUTE_CHANNEL_CAPACITY);
self.inner.write().await.map.insert(id, tx); self.inner.write().await.map.insert(id, tx);
(id, rx) (id, rx)
} }
@ -83,12 +95,27 @@ impl ConnRegistry {
None None
} }
pub async fn route(&self, id: u64, resp: MeResponse) -> bool { pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let inner = self.inner.read().await; let tx = {
if let Some(tx) = inner.map.get(&id) { let inner = self.inner.read().await;
tx.try_send(resp).is_ok() inner.map.get(&id).cloned()
} else { };
false
let Some(tx) = tx else {
return RouteResult::NoConn;
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
// Absorb short bursts without dropping/closing the session immediately.
match tokio::time::timeout(ROUTE_BACKPRESSURE_TIMEOUT, tx.send(resp)).await {
Ok(Ok(())) => RouteResult::Routed,
Ok(Err(_)) => RouteResult::ChannelClosed,
Err(_) => RouteResult::QueueFull,
}
}
} }
} }

View File

@ -62,6 +62,8 @@ impl MePool {
let mut writers_snapshot = { let mut writers_snapshot = {
let ws = self.writers.read().await; let ws = self.writers.read().await;
if ws.is_empty() { if ws.is_empty() {
// Create waiter before recovery attempts so notify_one permits are not missed.
let waiter = self.writer_available.notified();
drop(ws); drop(ws);
for family in self.family_order() { for family in self.family_order() {
let map = match family { let map = match family {
@ -72,13 +74,19 @@ impl MePool {
for (ip, port) in addrs { for (ip, port) in addrs {
let addr = SocketAddr::new(*ip, *port); let addr = SocketAddr::new(*ip, *port);
if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { if self.connect_one(addr, self.rng.as_ref()).await.is_ok() {
self.writer_available.notify_waiters(); self.writer_available.notify_one();
break; break;
} }
} }
} }
} }
if tokio::time::timeout(Duration::from_secs(3), self.writer_available.notified()).await.is_err() { if !self.writers.read().await.is_empty() {
continue;
}
if tokio::time::timeout(Duration::from_secs(3), waiter).await.is_err() {
if !self.writers.read().await.is_empty() {
continue;
}
return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into())); return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into()));
} }
continue; continue;

View File

@ -1,5 +1,7 @@
//! TCP Socket Configuration //! TCP Socket Configuration
use std::collections::HashSet;
use std::fs;
use std::io::Result; use std::io::Result;
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::time::Duration; use std::time::Duration;
@ -234,6 +236,133 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Sock
Ok(socket) Ok(socket)
} }
/// Best-effort process list for listeners occupying the same local TCP port.
#[derive(Debug, Clone)]
pub struct ListenerProcessInfo {
pub pid: u32,
pub process: String,
}
/// Find processes currently listening on the local TCP port of `addr`.
/// Returns an empty list when unsupported or when no owners can be resolved.
pub fn find_listener_processes(addr: SocketAddr) -> Vec<ListenerProcessInfo> {
#[cfg(target_os = "linux")]
{
find_listener_processes_linux(addr)
}
#[cfg(not(target_os = "linux"))]
{
let _ = addr;
Vec::new()
}
}
#[cfg(target_os = "linux")]
fn find_listener_processes_linux(addr: SocketAddr) -> Vec<ListenerProcessInfo> {
let inodes = listening_inodes_for_port(addr);
if inodes.is_empty() {
return Vec::new();
}
let mut out = Vec::new();
let proc_entries = match fs::read_dir("/proc") {
Ok(entries) => entries,
Err(_) => return out,
};
for entry in proc_entries.flatten() {
let pid = match entry.file_name().to_string_lossy().parse::<u32>() {
Ok(pid) => pid,
Err(_) => continue,
};
let fd_dir = entry.path().join("fd");
let fd_entries = match fs::read_dir(fd_dir) {
Ok(entries) => entries,
Err(_) => continue,
};
let mut matched = false;
for fd in fd_entries.flatten() {
let link_target = match fs::read_link(fd.path()) {
Ok(link) => link,
Err(_) => continue,
};
let link_str = link_target.to_string_lossy();
let Some(rest) = link_str.strip_prefix("socket:[") else {
continue;
};
let Some(inode_str) = rest.strip_suffix(']') else {
continue;
};
let Ok(inode) = inode_str.parse::<u64>() else {
continue;
};
if inodes.contains(&inode) {
matched = true;
break;
}
}
if matched {
let process = fs::read_to_string(entry.path().join("comm"))
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "unknown".to_string());
out.push(ListenerProcessInfo { pid, process });
}
}
out.sort_by_key(|p| p.pid);
out.dedup_by_key(|p| p.pid);
out
}
#[cfg(target_os = "linux")]
fn listening_inodes_for_port(addr: SocketAddr) -> HashSet<u64> {
let path = match addr {
SocketAddr::V4(_) => "/proc/net/tcp",
SocketAddr::V6(_) => "/proc/net/tcp6",
};
let mut inodes = HashSet::new();
let Ok(data) = fs::read_to_string(path) else {
return inodes;
};
for line in data.lines().skip(1) {
let cols: Vec<&str> = line.split_whitespace().collect();
if cols.len() < 10 {
continue;
}
// LISTEN state in /proc/net/tcp*
if cols[3] != "0A" {
continue;
}
let Some(port_hex) = cols[1].split(':').nth(1) else {
continue;
};
let Ok(port) = u16::from_str_radix(port_hex, 16) else {
continue;
};
if port != addr.port() {
continue;
}
if let Ok(inode) = cols[9].parse::<u64>() {
inodes.insert(inode);
}
}
inodes
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -24,6 +24,8 @@ const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt /// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5; const DC_PING_TIMEOUT_SECS: u64 = 5;
/// Timeout for direct TG DC TCP connect readiness.
const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10;
// ============= RTT Tracking ============= // ============= RTT Tracking =============
@ -375,7 +377,16 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
}
}
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
@ -383,6 +394,7 @@ impl UpstreamManager {
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
// Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port
let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() { let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() {
// IP:port format - use socket with optional interface binding // IP:port format - use socket with optional interface binding
@ -405,7 +417,15 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: proxy_addr.to_string(),
});
}
}
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
@ -416,8 +436,15 @@ impl UpstreamManager {
if interface.is_some() { if interface.is_some() {
warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring"); warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring");
} }
TcpStream::connect(address).await match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await {
.map_err(ProxyError::Io)? Ok(Ok(stream)) => stream,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: address.clone(),
});
}
}
}; };
// replace socks user_id with config.selected_scope, if set // replace socks user_id with config.selected_scope, if set
@ -425,10 +452,19 @@ impl UpstreamManager {
.filter(|s| !s.is_empty()); .filter(|s| !s.is_empty());
let _user_id: Option<&str> = scope.or(user_id.as_deref()); let _user_id: Option<&str> = scope.or(user_id.as_deref());
connect_socks4(&mut stream, target, _user_id).await?; match tokio::time::timeout(connect_timeout, connect_socks4(&mut stream, target, _user_id)).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
}
}
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
// Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port
let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() { let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() {
// IP:port format - use socket with optional interface binding // IP:port format - use socket with optional interface binding
@ -451,7 +487,15 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: proxy_addr.to_string(),
});
}
}
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
@ -462,8 +506,15 @@ impl UpstreamManager {
if interface.is_some() { if interface.is_some() {
warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring"); warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring");
} }
TcpStream::connect(address).await match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await {
.map_err(ProxyError::Io)? Ok(Ok(stream)) => stream,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: address.clone(),
});
}
}
}; };
debug!(config = ?config, "Socks5 connection"); debug!(config = ?config, "Socks5 connection");
@ -473,7 +524,20 @@ impl UpstreamManager {
let _username: Option<&str> = scope.or(username.as_deref()); let _username: Option<&str> = scope.or(username.as_deref());
let _password: Option<&str> = scope.or(password.as_deref()); let _password: Option<&str> = scope.or(password.as_deref());
connect_socks5(&mut stream, target, _username, _password).await?; match tokio::time::timeout(
connect_timeout,
connect_socks5(&mut stream, target, _username, _password),
)
.await
{
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
}
}
Ok(stream) Ok(stream)
}, },
} }

396
tools/tlsearch.py Normal file
View File

@ -0,0 +1,396 @@
#!/usr/bin/env python3
"""
TLS Profile Inspector
Usage:
python3 tools/tlsearch.py
python3 tools/tlsearch.py tlsfront
python3 tools/tlsearch.py tlsfront/petrovich.ru.json
python3 tools/tlsearch.py tlsfront --only-current
"""
from __future__ import annotations
import argparse
import datetime as dt
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
TLS_VERSIONS = {
0x0301: "TLS 1.0",
0x0302: "TLS 1.1",
0x0303: "TLS 1.2",
0x0304: "TLS 1.3",
}
EXT_NAMES = {
0: "server_name",
5: "status_request",
10: "supported_groups",
11: "ec_point_formats",
13: "signature_algorithms",
16: "alpn",
18: "signed_certificate_timestamp",
21: "padding",
23: "extended_master_secret",
35: "session_ticket",
43: "supported_versions",
45: "psk_key_exchange_modes",
51: "key_share",
}
CIPHER_NAMES = {
0x1301: "TLS_AES_128_GCM_SHA256",
0x1302: "TLS_AES_256_GCM_SHA384",
0x1303: "TLS_CHACHA20_POLY1305_SHA256",
0x1304: "TLS_AES_128_CCM_SHA256",
0x1305: "TLS_AES_128_CCM_8_SHA256",
0x009C: "TLS_RSA_WITH_AES_128_GCM_SHA256",
0x009D: "TLS_RSA_WITH_AES_256_GCM_SHA384",
0xC02F: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xC030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xCCA8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
0xCCA9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
}
NAMED_GROUPS = {
0x001D: "x25519",
0x0017: "secp256r1",
0x0018: "secp384r1",
0x0019: "secp521r1",
0x0100: "ffdhe2048",
0x0101: "ffdhe3072",
0x0102: "ffdhe4096",
}
@dataclass
class ProfileRecognition:
schema: str
mode: str
has_cert_info: bool
has_full_cert_payload: bool
cert_message_len: int
cert_chain_count: int
cert_chain_total_len: int
issues: list[str]
def to_hex(data: Iterable[int]) -> str:
return "".join(f"{b:02x}" for b in data)
def read_u16be(data: list[int], off: int = 0) -> int:
return (data[off] << 8) | data[off + 1]
def normalize_u8_list(value: Any) -> list[int]:
if not isinstance(value, list):
return []
out: list[int] = []
for item in value:
if isinstance(item, int) and 0 <= item <= 0xFF:
out.append(item)
else:
return []
return out
def as_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def as_int(value: Any, default: int = 0) -> int:
return value if isinstance(value, int) else default
def decode_version_pair(v: list[int]) -> str:
if len(v) != 2:
return f"invalid({v})"
ver = read_u16be(v)
return f"0x{ver:04x} ({TLS_VERSIONS.get(ver, 'unknown')})"
def decode_cipher_suite(v: list[int]) -> str:
if len(v) != 2:
return f"invalid({v})"
cs = read_u16be(v)
name = CIPHER_NAMES.get(cs, "unknown")
return f"0x{cs:04x} ({name})"
def decode_supported_versions(data: list[int]) -> str:
if len(data) == 2:
ver = read_u16be(data)
return f"selected=0x{ver:04x} ({TLS_VERSIONS.get(ver, 'unknown')})"
if not data:
return "empty"
if len(data) < 3:
return f"raw={to_hex(data)}"
vec_len = data[0]
versions: list[str] = []
for i in range(1, min(1 + vec_len, len(data)), 2):
if i + 1 >= len(data):
break
ver = read_u16be(data, i)
versions.append(f"0x{ver:04x}({TLS_VERSIONS.get(ver, 'unknown')})")
return "offered=[" + ", ".join(versions) + "]"
def decode_key_share(data: list[int]) -> str:
if len(data) < 4:
return f"raw={to_hex(data)}"
group = read_u16be(data, 0)
key_len = read_u16be(data, 2)
key_hex = to_hex(data[4 : 4 + min(key_len, len(data) - 4)])
gname = NAMED_GROUPS.get(group, "unknown_group")
return f"group=0x{group:04x}({gname}), key_len={key_len}, key={key_hex}"
def decode_alpn(data: list[int]) -> str:
if len(data) < 3:
return f"raw={to_hex(data)}"
total = read_u16be(data, 0)
pos = 2
vals: list[str] = []
limit = min(len(data), 2 + total)
while pos < limit:
ln = data[pos]
pos += 1
if pos + ln > limit:
break
raw = bytes(data[pos : pos + ln])
pos += ln
try:
vals.append(raw.decode("ascii"))
except UnicodeDecodeError:
vals.append(raw.hex())
return "protocols=[" + ", ".join(vals) + "]"
def decode_extension(ext_type: int, data: list[int]) -> str:
if ext_type == 43:
return decode_supported_versions(data)
if ext_type == 51:
return decode_key_share(data)
if ext_type == 16:
return decode_alpn(data)
return f"raw={to_hex(data)}"
def ts_to_iso(ts: Any) -> str:
if not isinstance(ts, int):
return "-"
return dt.datetime.fromtimestamp(ts, tz=dt.timezone.utc).isoformat()
def recognize_profile(obj: dict[str, Any]) -> ProfileRecognition:
issues: list[str] = []
sh = as_dict(obj.get("server_hello_template"))
if not sh:
issues.append("missing server_hello_template")
version = normalize_u8_list(sh.get("version"))
if version and len(version) != 2:
issues.append("server_hello_template.version must have 2 bytes")
app_sizes = obj.get("app_data_records_sizes")
if not isinstance(app_sizes, list) or not app_sizes:
issues.append("missing app_data_records_sizes")
elif any((not isinstance(v, int) or v <= 0) for v in app_sizes):
issues.append("app_data_records_sizes contains invalid values")
if not isinstance(obj.get("total_app_data_len"), int):
issues.append("missing total_app_data_len")
cert_info = as_dict(obj.get("cert_info"))
has_cert_info = bool(
cert_info.get("subject_cn")
or cert_info.get("issuer_cn")
or cert_info.get("san_names")
or isinstance(cert_info.get("not_before_unix"), int)
or isinstance(cert_info.get("not_after_unix"), int)
)
cert_payload = as_dict(obj.get("cert_payload"))
cert_message_len = 0
cert_chain_count = 0
cert_chain_total_len = 0
has_full_cert_payload = False
if cert_payload:
cert_msg = normalize_u8_list(cert_payload.get("certificate_message"))
if not cert_msg:
issues.append("cert_payload.certificate_message is missing or invalid")
else:
cert_message_len = len(cert_msg)
chain_raw = cert_payload.get("cert_chain_der")
if not isinstance(chain_raw, list):
issues.append("cert_payload.cert_chain_der is missing or invalid")
else:
for entry in chain_raw:
cert = normalize_u8_list(entry)
if cert:
cert_chain_count += 1
cert_chain_total_len += len(cert)
else:
issues.append("cert_payload.cert_chain_der has invalid certificate entry")
break
has_full_cert_payload = cert_message_len > 0 and cert_chain_count > 0
elif obj.get("cert_payload") is not None:
issues.append("cert_payload is not an object")
if has_full_cert_payload:
schema = "current"
mode = "full-cert-payload"
elif has_cert_info:
schema = "current-compact"
mode = "compact-cert-info"
else:
schema = "legacy"
mode = "random-fallback"
if issues:
schema = f"{schema}+issues"
return ProfileRecognition(
schema=schema,
mode=mode,
has_cert_info=has_cert_info,
has_full_cert_payload=has_full_cert_payload,
cert_message_len=cert_message_len,
cert_chain_count=cert_chain_count,
cert_chain_total_len=cert_chain_total_len,
issues=issues,
)
def decode_profile(path: Path) -> tuple[str, ProfileRecognition]:
obj: dict[str, Any] = json.loads(path.read_text(encoding="utf-8"))
recognition = recognize_profile(obj)
sh = as_dict(obj.get("server_hello_template"))
version = normalize_u8_list(sh.get("version"))
cipher = normalize_u8_list(sh.get("cipher_suite"))
random_bytes = normalize_u8_list(sh.get("random"))
session_id = normalize_u8_list(sh.get("session_id"))
lines: list[str] = []
lines.append(f"[{path.name}]")
lines.append(f" domain: {obj.get('domain', '-')}")
lines.append(f" profile.schema: {recognition.schema}")
lines.append(f" profile.mode: {recognition.mode}")
lines.append(f" profile.has_full_cert_payload: {recognition.has_full_cert_payload}")
lines.append(f" profile.has_cert_info: {recognition.has_cert_info}")
if recognition.has_full_cert_payload:
lines.append(f" profile.cert_message_len: {recognition.cert_message_len}")
lines.append(f" profile.cert_chain_count: {recognition.cert_chain_count}")
lines.append(f" profile.cert_chain_total_len: {recognition.cert_chain_total_len}")
if recognition.issues:
lines.append(" profile.issues:")
for issue in recognition.issues:
lines.append(f" - {issue}")
lines.append(f" tls.version: {decode_version_pair(version)}")
lines.append(f" tls.cipher: {decode_cipher_suite(cipher)}")
lines.append(f" tls.compression: {sh.get('compression', '-')}")
lines.append(f" tls.random: {to_hex(random_bytes)}")
lines.append(f" tls.session_id_len: {len(session_id)}")
if session_id:
lines.append(f" tls.session_id: {to_hex(session_id)}")
app_sizes = obj.get("app_data_records_sizes", [])
if isinstance(app_sizes, list):
lines.append(" app_data_records_sizes: " + ", ".join(str(v) for v in app_sizes))
else:
lines.append(" app_data_records_sizes: -")
lines.append(f" total_app_data_len: {obj.get('total_app_data_len', '-')}")
cert = as_dict(obj.get("cert_info"))
if cert:
lines.append(" cert_info:")
lines.append(f" subject_cn: {cert.get('subject_cn') or '-'}")
lines.append(f" issuer_cn: {cert.get('issuer_cn') or '-'}")
lines.append(f" not_before: {ts_to_iso(cert.get('not_before_unix'))}")
lines.append(f" not_after: {ts_to_iso(cert.get('not_after_unix'))}")
sans = cert.get("san_names")
if isinstance(sans, list) and sans:
lines.append(" san_names: " + ", ".join(str(v) for v in sans))
else:
lines.append(" san_names: -")
else:
lines.append(" cert_info: -")
exts = sh.get("extensions", [])
if not isinstance(exts, list):
exts = []
lines.append(f" extensions[{len(exts)}]:")
for ext in exts:
ext_obj = as_dict(ext)
ext_type = as_int(ext_obj.get("ext_type"), -1)
data = normalize_u8_list(ext_obj.get("data"))
name = EXT_NAMES.get(ext_type, "unknown")
decoded = decode_extension(ext_type, data)
lines.append(f" - type={ext_type} ({name}), len={len(data)}: {decoded}")
lines.append("")
return ("\n".join(lines), recognition)
def collect_files(input_path: Path) -> list[Path]:
if input_path.is_file():
return [input_path]
return sorted(p for p in input_path.glob("*.json") if p.is_file())
def main() -> int:
parser = argparse.ArgumentParser(
description="Decode TLS profile JSON files and recognize current schema."
)
parser.add_argument(
"path",
nargs="?",
default="tlsfront",
help="Path to tlsfront directory or a single JSON file.",
)
parser.add_argument(
"--only-current",
action="store_true",
help="Show only profiles recognized as current/full-cert-payload.",
)
args = parser.parse_args()
base = Path(args.path)
if not base.exists():
print(f"Path not found: {base}")
return 1
files = collect_files(base)
if not files:
print(f"No JSON files found in: {base}")
return 1
printed = 0
for path in files:
try:
rendered, recognition = decode_profile(path)
if args.only_current and recognition.schema != "current":
continue
print(rendered, end="")
printed += 1
except Exception as e: # noqa: BLE001
print(f"[{path.name}] decode error: {e}\n")
if args.only_current and printed == 0:
print("No current profiles found.")
return 0
if __name__ == "__main__":
raise SystemExit(main())