ReRoute + Bnd-checks in API + Per-upstream Runtime Selftest + BSD-Support: merge pull request #394 from telemt/flow

ReRoute + Bnd-checks in API + Per-upstream Runtime Selftest + BSD-Support
This commit is contained in:
Alexey 2026-03-11 23:34:45 +03:00 committed by GitHub
commit 5bc7004e9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 947 additions and 29 deletions

View File

@ -0,0 +1,135 @@
---
description: 'Rust programming language coding conventions and best practices'
applyTo: '**/*.rs'
---
# Rust Coding Conventions and Best Practices
Follow idiomatic Rust practices and community standards when writing Rust code.
These instructions are based on [The Rust Book](https://doc.rust-lang.org/book/), [Rust API Guidelines](https://rust-lang.github.io/api-guidelines/), [RFC 430 naming conventions](https://github.com/rust-lang/rfcs/blob/master/text/0430-finalizing-naming-conventions.md), and the broader Rust community at [users.rust-lang.org](https://users.rust-lang.org).
## General Instructions
- Always prioritize readability, safety, and maintainability.
- Use strong typing and leverage Rust's ownership system for memory safety.
- Break down complex functions into smaller, more manageable functions.
- For algorithm-related code, include explanations of the approach used.
- Write code with good maintainability practices, including comments on why certain design decisions were made.
- Handle errors gracefully using `Result<T, E>` and provide meaningful error messages.
- For external dependencies, mention their usage and purpose in documentation.
- Use consistent naming conventions following [RFC 430](https://github.com/rust-lang/rfcs/blob/master/text/0430-finalizing-naming-conventions.md).
- Write idiomatic, safe, and efficient Rust code that follows the borrow checker's rules.
- Ensure code compiles without warnings.
## Patterns to Follow
- Use modules (`mod`) and public interfaces (`pub`) to encapsulate logic.
- Handle errors properly using `?`, `match`, or `if let`.
- Use `serde` for serialization and `thiserror` or `anyhow` for custom errors.
- Implement traits to abstract services or external dependencies.
- Structure async code using `async/await` and `tokio` or `async-std`.
- Prefer enums over flags and states for type safety.
- Use builders for complex object creation.
- Split binary and library code (`main.rs` vs `lib.rs`) for testability and reuse.
- Use `rayon` for data parallelism and CPU-bound tasks.
- Use iterators instead of index-based loops as they're often faster and safer.
- Use `&str` instead of `String` for function parameters when you don't need ownership.
- Prefer borrowing and zero-copy operations to avoid unnecessary allocations.
### Ownership, Borrowing, and Lifetimes
- Prefer borrowing (`&T`) over cloning unless ownership transfer is necessary.
- Use `&mut T` when you need to modify borrowed data.
- Explicitly annotate lifetimes when the compiler cannot infer them.
- Use `Rc<T>` for single-threaded reference counting and `Arc<T>` for thread-safe reference counting.
- Use `RefCell<T>` for interior mutability in single-threaded contexts and `Mutex<T>` or `RwLock<T>` for multi-threaded contexts.
## Patterns to Avoid
- Don't use `unwrap()` or `expect()` unless absolutely necessary—prefer proper error handling.
- Avoid panics in library code—return `Result` instead.
- Don't rely on global mutable state—use dependency injection or thread-safe containers.
- Avoid deeply nested logic—refactor with functions or combinators.
- Don't ignore warnings—treat them as errors during CI.
- Avoid `unsafe` unless required and fully documented.
- Don't overuse `clone()`, use borrowing instead of cloning unless ownership transfer is needed.
- Avoid premature `collect()`, keep iterators lazy until you actually need the collection.
- Avoid unnecessary allocations—prefer borrowing and zero-copy operations.
## Code Style and Formatting
- Follow the Rust Style Guide and use `rustfmt` for automatic formatting.
- Keep lines under 100 characters when possible.
- Place function and struct documentation immediately before the item using `///`.
- Use `cargo clippy` to catch common mistakes and enforce best practices.
## Error Handling
- Use `Result<T, E>` for recoverable errors and `panic!` only for unrecoverable errors.
- Prefer `?` operator over `unwrap()` or `expect()` for error propagation.
- Create custom error types using `thiserror` or implement `std::error::Error`.
- Use `Option<T>` for values that may or may not exist.
- Provide meaningful error messages and context.
- Error types should be meaningful and well-behaved (implement standard traits).
- Validate function arguments and return appropriate errors for invalid input.
## API Design Guidelines
### Common Traits Implementation
Eagerly implement common traits where appropriate:
- `Copy`, `Clone`, `Eq`, `PartialEq`, `Ord`, `PartialOrd`, `Hash`, `Debug`, `Display`, `Default`
- Use standard conversion traits: `From`, `AsRef`, `AsMut`
- Collections should implement `FromIterator` and `Extend`
- Note: `Send` and `Sync` are auto-implemented by the compiler when safe; avoid manual implementation unless using `unsafe` code
### Type Safety and Predictability
- Use newtypes to provide static distinctions
- Arguments should convey meaning through types; prefer specific types over generic `bool` parameters
- Use `Option<T>` appropriately for truly optional values
- Functions with a clear receiver should be methods
- Only smart pointers should implement `Deref` and `DerefMut`
### Future Proofing
- Use sealed traits to protect against downstream implementations
- Structs should have private fields
- Functions should validate their arguments
- All public types must implement `Debug`
## Testing and Documentation
- Write comprehensive unit tests using `#[cfg(test)]` modules and `#[test]` annotations.
- Use test modules alongside the code they test (`mod tests { ... }`).
- Write integration tests in `tests/` directory with descriptive filenames.
- Write clear and concise comments for each function, struct, enum, and complex logic.
- Ensure functions have descriptive names and include comprehensive documentation.
- Document all public APIs with rustdoc (`///` comments) following the [API Guidelines](https://rust-lang.github.io/api-guidelines/).
- Use `#[doc(hidden)]` to hide implementation details from public documentation.
- Document error conditions, panic scenarios, and safety considerations.
- Examples should use `?` operator, not `unwrap()` or deprecated `try!` macro.
## Project Organization
- Use semantic versioning in `Cargo.toml`.
- Include comprehensive metadata: `description`, `license`, `repository`, `keywords`, `categories`.
- Use feature flags for optional functionality.
- Organize code into modules using `mod.rs` or named files.
- Keep `main.rs` or `lib.rs` minimal - move logic to modules.
## Quality Checklist
Before publishing or reviewing Rust code, ensure:
### Core Requirements
- [ ] **Naming**: Follows RFC 430 naming conventions
- [ ] **Traits**: Implements `Debug`, `Clone`, `PartialEq` where appropriate
- [ ] **Error Handling**: Uses `Result<T, E>` and provides meaningful error types
- [ ] **Documentation**: All public items have rustdoc comments with examples
- [ ] **Testing**: Comprehensive test coverage including edge cases
### Safety and Quality
- [ ] **Safety**: No unnecessary `unsafe` code, proper error handling
- [ ] **Performance**: Efficient use of iterators, minimal allocations
- [ ] **API Design**: Functions are predictable, flexible, and type-safe
- [ ] **Future Proofing**: Private fields in structs, sealed traits where appropriate
- [ ] **Tooling**: Code passes `cargo fmt`, `cargo clippy`, and `cargo test`

View File

@ -0,0 +1,162 @@
---
description: 'Guidelines for GitHub Copilot to write comments to achieve self-explanatory code with less comments. Examples are in JavaScript but it should work on any language that has comments.'
applyTo: '**'
---
# Self-explanatory Code Commenting Instructions
## Core Principle
**Write code that speaks for itself. Comment only when necessary to explain WHY, not WHAT.**
We do not need comments most of the time.
## Commenting Guidelines
### ❌ AVOID These Comment Types
**Obvious Comments**
```javascript
// Bad: States the obvious
let counter = 0; // Initialize counter to zero
counter++; // Increment counter by one
```
**Redundant Comments**
```javascript
// Bad: Comment repeats the code
function getUserName() {
return user.name; // Return the user's name
}
```
**Outdated Comments**
```javascript
// Bad: Comment doesn't match the code
// Calculate tax at 5% rate
const tax = price * 0.08; // Actually 8%
```
### ✅ WRITE These Comment Types
**Complex Business Logic**
```javascript
// Good: Explains WHY this specific calculation
// Apply progressive tax brackets: 10% up to 10k, 20% above
const tax = calculateProgressiveTax(income, [0.10, 0.20], [10000]);
```
**Non-obvious Algorithms**
```javascript
// Good: Explains the algorithm choice
// Using Floyd-Warshall for all-pairs shortest paths
// because we need distances between all nodes
for (let k = 0; k < vertices; k++) {
for (let i = 0; i < vertices; i++) {
for (let j = 0; j < vertices; j++) {
// ... implementation
}
}
}
```
**Regex Patterns**
```javascript
// Good: Explains what the regex matches
// Match email format: username@domain.extension
const emailPattern = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/;
```
**API Constraints or Gotchas**
```javascript
// Good: Explains external constraint
// GitHub API rate limit: 5000 requests/hour for authenticated users
await rateLimiter.wait();
const response = await fetch(githubApiUrl);
```
## Decision Framework
Before writing a comment, ask:
1. **Is the code self-explanatory?** → No comment needed
2. **Would a better variable/function name eliminate the need?** → Refactor instead
3. **Does this explain WHY, not WHAT?** → Good comment
4. **Will this help future maintainers?** → Good comment
## Special Cases for Comments
### Public APIs
```javascript
/**
* Calculate compound interest using the standard formula.
*
* @param {number} principal - Initial amount invested
* @param {number} rate - Annual interest rate (as decimal, e.g., 0.05 for 5%)
* @param {number} time - Time period in years
* @param {number} compoundFrequency - How many times per year interest compounds (default: 1)
* @returns {number} Final amount after compound interest
*/
function calculateCompoundInterest(principal, rate, time, compoundFrequency = 1) {
// ... implementation
}
```
### Configuration and Constants
```javascript
// Good: Explains the source or reasoning
const MAX_RETRIES = 3; // Based on network reliability studies
const API_TIMEOUT = 5000; // AWS Lambda timeout is 15s, leaving buffer
```
### Annotations
```javascript
// TODO: Replace with proper user authentication after security review
// FIXME: Memory leak in production - investigate connection pooling
// HACK: Workaround for bug in library v2.1.0 - remove after upgrade
// NOTE: This implementation assumes UTC timezone for all calculations
// WARNING: This function modifies the original array instead of creating a copy
// PERF: Consider caching this result if called frequently in hot path
// SECURITY: Validate input to prevent SQL injection before using in query
// BUG: Edge case failure when array is empty - needs investigation
// REFACTOR: Extract this logic into separate utility function for reusability
// DEPRECATED: Use newApiFunction() instead - this will be removed in v3.0
```
## Anti-Patterns to Avoid
### Dead Code Comments
```javascript
// Bad: Don't comment out code
// const oldFunction = () => { ... };
const newFunction = () => { ... };
```
### Changelog Comments
```javascript
// Bad: Don't maintain history in comments
// Modified by John on 2023-01-15
// Fixed bug reported by Sarah on 2023-02-03
function processData() {
// ... implementation
}
```
### Divider Comments
```javascript
// Bad: Don't use decorative comments
//=====================================
// UTILITY FUNCTIONS
//=====================================
```
## Quality Checklist
Before committing, ensure your comments:
- [ ] Explain WHY, not WHAT
- [ ] Are grammatically correct and clear
- [ ] Will remain accurate as code evolves
- [ ] Add genuine value to code understanding
- [ ] Are placed appropriately (above the code they describe)
- [ ] Use proper spelling and professional language
## Summary
Remember: **The best comment is the one you don't need to write because the code is self-documenting.**

2
Cargo.lock generated
View File

@ -2087,7 +2087,7 @@ dependencies = [
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.1.3" version = "3.3.15"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -73,3 +73,6 @@ futures = "0.3"
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"
harness = false harness = false
[profile.release]
lto = "thin"

View File

@ -264,6 +264,11 @@ git clone https://github.com/telemt/telemt
cd telemt cd telemt
# Starting Release Build # Starting Release Build
cargo build --release cargo build --release
# Low-RAM devices (1 GB, e.g. NanoPi Neo3 / Raspberry Pi Zero 2):
# release profile uses lto = "thin" to reduce peak linker memory.
# If your custom toolchain overrides profiles, avoid enabling fat LTO.
# Move to /bin # Move to /bin
mv ./target/release/telemt /bin mv ./target/release/telemt /bin
# Make executable # Make executable
@ -272,6 +277,12 @@ chmod +x /bin/telemt
telemt config.toml telemt config.toml
``` ```
### OpenBSD
- Build and service setup guide: [OpenBSD Guide (EN)](docs/OPENBSD.en.md)
- Example rc.d script: [contrib/openbsd/telemt.rcd](contrib/openbsd/telemt.rcd)
- Status: OpenBSD sandbox hardening with `pledge(2)` and `unveil(2)` is not implemented yet.
## Why Rust? ## Why Rust?
- Long-running reliability and idempotent behavior - Long-running reliability and idempotent behavior
- Rust's deterministic resource management - RAII - Rust's deterministic resource management - RAII

View File

@ -0,0 +1,16 @@
#!/bin/ksh
# /etc/rc.d/telemt
#
# rc.d(8) script for Telemt MTProxy daemon.
# Tokio runtime does not daemonize itself, so rc_bg=YES is used.
daemon="/usr/local/bin/telemt"
daemon_user="_telemt"
daemon_flags="/etc/telemt/config.toml"
. /etc/rc.d/rc.subr
rc_bg=YES
rc_reload=NO
rc_cmd $1

132
docs/OPENBSD.en.md Normal file
View File

@ -0,0 +1,132 @@
# Telemt on OpenBSD (Build, Run, and rc.d)
This guide covers a practical OpenBSD deployment flow for Telemt:
- build from source,
- install binary and config,
- run as an rc.d daemon,
- verify basic runtime behavior.
## 1. Prerequisites
Install required packages:
```sh
doas pkg_add rust git
```
Notes:
- Telemt release installer (`install.sh`) is Linux-only.
- On OpenBSD, use source build with `cargo`.
## 2. Build from source
```sh
git clone https://github.com/telemt/telemt
cd telemt
cargo build --release
./target/release/telemt --version
```
For low-RAM systems, this repository already uses `lto = "thin"` in release profile.
## 3. Install binary and config
```sh
doas install -d -m 0755 /usr/local/bin
doas install -m 0755 ./target/release/telemt /usr/local/bin/telemt
doas install -d -m 0750 /etc/telemt
doas install -m 0640 ./config.toml /etc/telemt/config.toml
```
## 4. Create runtime user
```sh
doas useradd -L daemon -s /sbin/nologin -d /var/empty _telemt
```
If `_telemt` already exists, continue.
## 5. Install rc.d service
Install the provided script:
```sh
doas install -m 0555 ./contrib/openbsd/telemt.rcd /etc/rc.d/telemt
```
Enable and start:
```sh
doas rcctl enable telemt
# Optional: send daemon output to syslog
#doas rcctl set telemt logger daemon.info
doas rcctl start telemt
```
Service controls:
```sh
doas rcctl check telemt
doas rcctl restart telemt
doas rcctl stop telemt
```
## 6. Resource limits (recommended)
OpenBSD rc.d can apply limits via login class. Add class `telemt` and assign it to `_telemt`.
Example class entry:
```text
telemt:\
:openfiles-cur=8192:openfiles-max=16384:\
:datasize-cur=768M:datasize-max=1024M:\
:coredumpsize=0:\
:tc=daemon:
```
These values are conservative defaults for small and medium deployments.
Increase `openfiles-*` only if logs show descriptor exhaustion under load.
Then rebuild database and assign class:
```sh
doas cap_mkdb /etc/login.conf
#doas usermod -L telemt _telemt
```
Uncomment `usermod` if you want this class bound to the Telemt user.
## 7. Functional smoke test
1. Validate service state:
```sh
doas rcctl check telemt
```
2. Check listener is present (replace 443 if needed):
```sh
netstat -n -f inet -p tcp | grep LISTEN | grep '\.443'
```
3. Verify process user:
```sh
ps -o user,pid,command -ax | grep telemt | grep -v grep
```
4. If startup fails, debug in foreground:
```sh
RUST_LOG=debug /usr/local/bin/telemt /etc/telemt/config.toml
```
## 8. OpenBSD-specific caveats
- OpenBSD does not support per-socket keepalive retries/interval tuning in the same way as Linux.
- Telemt source already uses target-aware cfg gates for keepalive setup.
- Use rc.d/rcctl, not systemd.

View File

@ -19,6 +19,15 @@ need_cmd() {
command -v "$1" >/dev/null 2>&1 || die "required command not found: $1" command -v "$1" >/dev/null 2>&1 || die "required command not found: $1"
} }
detect_os() {
os="$(uname -s)"
case "$os" in
Linux) printf 'linux\n' ;;
OpenBSD) printf 'openbsd\n' ;;
*) printf '%s\n' "$os" ;;
esac
}
detect_arch() { detect_arch() {
arch="$(uname -m)" arch="$(uname -m)"
case "$arch" in case "$arch" in
@ -68,6 +77,19 @@ need_cmd grep
need_cmd install need_cmd install
ARCH="$(detect_arch)" ARCH="$(detect_arch)"
OS="$(detect_os)"
if [ "$OS" != "linux" ]; then
case "$OS" in
openbsd)
die "install.sh installs only Linux release artifacts. On OpenBSD, build from source (see docs/OPENBSD.en.md)."
;;
*)
die "unsupported operating system for install.sh: $OS"
;;
esac
fi
LIBC="$(detect_libc)" LIBC="$(detect_libc)"
case "$VERSION" in case "$VERSION" in

View File

@ -16,6 +16,7 @@ use tracing::{debug, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::RouteRuntimeController;
use crate::startup::StartupTracker; use crate::startup::StartupTracker;
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
@ -84,6 +85,7 @@ pub(super) struct ApiShared {
pub(super) request_id: Arc<AtomicU64>, pub(super) request_id: Arc<AtomicU64>,
pub(super) runtime_state: Arc<ApiRuntimeState>, pub(super) runtime_state: Arc<ApiRuntimeState>,
pub(super) startup_tracker: Arc<StartupTracker>, pub(super) startup_tracker: Arc<StartupTracker>,
pub(super) route_runtime: Arc<RouteRuntimeController>,
} }
impl ApiShared { impl ApiShared {
@ -101,6 +103,7 @@ pub async fn serve(
stats: Arc<Stats>, stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
me_pool: Arc<RwLock<Option<Arc<MePool>>>>, me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
route_runtime: Arc<RouteRuntimeController>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
admission_rx: watch::Receiver<bool>, admission_rx: watch::Receiver<bool>,
@ -147,6 +150,7 @@ pub async fn serve(
request_id: Arc::new(AtomicU64::new(1)), request_id: Arc::new(AtomicU64::new(1)),
runtime_state: runtime_state.clone(), runtime_state: runtime_state.clone(),
startup_tracker, startup_tracker,
route_runtime,
}); });
spawn_runtime_watchers( spawn_runtime_watchers(
@ -338,7 +342,7 @@ async fn handle(
} }
("GET", "/v1/runtime/me-selftest") => { ("GET", "/v1/runtime/me-selftest") => {
let revision = current_revision(&shared.config_path).await?; let revision = current_revision(&shared.config_path).await?;
let data = build_runtime_me_selftest_data(shared.as_ref()).await; let data = build_runtime_me_selftest_data(shared.as_ref(), cfg.as_ref()).await;
Ok(success_response(StatusCode::OK, data, revision)) Ok(success_response(StatusCode::OK, data, revision))
} }
("GET", "/v1/runtime/connections/summary") => { ("GET", "/v1/runtime/connections/summary") => {

View File

@ -1,11 +1,14 @@
use std::net::IpAddr; use std::net::IpAddr;
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use serde::Serialize; use serde::Serialize;
use crate::config::{ProxyConfig, UpstreamType};
use crate::network::probe::{detect_interface_ipv4, detect_interface_ipv6, is_bogon}; use crate::network::probe::{detect_interface_ipv4, detect_interface_ipv6, is_bogon};
use crate::transport::middle_proxy::{bnd_snapshot, timeskew_snapshot}; use crate::transport::middle_proxy::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots};
use crate::transport::UpstreamRouteKind;
use super::ApiShared; use super::ApiShared;
@ -65,13 +68,26 @@ pub(super) struct RuntimeMeSelftestBndData {
pub(super) last_seen_age_secs: Option<u64>, pub(super) last_seen_age_secs: Option<u64>,
} }
#[derive(Serialize)]
pub(super) struct RuntimeMeSelftestUpstreamData {
pub(super) upstream_id: usize,
pub(super) route_kind: &'static str,
pub(super) address: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) bnd: Option<RuntimeMeSelftestBndData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) ip: Option<String>,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(super) struct RuntimeMeSelftestPayload { pub(super) struct RuntimeMeSelftestPayload {
pub(super) kdf: RuntimeMeSelftestKdfData, pub(super) kdf: RuntimeMeSelftestKdfData,
pub(super) timeskew: RuntimeMeSelftestTimeskewData, pub(super) timeskew: RuntimeMeSelftestTimeskewData,
pub(super) ip: RuntimeMeSelftestIpData, pub(super) ip: RuntimeMeSelftestIpData,
pub(super) pid: RuntimeMeSelftestPidData, pub(super) pid: RuntimeMeSelftestPidData,
pub(super) bnd: RuntimeMeSelftestBndData, pub(super) bnd: Option<RuntimeMeSelftestBndData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) upstreams: Option<Vec<RuntimeMeSelftestUpstreamData>>,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -98,7 +114,10 @@ fn kdf_ewma_state() -> &'static Mutex<KdfEwmaState> {
KDF_EWMA_STATE.get_or_init(|| Mutex::new(KdfEwmaState::default())) KDF_EWMA_STATE.get_or_init(|| Mutex::new(KdfEwmaState::default()))
} }
pub(super) async fn build_runtime_me_selftest_data(shared: &ApiShared) -> RuntimeMeSelftestData { pub(super) async fn build_runtime_me_selftest_data(
shared: &ApiShared,
cfg: &ProxyConfig,
) -> RuntimeMeSelftestData {
let now_epoch_secs = now_epoch_secs(); let now_epoch_secs = now_epoch_secs();
if shared.me_pool.read().await.is_none() { if shared.me_pool.read().await.is_none() {
return RuntimeMeSelftestData { return RuntimeMeSelftestData {
@ -139,7 +158,26 @@ pub(super) async fn build_runtime_me_selftest_data(shared: &ApiShared) -> Runtim
let pid = std::process::id(); let pid = std::process::id();
let pid_state = if pid == 1 { "one" } else { "non-one" }; let pid_state = if pid == 1 { "one" } else { "non-one" };
let bnd = bnd_snapshot(); let has_socks_upstreams = cfg.upstreams.iter().any(|upstream| {
upstream.enabled
&& matches!(
upstream.upstream_type,
UpstreamType::Socks4 { .. } | UpstreamType::Socks5 { .. }
)
});
let bnd = if has_socks_upstreams {
let snapshot = bnd_snapshot();
Some(RuntimeMeSelftestBndData {
addr_state: snapshot.addr_status,
port_state: snapshot.port_status,
last_addr: snapshot.last_addr.map(|value| value.to_string()),
last_seen_age_secs: snapshot.last_seen_age_secs,
})
} else {
None
};
let upstreams = build_upstream_selftest_data(shared);
RuntimeMeSelftestData { RuntimeMeSelftestData {
enabled: true, enabled: true,
@ -168,16 +206,41 @@ pub(super) async fn build_runtime_me_selftest_data(shared: &ApiShared) -> Runtim
pid, pid,
state: pid_state, state: pid_state,
}, },
bnd: RuntimeMeSelftestBndData { bnd,
addr_state: bnd.addr_status, upstreams,
port_state: bnd.port_status,
last_addr: bnd.last_addr.map(|value| value.to_string()),
last_seen_age_secs: bnd.last_seen_age_secs,
},
}), }),
} }
} }
fn build_upstream_selftest_data(shared: &ApiShared) -> Option<Vec<RuntimeMeSelftestUpstreamData>> {
let snapshot = shared.upstream_manager.try_api_snapshot()?;
if snapshot.summary.configured_total <= 1 {
return None;
}
let mut upstream_bnd_by_id: HashMap<usize, _> = upstream_bnd_snapshots()
.into_iter()
.map(|entry| (entry.upstream_id, entry))
.collect();
let mut rows = Vec::with_capacity(snapshot.upstreams.len());
for upstream in snapshot.upstreams {
let upstream_bnd = upstream_bnd_by_id.remove(&upstream.upstream_id);
rows.push(RuntimeMeSelftestUpstreamData {
upstream_id: upstream.upstream_id,
route_kind: map_route_kind(upstream.route_kind),
address: upstream.address,
bnd: upstream_bnd.as_ref().map(|entry| RuntimeMeSelftestBndData {
addr_state: entry.addr_status,
port_state: entry.port_status,
last_addr: entry.last_addr.map(|value| value.to_string()),
last_seen_age_secs: entry.last_seen_age_secs,
}),
ip: upstream_bnd.and_then(|entry| entry.last_ip.map(|value| value.to_string())),
});
}
Some(rows)
}
fn update_kdf_ewma(now_epoch_secs: u64, total_errors: u64) -> f64 { fn update_kdf_ewma(now_epoch_secs: u64, total_errors: u64) -> f64 {
let Ok(mut guard) = kdf_ewma_state().lock() else { let Ok(mut guard) = kdf_ewma_state().lock() else {
return 0.0; return 0.0;
@ -216,6 +279,14 @@ fn classify_ip(ip: IpAddr) -> &'static str {
"good" "good"
} }
fn map_route_kind(value: UpstreamRouteKind) -> &'static str {
match value {
UpstreamRouteKind::Direct => "direct",
UpstreamRouteKind::Socks4 => "socks4",
UpstreamRouteKind::Socks5 => "socks5",
}
}
fn round3(value: f64) -> f64 { fn round3(value: f64) -> f64 {
(value * 1000.0).round() / 1000.0 (value * 1000.0).round() / 1000.0
} }

View File

@ -3,6 +3,7 @@ use std::sync::atomic::Ordering;
use serde::Serialize; use serde::Serialize;
use crate::config::{MeFloorMode, MeWriterPickMode, ProxyConfig, UserMaxUniqueIpsMode}; use crate::config::{MeFloorMode, MeWriterPickMode, ProxyConfig, UserMaxUniqueIpsMode};
use crate::proxy::route_mode::RelayRouteMode;
use super::ApiShared; use super::ApiShared;
use super::runtime_init::build_runtime_startup_summary; use super::runtime_init::build_runtime_startup_summary;
@ -35,6 +36,10 @@ pub(super) struct RuntimeGatesData {
pub(super) me_runtime_ready: bool, pub(super) me_runtime_ready: bool,
pub(super) me2dc_fallback_enabled: bool, pub(super) me2dc_fallback_enabled: bool,
pub(super) use_middle_proxy: bool, pub(super) use_middle_proxy: bool,
pub(super) route_mode: &'static str,
pub(super) reroute_active: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) reroute_to_direct_at_epoch_secs: Option<u64>,
pub(super) startup_status: &'static str, pub(super) startup_status: &'static str,
pub(super) startup_stage: String, pub(super) startup_stage: String,
pub(super) startup_progress_pct: f64, pub(super) startup_progress_pct: f64,
@ -157,6 +162,16 @@ pub(super) async fn build_runtime_gates_data(
cfg: &ProxyConfig, cfg: &ProxyConfig,
) -> RuntimeGatesData { ) -> RuntimeGatesData {
let startup_summary = build_runtime_startup_summary(shared).await; let startup_summary = build_runtime_startup_summary(shared).await;
let route_state = shared.route_runtime.snapshot();
let route_mode = route_state.mode.as_str();
let reroute_active = cfg.general.use_middle_proxy
&& cfg.general.me2dc_fallback
&& matches!(route_state.mode, RelayRouteMode::Direct);
let reroute_to_direct_at_epoch_secs = if reroute_active {
shared.route_runtime.direct_since_epoch_secs()
} else {
None
};
let me_runtime_ready = if !cfg.general.use_middle_proxy { let me_runtime_ready = if !cfg.general.use_middle_proxy {
true true
} else { } else {
@ -175,6 +190,9 @@ pub(super) async fn build_runtime_gates_data(
me_runtime_ready, me_runtime_ready,
me2dc_fallback_enabled: cfg.general.me2dc_fallback, me2dc_fallback_enabled: cfg.general.me2dc_fallback,
use_middle_proxy: cfg.general.use_middle_proxy, use_middle_proxy: cfg.general.use_middle_proxy,
route_mode,
reroute_active,
reroute_to_direct_at_epoch_secs,
startup_status: startup_summary.status, startup_status: startup_summary.status,
startup_stage: startup_summary.stage, startup_stage: startup_summary.stage,
startup_progress_pct: startup_summary.progress_pct, startup_progress_pct: startup_summary.progress_pct,

View File

@ -220,6 +220,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let ip_tracker_api = ip_tracker.clone(); let ip_tracker_api = ip_tracker.clone();
let me_pool_api = api_me_pool.clone(); let me_pool_api = api_me_pool.clone();
let upstream_manager_api = upstream_manager.clone(); let upstream_manager_api = upstream_manager.clone();
let route_runtime_api = route_runtime.clone();
let config_rx_api = api_config_rx.clone(); let config_rx_api = api_config_rx.clone();
let admission_rx_api = admission_rx.clone(); let admission_rx_api = admission_rx.clone();
let config_path_api = std::path::PathBuf::from(&config_path); let config_path_api = std::path::PathBuf::from(&config_path);
@ -231,6 +232,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
stats_api, stats_api,
ip_tracker_api, ip_tracker_api,
me_pool_api, me_pool_api,
route_runtime_api,
upstream_manager_api, upstream_manager_api,
config_rx_api, config_rx_api,
admission_rx_api, admission_rx_api,

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::time::Duration; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::watch; use tokio::sync::watch;
@ -43,6 +43,7 @@ pub(crate) struct RouteCutoverState {
pub(crate) struct RouteRuntimeController { pub(crate) struct RouteRuntimeController {
mode: Arc<AtomicU8>, mode: Arc<AtomicU8>,
generation: Arc<AtomicU64>, generation: Arc<AtomicU64>,
direct_since_epoch_secs: Arc<AtomicU64>,
tx: watch::Sender<RouteCutoverState>, tx: watch::Sender<RouteCutoverState>,
} }
@ -53,9 +54,15 @@ impl RouteRuntimeController {
generation: 0, generation: 0,
}; };
let (tx, _rx) = watch::channel(initial); let (tx, _rx) = watch::channel(initial);
let direct_since_epoch_secs = if matches!(initial_mode, RelayRouteMode::Direct) {
now_epoch_secs()
} else {
0
};
Self { Self {
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
generation: Arc::new(AtomicU64::new(0)), generation: Arc::new(AtomicU64::new(0)),
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
tx, tx,
} }
} }
@ -71,11 +78,22 @@ impl RouteRuntimeController {
self.tx.subscribe() self.tx.subscribe()
} }
pub(crate) fn direct_since_epoch_secs(&self) -> Option<u64> {
let value = self.direct_since_epoch_secs.load(Ordering::Relaxed);
(value > 0).then_some(value)
}
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> { pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
if previous == mode.as_u8() { if previous == mode.as_u8() {
return None; return None;
} }
if matches!(mode, RelayRouteMode::Direct) {
self.direct_since_epoch_secs
.store(now_epoch_secs(), Ordering::Relaxed);
} else {
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
}
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let next = RouteCutoverState { mode, generation }; let next = RouteCutoverState { mode, generation };
self.tx.send_replace(next); self.tx.send_replace(next);
@ -83,6 +101,13 @@ impl RouteRuntimeController {
} }
} }
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|value| value.as_secs())
.unwrap_or(0)
}
pub(crate) fn is_session_affected_by_cutover( pub(crate) fn is_session_affected_by_cutover(
current: RouteCutoverState, current: RouteCutoverState,
_session_mode: RelayRouteMode, _session_mode: RelayRouteMode,

View File

@ -33,7 +33,7 @@ use super::codec::{
cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload, cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload,
read_rpc_frame_plaintext, rpc_crc, read_rpc_frame_plaintext, rpc_crc,
}; };
use super::selftest::{BndAddrStatus, BndPortStatus, record_bnd_status}; use super::selftest::{BndAddrStatus, BndPortStatus, record_bnd_status, record_upstream_bnd_status};
use super::wire::{extract_ip_material, IpMaterial}; use super::wire::{extract_ip_material, IpMaterial};
use super::MePool; use super::MePool;
@ -199,10 +199,26 @@ impl MePool {
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> { fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
let sock = SockRef::from(stream); let sock = SockRef::from(stream);
let ka = TcpKeepalive::new() let ka = TcpKeepalive::new().with_time(Duration::from_secs(30));
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10)) // Mirror socket2 v0.5.10 target gate for with_retries(), the stricter method.
.with_retries(3); #[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "cygwin",
))]
let ka = ka.with_interval(Duration::from_secs(10)).with_retries(3);
sock.set_tcp_keepalive(&ka)?; sock.set_tcp_keepalive(&ka)?;
sock.set_keepalive(true)?; sock.set_keepalive(true)?;
Ok(()) Ok(())
@ -299,6 +315,18 @@ impl MePool {
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
if let Some(upstream_info) = upstream_egress {
let client_ip_for_kdf = socks_bound_addr
.map(|value| value.ip())
.unwrap_or(local_addr_nat.ip());
record_upstream_bnd_status(
upstream_info.upstream_id,
bnd_addr_status,
bnd_port_status,
raw_socks_bound_addr,
Some(client_ip_for_kdf),
);
}
let (mut rd, mut wr) = tokio::io::split(stream); let (mut rd, mut wr) = tokio::io::split(stream);
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
@ -685,3 +713,66 @@ fn hex_dump(data: &[u8]) -> String {
} }
out out
} }
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn test_configure_keepalive_loopback() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => listener,
Err(error) if error.kind() == ErrorKind::PermissionDenied => return,
Err(error) => panic!("bind failed: {error}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(error) => panic!("local_addr failed: {error}"),
};
let stream = match TcpStream::connect(addr).await {
Ok(stream) => stream,
Err(error) if error.kind() == ErrorKind::PermissionDenied => return,
Err(error) => panic!("connect failed: {error}"),
};
if let Err(error) = MePool::configure_keepalive(&stream) {
if error.kind() == ErrorKind::PermissionDenied {
return;
}
panic!("configure_keepalive failed: {error}");
}
}
#[test]
#[cfg(target_os = "openbsd")]
fn test_openbsd_keepalive_cfg_path_compiles() {
let _ka = TcpKeepalive::new().with_time(Duration::from_secs(30));
}
#[test]
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "cygwin",
))]
fn test_retry_keepalive_cfg_path_compiles() {
let _ka = TcpKeepalive::new()
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10))
.with_retries(3);
}
}

View File

@ -38,7 +38,9 @@ pub use config_updater::{
me_config_updater, save_proxy_config_cache, me_config_updater, save_proxy_config_cache,
}; };
pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task};
pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot}; pub(crate) use selftest::{
bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots,
};
pub use wire::proto_flags_for_tag; pub use wire::proto_flags_for_tag;
#[derive(Debug)] #[derive(Debug)]

View File

@ -1,5 +1,5 @@
use std::collections::VecDeque; use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr; use std::net::{IpAddr, SocketAddr};
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
@ -45,6 +45,16 @@ pub(crate) struct MeBndSnapshot {
pub last_seen_age_secs: Option<u64>, pub last_seen_age_secs: Option<u64>,
} }
#[derive(Clone, Debug)]
pub(crate) struct MeUpstreamBndSnapshot {
pub upstream_id: usize,
pub addr_status: &'static str,
pub port_status: &'static str,
pub last_addr: Option<SocketAddr>,
pub last_ip: Option<IpAddr>,
pub last_seen_age_secs: Option<u64>,
}
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub(crate) struct MeTimeskewSnapshot { pub(crate) struct MeTimeskewSnapshot {
pub max_skew_secs_15m: Option<u64>, pub max_skew_secs_15m: Option<u64>,
@ -67,9 +77,19 @@ struct MeSelftestState {
bnd_port_status: BndPortStatus, bnd_port_status: BndPortStatus,
bnd_last_addr: Option<SocketAddr>, bnd_last_addr: Option<SocketAddr>,
bnd_last_seen_epoch_secs: Option<u64>, bnd_last_seen_epoch_secs: Option<u64>,
upstream_bnd: HashMap<usize, UpstreamBndState>,
timeskew_samples: VecDeque<MeTimeskewSample>, timeskew_samples: VecDeque<MeTimeskewSample>,
} }
#[derive(Clone, Copy, Debug)]
struct UpstreamBndState {
addr_status: BndAddrStatus,
port_status: BndPortStatus,
last_addr: Option<SocketAddr>,
last_ip: Option<IpAddr>,
last_seen_epoch_secs: Option<u64>,
}
impl Default for MeSelftestState { impl Default for MeSelftestState {
fn default() -> Self { fn default() -> Self {
Self { Self {
@ -77,6 +97,7 @@ impl Default for MeSelftestState {
bnd_port_status: BndPortStatus::Error, bnd_port_status: BndPortStatus::Error,
bnd_last_addr: None, bnd_last_addr: None,
bnd_last_seen_epoch_secs: None, bnd_last_seen_epoch_secs: None,
upstream_bnd: HashMap::new(),
timeskew_samples: VecDeque::new(), timeskew_samples: VecDeque::new(),
} }
} }
@ -126,6 +147,51 @@ pub(crate) fn bnd_snapshot() -> MeBndSnapshot {
} }
} }
pub(crate) fn record_upstream_bnd_status(
upstream_id: usize,
addr_status: BndAddrStatus,
port_status: BndPortStatus,
last_addr: Option<SocketAddr>,
last_ip: Option<IpAddr>,
) {
let now_epoch_secs = now_epoch_secs();
let Ok(mut guard) = state().lock() else {
return;
};
guard.upstream_bnd.insert(
upstream_id,
UpstreamBndState {
addr_status,
port_status,
last_addr,
last_ip,
last_seen_epoch_secs: Some(now_epoch_secs),
},
);
}
pub(crate) fn upstream_bnd_snapshots() -> Vec<MeUpstreamBndSnapshot> {
let now_epoch_secs = now_epoch_secs();
let Ok(guard) = state().lock() else {
return Vec::new();
};
let mut out = Vec::with_capacity(guard.upstream_bnd.len());
for (upstream_id, entry) in &guard.upstream_bnd {
out.push(MeUpstreamBndSnapshot {
upstream_id: *upstream_id,
addr_status: entry.addr_status.as_str(),
port_status: entry.port_status.as_str(),
last_addr: entry.last_addr,
last_ip: entry.last_ip,
last_seen_age_secs: entry
.last_seen_epoch_secs
.map(|value| now_epoch_secs.saturating_sub(value)),
});
}
out.sort_by_key(|entry| entry.upstream_id);
out
}
pub(crate) fn record_timeskew_sample(source: &'static str, skew_secs: u64) { pub(crate) fn record_timeskew_sample(source: &'static str, skew_secs: u64) {
let now_epoch_secs = now_epoch_secs(); let now_epoch_secs = now_epoch_secs();
let Ok(mut guard) = state().lock() else { let Ok(mut guard) = state().lock() else {

View File

@ -1,6 +1,8 @@
//! TCP Socket Configuration //! TCP Socket Configuration
#[cfg(target_os = "linux")]
use std::collections::HashSet; use std::collections::HashSet;
#[cfg(target_os = "linux")]
use std::fs; use std::fs;
use std::io::Result; use std::io::Result;
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
@ -44,6 +46,7 @@ pub fn configure_tcp_socket(
pub fn configure_client_socket( pub fn configure_client_socket(
stream: &TcpStream, stream: &TcpStream,
keepalive_secs: u64, keepalive_secs: u64,
#[cfg_attr(not(target_os = "linux"), allow(unused_variables))]
ack_timeout_secs: u64, ack_timeout_secs: u64,
) -> Result<()> { ) -> Result<()> {
let socket = socket2::SockRef::from(stream); let socket = socket2::SockRef::from(stream);
@ -65,17 +68,27 @@ pub fn configure_client_socket(
// is implemented in relay_bidirectional instead // is implemented in relay_bidirectional instead
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
use std::io::{Error, ErrorKind};
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd(); let fd = stream.as_raw_fd();
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int; let timeout_ms_u64 = ack_timeout_secs
unsafe { .checked_mul(1000)
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?;
let timeout_ms = i32::try_from(timeout_ms_u64)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?;
let rc = unsafe {
libc::setsockopt( libc::setsockopt(
fd, fd,
libc::IPPROTO_TCP, libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT, libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void, &timeout_ms as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t, std::mem::size_of::<libc::c_int>() as libc::socklen_t,
); )
};
if rc != 0 {
return Err(Error::last_os_error());
} }
} }
@ -373,6 +386,7 @@ fn listening_inodes_for_port(addr: SocketAddr) -> HashSet<u64> {
mod tests { mod tests {
use super::*; use super::*;
use std::io::ErrorKind; use std::io::ErrorKind;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[tokio::test] #[tokio::test]
@ -397,6 +411,142 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_configure_client_socket() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(e) => panic!("local_addr failed: {e}"),
};
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
if let Err(e) = configure_client_socket(&stream, 30, 30) {
if e.kind() == ErrorKind::PermissionDenied {
return;
}
panic!("configure_client_socket failed: {e}");
}
}
#[tokio::test]
async fn test_configure_client_socket_zero_ack_timeout() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(e) => panic!("local_addr failed: {e}"),
};
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
if let Err(e) = configure_client_socket(&stream, 30, 0) {
if e.kind() == ErrorKind::PermissionDenied {
return;
}
panic!("configure_client_socket with zero ack timeout failed: {e}");
}
}
#[tokio::test]
async fn test_configure_client_socket_roundtrip_io() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(e) => panic!("local_addr failed: {e}"),
};
let server_task = tokio::spawn(async move {
let (mut accepted, _) = match listener.accept().await {
Ok(v) => v,
Err(e) => panic!("accept failed: {e}"),
};
let mut payload = [0u8; 4];
if let Err(e) = accepted.read_exact(&mut payload).await {
panic!("server read_exact failed: {e}");
}
if let Err(e) = accepted.write_all(b"pong").await {
panic!("server write_all failed: {e}");
}
payload
});
let mut stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
if let Err(e) = configure_client_socket(&stream, 30, 30) {
if e.kind() == ErrorKind::PermissionDenied {
return;
}
panic!("configure_client_socket failed: {e}");
}
if let Err(e) = stream.write_all(b"ping").await {
panic!("client write_all failed: {e}");
}
let mut reply = [0u8; 4];
if let Err(e) = stream.read_exact(&mut reply).await {
panic!("client read_exact failed: {e}");
}
assert_eq!(&reply, b"pong");
let server_seen = match server_task.await {
Ok(value) => value,
Err(e) => panic!("server task join failed: {e}"),
};
assert_eq!(&server_seen, b"ping");
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn test_configure_client_socket_ack_timeout_overflow_rejected() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(e) => panic!("local_addr failed: {e}"),
};
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
let too_large_secs = (i32::MAX as u64 / 1000) + 1;
let err = match configure_client_socket(&stream, 30, too_large_secs) {
Ok(()) => panic!("expected overflow validation error"),
Err(e) => e,
};
assert_eq!(err.kind(), ErrorKind::InvalidInput);
}
#[test] #[test]
fn test_normalize_ip() { fn test_normalize_ip() {
// IPv4 stays IPv4 // IPv4 stays IPv4

View File

@ -213,6 +213,7 @@ pub struct UpstreamApiPolicySnapshot {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UpstreamEgressInfo { pub struct UpstreamEgressInfo {
pub upstream_id: usize,
pub route_kind: UpstreamRouteKind, pub route_kind: UpstreamRouteKind,
pub local_addr: Option<SocketAddr>, pub local_addr: Option<SocketAddr>,
pub direct_bind_ip: Option<IpAddr>, pub direct_bind_ip: Option<IpAddr>,
@ -672,7 +673,7 @@ impl UpstreamManager {
self.stats.increment_upstream_connect_attempt_total(); self.stats.increment_upstream_connect_attempt_total();
let start = Instant::now(); let start = Instant::now();
match self match self
.connect_via_upstream(&upstream, target, bind_rr.clone(), attempt_timeout) .connect_via_upstream(idx, &upstream, target, bind_rr.clone(), attempt_timeout)
.await .await
{ {
Ok((stream, egress)) => { Ok((stream, egress)) => {
@ -779,6 +780,7 @@ impl UpstreamManager {
async fn connect_via_upstream( async fn connect_via_upstream(
&self, &self,
upstream_id: usize,
config: &UpstreamConfig, config: &UpstreamConfig,
target: SocketAddr, target: SocketAddr,
bind_rr: Option<Arc<AtomicUsize>>, bind_rr: Option<Arc<AtomicUsize>>,
@ -828,6 +830,7 @@ impl UpstreamManager {
Ok(( Ok((
stream, stream,
UpstreamEgressInfo { UpstreamEgressInfo {
upstream_id,
route_kind: UpstreamRouteKind::Direct, route_kind: UpstreamRouteKind::Direct,
local_addr, local_addr,
direct_bind_ip: bind_ip, direct_bind_ip: bind_ip,
@ -906,6 +909,7 @@ impl UpstreamManager {
Ok(( Ok((
stream, stream,
UpstreamEgressInfo { UpstreamEgressInfo {
upstream_id,
route_kind: UpstreamRouteKind::Socks4, route_kind: UpstreamRouteKind::Socks4,
local_addr, local_addr,
direct_bind_ip: None, direct_bind_ip: None,
@ -986,6 +990,7 @@ impl UpstreamManager {
Ok(( Ok((
stream, stream,
UpstreamEgressInfo { UpstreamEgressInfo {
upstream_id,
route_kind: UpstreamRouteKind::Socks5, route_kind: UpstreamRouteKind::Socks5,
local_addr, local_addr,
direct_bind_ip: None, direct_bind_ip: None,
@ -1048,7 +1053,7 @@ impl UpstreamManager {
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS), Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(upstream_config, Some(bind_rr.clone()), addr_v6) self.ping_single_dc(*upstream_idx, upstream_config, Some(bind_rr.clone()), addr_v6)
).await; ).await;
let ping_result = match result { let ping_result = match result {
@ -1099,7 +1104,7 @@ impl UpstreamManager {
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS), Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(upstream_config, Some(bind_rr.clone()), addr_v4) self.ping_single_dc(*upstream_idx, upstream_config, Some(bind_rr.clone()), addr_v4)
).await; ).await;
let ping_result = match result { let ping_result = match result {
@ -1162,7 +1167,7 @@ impl UpstreamManager {
} }
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS), Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(upstream_config, Some(bind_rr.clone()), addr) self.ping_single_dc(*upstream_idx, upstream_config, Some(bind_rr.clone()), addr)
).await; ).await;
let ping_result = match result { let ping_result = match result {
@ -1233,6 +1238,7 @@ impl UpstreamManager {
async fn ping_single_dc( async fn ping_single_dc(
&self, &self,
upstream_id: usize,
config: &UpstreamConfig, config: &UpstreamConfig,
bind_rr: Option<Arc<AtomicUsize>>, bind_rr: Option<Arc<AtomicUsize>>,
target: SocketAddr, target: SocketAddr,
@ -1240,6 +1246,7 @@ impl UpstreamManager {
let start = Instant::now(); let start = Instant::now();
let _ = self let _ = self
.connect_via_upstream( .connect_via_upstream(
upstream_id,
config, config,
target, target,
bind_rr, bind_rr,
@ -1418,6 +1425,7 @@ impl UpstreamManager {
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS), Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS),
self.connect_via_upstream( self.connect_via_upstream(
i,
&config, &config,
endpoint, endpoint,
Some(bind_rr.clone()), Some(bind_rr.clone()),