diff --git a/Cargo.toml b/Cargo.toml index 16586e0..324af49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.1.6" +version = "3.2.0" edition = "2024" [dependencies] diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..bd8f892 --- /dev/null +++ b/docs/API.md @@ -0,0 +1,427 @@ +# Telemt Control API + +## Purpose +Control-plane HTTP API for runtime visibility and user/config management. +Data-plane MTProto traffic is out of scope. + +## Runtime Configuration +API runtime is configured in `[server.api]`. + +| Field | Type | Default | Description | +| --- | --- | --- | --- | +| `enabled` | `bool` | `false` | Enables REST API listener. | +| `listen` | `string` (`IP:PORT`) | `127.0.0.1:9091` | API bind address. | +| `whitelist` | `CIDR[]` | `127.0.0.1/32, ::1/128` | Source IP allowlist. Empty list means allow all. | +| `auth_header` | `string` | `""` | Exact value for `Authorization` header. Empty disables header auth. | +| `request_body_limit_bytes` | `usize` | `65536` | Maximum request body size. | +| `minimal_runtime_enabled` | `bool` | `false` | Enables runtime snapshot endpoints requiring ME pool read-lock aggregation. | +| `minimal_runtime_cache_ttl_ms` | `u64` | `1000` | Cache TTL for minimal snapshots. `0` disables cache. | +| `read_only` | `bool` | `false` | Disables mutating endpoints. | + +`server.admin_api` is accepted as an alias for backward compatibility. + +## Protocol Contract + +| Item | Value | +| --- | --- | +| Transport | HTTP/1.1 | +| Content type | `application/json; charset=utf-8` | +| Prefix | `/v1` | +| Optimistic concurrency | `If-Match: ` on mutating requests (optional) | +| Revision format | SHA-256 hex of current `config.toml` content | + +### Success Envelope +```json +{ + "ok": true, + "data": {}, + "revision": "sha256-hex" +} +``` + +### Error Envelope +```json +{ + "ok": false, + "error": { + "code": "machine_code", + "message": "human-readable" + }, + "request_id": 1 +} +``` + +## Endpoint Matrix + +| Method | Path | Body | Success | `data` contract | +| --- | --- | --- | --- | --- | +| `GET` | `/v1/health` | none | `200` | `HealthData` | +| `GET` | `/v1/stats/summary` | none | `200` | `SummaryData` | +| `GET` | `/v1/stats/zero/all` | none | `200` | `ZeroAllData` | +| `GET` | `/v1/stats/minimal/all` | none | `200` | `MinimalAllData` | +| `GET` | `/v1/stats/me-writers` | none | `200` | `MeWritersData` | +| `GET` | `/v1/stats/dcs` | none | `200` | `DcStatusData` | +| `GET` | `/v1/stats/users` | none | `200` | `UserInfo[]` | +| `GET` | `/v1/users` | none | `200` | `UserInfo[]` | +| `POST` | `/v1/users` | `CreateUserRequest` | `201` | `CreateUserResponse` | +| `GET` | `/v1/users/{username}` | none | `200` | `UserInfo` | +| `PATCH` | `/v1/users/{username}` | `PatchUserRequest` | `200` | `UserInfo` | +| `DELETE` | `/v1/users/{username}` | none | `200` | `string` (deleted username) | +| `POST` | `/v1/users/{username}/rotate-secret` | `RotateSecretRequest` or empty body | `200` | `CreateUserResponse` | + +## Common Error Codes + +| HTTP | `error.code` | Trigger | +| --- | --- | --- | +| `400` | `bad_request` | Invalid JSON, validation failures, malformed request body. | +| `401` | `unauthorized` | Missing/invalid `Authorization` when `auth_header` is configured. | +| `403` | `forbidden` | Source IP is not allowed by whitelist. | +| `403` | `read_only` | Mutating endpoint called while `read_only=true`. | +| `404` | `not_found` | Unknown route or unknown user. | +| `405` | `method_not_allowed` | Unsupported method for an existing user route. | +| `409` | `revision_conflict` | `If-Match` revision mismatch. | +| `409` | `user_exists` | User already exists on create. | +| `409` | `last_user_forbidden` | Attempt to delete last configured user. | +| `413` | `payload_too_large` | Body exceeds `request_body_limit_bytes`. | +| `500` | `internal_error` | Internal error (I/O, serialization, config load/save). | +| `503` | `api_disabled` | API disabled in config. | + +## Request Contracts + +### `CreateUserRequest` +| Field | Type | Required | Description | +| --- | --- | --- | --- | +| `username` | `string` | yes | `[A-Za-z0-9_.-]`, length `1..64`. | +| `secret` | `string` | no | Exactly 32 hex chars. If missing, generated automatically. | +| `user_ad_tag` | `string` | no | Exactly 32 hex chars. | +| `max_tcp_conns` | `usize` | no | Per-user concurrent TCP limit. | +| `expiration_rfc3339` | `string` | no | RFC3339 expiration timestamp. | +| `data_quota_bytes` | `u64` | no | Per-user traffic quota. | +| `max_unique_ips` | `usize` | no | Per-user unique source IP limit. | + +### `PatchUserRequest` +| Field | Type | Required | Description | +| --- | --- | --- | --- | +| `secret` | `string` | no | Exactly 32 hex chars. | +| `user_ad_tag` | `string` | no | Exactly 32 hex chars. | +| `max_tcp_conns` | `usize` | no | Per-user concurrent TCP limit. | +| `expiration_rfc3339` | `string` | no | RFC3339 expiration timestamp. | +| `data_quota_bytes` | `u64` | no | Per-user traffic quota. | +| `max_unique_ips` | `usize` | no | Per-user unique source IP limit. | + +### `RotateSecretRequest` +| Field | Type | Required | Description | +| --- | --- | --- | --- | +| `secret` | `string` | no | Exactly 32 hex chars. If missing, generated automatically. | + +## Response Data Contracts + +### `HealthData` +| Field | Type | Description | +| --- | --- | --- | +| `status` | `string` | Always `"ok"`. | +| `read_only` | `bool` | Mirrors current API `read_only` mode. | + +### `SummaryData` +| Field | Type | Description | +| --- | --- | --- | +| `uptime_seconds` | `f64` | Process uptime in seconds. | +| `connections_total` | `u64` | Total accepted client connections. | +| `connections_bad_total` | `u64` | Failed/invalid client connections. | +| `handshake_timeouts_total` | `u64` | Handshake timeout count. | +| `configured_users` | `usize` | Number of configured users in config. | + +### `ZeroAllData` +| Field | Type | Description | +| --- | --- | --- | +| `generated_at_epoch_secs` | `u64` | Snapshot time (Unix epoch seconds). | +| `core` | `ZeroCoreData` | Core counters and telemetry policy snapshot. | +| `upstream` | `ZeroUpstreamData` | Upstream connect counters/histogram buckets. | +| `middle_proxy` | `ZeroMiddleProxyData` | ME protocol/health counters. | +| `pool` | `ZeroPoolData` | ME pool lifecycle counters. | +| `desync` | `ZeroDesyncData` | Frame desync counters. | + +#### `ZeroCoreData` +| Field | Type | Description | +| --- | --- | --- | +| `uptime_seconds` | `f64` | Process uptime. | +| `connections_total` | `u64` | Total accepted connections. | +| `connections_bad_total` | `u64` | Failed/invalid connections. | +| `handshake_timeouts_total` | `u64` | Handshake timeouts. | +| `configured_users` | `usize` | Configured user count. | +| `telemetry_core_enabled` | `bool` | Core telemetry toggle. | +| `telemetry_user_enabled` | `bool` | User telemetry toggle. | +| `telemetry_me_level` | `string` | ME telemetry level (`off|normal|verbose`). | + +#### `ZeroUpstreamData` +| Field | Type | Description | +| --- | --- | --- | +| `connect_attempt_total` | `u64` | Total upstream connect attempts. | +| `connect_success_total` | `u64` | Successful upstream connects. | +| `connect_fail_total` | `u64` | Failed upstream connects. | +| `connect_failfast_hard_error_total` | `u64` | Fail-fast hard errors. | +| `connect_attempts_bucket_1` | `u64` | Connect attempts resolved in 1 try. | +| `connect_attempts_bucket_2` | `u64` | Connect attempts resolved in 2 tries. | +| `connect_attempts_bucket_3_4` | `u64` | Connect attempts resolved in 3-4 tries. | +| `connect_attempts_bucket_gt_4` | `u64` | Connect attempts requiring more than 4 tries. | +| `connect_duration_success_bucket_le_100ms` | `u64` | Successful connects <=100 ms. | +| `connect_duration_success_bucket_101_500ms` | `u64` | Successful connects 101-500 ms. | +| `connect_duration_success_bucket_501_1000ms` | `u64` | Successful connects 501-1000 ms. | +| `connect_duration_success_bucket_gt_1000ms` | `u64` | Successful connects >1000 ms. | +| `connect_duration_fail_bucket_le_100ms` | `u64` | Failed connects <=100 ms. | +| `connect_duration_fail_bucket_101_500ms` | `u64` | Failed connects 101-500 ms. | +| `connect_duration_fail_bucket_501_1000ms` | `u64` | Failed connects 501-1000 ms. | +| `connect_duration_fail_bucket_gt_1000ms` | `u64` | Failed connects >1000 ms. | + +#### `ZeroMiddleProxyData` +| Field | Type | Description | +| --- | --- | --- | +| `keepalive_sent_total` | `u64` | ME keepalive packets sent. | +| `keepalive_failed_total` | `u64` | ME keepalive send failures. | +| `keepalive_pong_total` | `u64` | Keepalive pong responses received. | +| `keepalive_timeout_total` | `u64` | Keepalive timeout events. | +| `rpc_proxy_req_signal_sent_total` | `u64` | RPC proxy activity signals sent. | +| `rpc_proxy_req_signal_failed_total` | `u64` | RPC proxy activity signal failures. | +| `rpc_proxy_req_signal_skipped_no_meta_total` | `u64` | Signals skipped due to missing metadata. | +| `rpc_proxy_req_signal_response_total` | `u64` | RPC proxy signal responses received. | +| `rpc_proxy_req_signal_close_sent_total` | `u64` | RPC proxy close signals sent. | +| `reconnect_attempt_total` | `u64` | ME reconnect attempts. | +| `reconnect_success_total` | `u64` | Successful reconnects. | +| `handshake_reject_total` | `u64` | ME handshake rejects. | +| `handshake_error_codes` | `ZeroCodeCount[]` | Handshake rejects grouped by code. | +| `reader_eof_total` | `u64` | ME reader EOF events. | +| `idle_close_by_peer_total` | `u64` | Idle closes initiated by peer. | +| `route_drop_no_conn_total` | `u64` | Route drops due to missing bound connection. | +| `route_drop_channel_closed_total` | `u64` | Route drops due to closed channel. | +| `route_drop_queue_full_total` | `u64` | Route drops due to full queue (total). | +| `route_drop_queue_full_base_total` | `u64` | Route drops in base queue mode. | +| `route_drop_queue_full_high_total` | `u64` | Route drops in high queue mode. | +| `socks_kdf_strict_reject_total` | `u64` | SOCKS KDF strict rejects. | +| `socks_kdf_compat_fallback_total` | `u64` | SOCKS KDF compat fallbacks. | +| `endpoint_quarantine_total` | `u64` | Endpoint quarantine activations. | +| `kdf_drift_total` | `u64` | KDF drift detections. | +| `kdf_port_only_drift_total` | `u64` | KDF port-only drift detections. | +| `hardswap_pending_reuse_total` | `u64` | Pending hardswap reused events. | +| `hardswap_pending_ttl_expired_total` | `u64` | Pending hardswap TTL expiry events. | +| `single_endpoint_outage_enter_total` | `u64` | Entered single-endpoint outage mode. | +| `single_endpoint_outage_exit_total` | `u64` | Exited single-endpoint outage mode. | +| `single_endpoint_outage_reconnect_attempt_total` | `u64` | Reconnect attempts in outage mode. | +| `single_endpoint_outage_reconnect_success_total` | `u64` | Reconnect successes in outage mode. | +| `single_endpoint_quarantine_bypass_total` | `u64` | Quarantine bypasses in outage mode. | +| `single_endpoint_shadow_rotate_total` | `u64` | Shadow writer rotations. | +| `single_endpoint_shadow_rotate_skipped_quarantine_total` | `u64` | Shadow rotations skipped because of quarantine. | +| `floor_mode_switch_total` | `u64` | Total floor mode switches. | +| `floor_mode_switch_static_to_adaptive_total` | `u64` | Static -> adaptive switches. | +| `floor_mode_switch_adaptive_to_static_total` | `u64` | Adaptive -> static switches. | + +#### `ZeroCodeCount` +| Field | Type | Description | +| --- | --- | --- | +| `code` | `i32` | Handshake error code. | +| `total` | `u64` | Events with this code. | + +#### `ZeroPoolData` +| Field | Type | Description | +| --- | --- | --- | +| `pool_swap_total` | `u64` | Pool swap count. | +| `pool_drain_active` | `u64` | Current active draining pools. | +| `pool_force_close_total` | `u64` | Forced pool closes by timeout. | +| `pool_stale_pick_total` | `u64` | Stale writer picks for binding. | +| `writer_removed_total` | `u64` | Writer removals total. | +| `writer_removed_unexpected_total` | `u64` | Unexpected writer removals. | +| `refill_triggered_total` | `u64` | Refill triggers. | +| `refill_skipped_inflight_total` | `u64` | Refill skipped because refill already in-flight. | +| `refill_failed_total` | `u64` | Refill failures. | +| `writer_restored_same_endpoint_total` | `u64` | Restores on same endpoint. | +| `writer_restored_fallback_total` | `u64` | Restores on fallback endpoint. | + +#### `ZeroDesyncData` +| Field | Type | Description | +| --- | --- | --- | +| `secure_padding_invalid_total` | `u64` | Invalid secure padding events. | +| `desync_total` | `u64` | Desync events total. | +| `desync_full_logged_total` | `u64` | Fully logged desync events. | +| `desync_suppressed_total` | `u64` | Suppressed desync logs. | +| `desync_frames_bucket_0` | `u64` | Desync frames bucket 0. | +| `desync_frames_bucket_1_2` | `u64` | Desync frames bucket 1-2. | +| `desync_frames_bucket_3_10` | `u64` | Desync frames bucket 3-10. | +| `desync_frames_bucket_gt_10` | `u64` | Desync frames bucket >10. | + +### `MinimalAllData` +| Field | Type | Description | +| --- | --- | --- | +| `enabled` | `bool` | Whether minimal runtime snapshots are enabled by config. | +| `reason` | `string?` | `feature_disabled` or `source_unavailable` when applicable. | +| `generated_at_epoch_secs` | `u64` | Snapshot generation time. | +| `data` | `MinimalAllPayload?` | Null when disabled; fallback payload when source unavailable. | + +#### `MinimalAllPayload` +| Field | Type | Description | +| --- | --- | --- | +| `me_writers` | `MeWritersData` | ME writer status block. | +| `dcs` | `DcStatusData` | DC aggregate status block. | +| `me_runtime` | `MinimalMeRuntimeData?` | Runtime ME control snapshot. | +| `network_path` | `MinimalDcPathData[]` | Active IP path selection per DC. | + +#### `MinimalMeRuntimeData` +| Field | Type | Description | +| --- | --- | --- | +| `active_generation` | `u64` | Active pool generation. | +| `warm_generation` | `u64` | Warm pool generation. | +| `pending_hardswap_generation` | `u64` | Pending hardswap generation. | +| `pending_hardswap_age_secs` | `u64?` | Pending hardswap age in seconds. | +| `hardswap_enabled` | `bool` | Hardswap mode toggle. | +| `floor_mode` | `string` | Writer floor mode. | +| `adaptive_floor_idle_secs` | `u64` | Idle threshold for adaptive floor. | +| `adaptive_floor_min_writers_single_endpoint` | `u8` | Minimum writers for single-endpoint DC in adaptive mode. | +| `adaptive_floor_recover_grace_secs` | `u64` | Grace period for floor recovery. | +| `me_keepalive_enabled` | `bool` | ME keepalive toggle. | +| `me_keepalive_interval_secs` | `u64` | Keepalive period. | +| `me_keepalive_jitter_secs` | `u64` | Keepalive jitter. | +| `me_keepalive_payload_random` | `bool` | Randomized keepalive payload toggle. | +| `rpc_proxy_req_every_secs` | `u64` | Period for RPC proxy request signal. | +| `me_reconnect_max_concurrent_per_dc` | `u32` | Reconnect concurrency per DC. | +| `me_reconnect_backoff_base_ms` | `u64` | Base reconnect backoff. | +| `me_reconnect_backoff_cap_ms` | `u64` | Max reconnect backoff. | +| `me_reconnect_fast_retry_count` | `u32` | Fast retry attempts before normal backoff. | +| `me_pool_drain_ttl_secs` | `u64` | Pool drain TTL. | +| `me_pool_force_close_secs` | `u64` | Hard close timeout for draining writers. | +| `me_pool_min_fresh_ratio` | `f32` | Minimum fresh ratio before swap. | +| `me_bind_stale_mode` | `string` | Stale writer bind policy. | +| `me_bind_stale_ttl_secs` | `u64` | Stale writer TTL. | +| `me_single_endpoint_shadow_writers` | `u8` | Shadow writers for single-endpoint DCs. | +| `me_single_endpoint_outage_mode_enabled` | `bool` | Outage mode toggle for single-endpoint DCs. | +| `me_single_endpoint_outage_disable_quarantine` | `bool` | Quarantine behavior in outage mode. | +| `me_single_endpoint_outage_backoff_min_ms` | `u64` | Outage mode min reconnect backoff. | +| `me_single_endpoint_outage_backoff_max_ms` | `u64` | Outage mode max reconnect backoff. | +| `me_single_endpoint_shadow_rotate_every_secs` | `u64` | Shadow rotation interval. | +| `me_deterministic_writer_sort` | `bool` | Deterministic writer ordering toggle. | +| `me_socks_kdf_policy` | `string` | Current SOCKS KDF policy mode. | +| `quarantined_endpoints_total` | `usize` | Total quarantined endpoints. | +| `quarantined_endpoints` | `MinimalQuarantineData[]` | Quarantine details. | + +#### `MinimalQuarantineData` +| Field | Type | Description | +| --- | --- | --- | +| `endpoint` | `string` | Endpoint (`ip:port`). | +| `remaining_ms` | `u64` | Remaining quarantine duration. | + +#### `MinimalDcPathData` +| Field | Type | Description | +| --- | --- | --- | +| `dc` | `i16` | Telegram DC identifier. | +| `ip_preference` | `string?` | Runtime IP family preference. | +| `selected_addr_v4` | `string?` | Selected IPv4 endpoint for this DC. | +| `selected_addr_v6` | `string?` | Selected IPv6 endpoint for this DC. | + +### `MeWritersData` +| Field | Type | Description | +| --- | --- | --- | +| `middle_proxy_enabled` | `bool` | `false` when minimal runtime is disabled or source unavailable. | +| `reason` | `string?` | `feature_disabled` or `source_unavailable` when not fully available. | +| `generated_at_epoch_secs` | `u64` | Snapshot generation time. | +| `summary` | `MeWritersSummary` | Coverage/availability summary. | +| `writers` | `MeWriterStatus[]` | Per-writer statuses. | + +#### `MeWritersSummary` +| Field | Type | Description | +| --- | --- | --- | +| `configured_dc_groups` | `usize` | Number of configured DC groups. | +| `configured_endpoints` | `usize` | Total configured ME endpoints. | +| `available_endpoints` | `usize` | Endpoints currently available. | +| `available_pct` | `f64` | `available_endpoints / configured_endpoints * 100`. | +| `required_writers` | `usize` | Required writers based on current floor policy. | +| `alive_writers` | `usize` | Writers currently alive. | +| `coverage_pct` | `f64` | `alive_writers / required_writers * 100`. | + +#### `MeWriterStatus` +| Field | Type | Description | +| --- | --- | --- | +| `writer_id` | `u64` | Runtime writer identifier. | +| `dc` | `i16?` | DC id if mapped. | +| `endpoint` | `string` | Endpoint (`ip:port`). | +| `generation` | `u64` | Pool generation owning this writer. | +| `state` | `string` | Writer state (`warm`, `active`, `draining`). | +| `draining` | `bool` | Draining flag. | +| `degraded` | `bool` | Degraded flag. | +| `bound_clients` | `usize` | Number of currently bound clients. | +| `idle_for_secs` | `u64?` | Idle age in seconds if idle. | +| `rtt_ema_ms` | `f64?` | RTT exponential moving average. | + +### `DcStatusData` +| Field | Type | Description | +| --- | --- | --- | +| `middle_proxy_enabled` | `bool` | `false` when minimal runtime is disabled or source unavailable. | +| `reason` | `string?` | `feature_disabled` or `source_unavailable` when not fully available. | +| `generated_at_epoch_secs` | `u64` | Snapshot generation time. | +| `dcs` | `DcStatus[]` | Per-DC status rows. | + +#### `DcStatus` +| Field | Type | Description | +| --- | --- | --- | +| `dc` | `i16` | Telegram DC id. | +| `endpoints` | `string[]` | Endpoints in this DC (`ip:port`). | +| `available_endpoints` | `usize` | Endpoints currently available in this DC. | +| `available_pct` | `f64` | `available_endpoints / endpoints_total * 100`. | +| `required_writers` | `usize` | Required writer count for this DC. | +| `alive_writers` | `usize` | Alive writers in this DC. | +| `coverage_pct` | `f64` | `alive_writers / required_writers * 100`. | +| `rtt_ms` | `f64?` | Aggregated RTT for DC. | +| `load` | `usize` | Active client sessions bound to this DC. | + +### `UserInfo` +| Field | Type | Description | +| --- | --- | --- | +| `username` | `string` | Username. | +| `user_ad_tag` | `string?` | Optional ad tag (32 hex chars). | +| `max_tcp_conns` | `usize?` | Optional max concurrent TCP limit. | +| `expiration_rfc3339` | `string?` | Optional expiration timestamp. | +| `data_quota_bytes` | `u64?` | Optional data quota. | +| `max_unique_ips` | `usize?` | Optional unique IP limit. | +| `current_connections` | `u64` | Current live connections. | +| `active_unique_ips` | `usize` | Current active unique source IPs. | +| `total_octets` | `u64` | Total traffic octets for this user. | +| `links` | `UserLinks` | Active connection links derived from current config. | + +#### `UserLinks` +| Field | Type | Description | +| --- | --- | --- | +| `classic` | `string[]` | Active `tg://proxy` links for classic mode. | +| `secure` | `string[]` | Active `tg://proxy` links for secure/DD mode. | +| `tls` | `string[]` | Active `tg://proxy` links for EE-TLS mode (for each host+TLS domain). | + +Link generation uses active config and enabled modes: +- `[general.links].public_host/public_port` have priority. +- Fallback host sources: listener `announce`, `announce_ip`, explicit listener `ip`. +- Legacy fallback: `listen_addr_ipv4` and `listen_addr_ipv6` when routable. + +### `CreateUserResponse` +| Field | Type | Description | +| --- | --- | --- | +| `user` | `UserInfo` | Created or updated user view. | +| `secret` | `string` | Effective user secret. | + +## Mutation Semantics + +| Endpoint | Notes | +| --- | --- | +| `POST /v1/users` | Creates user and validates resulting config before atomic save. | +| `PATCH /v1/users/{username}` | Partial update of provided fields only. Missing fields remain unchanged. | +| `POST /v1/users/{username}/rotate-secret` | Replaces secret. Empty body is allowed and auto-generates secret. | +| `DELETE /v1/users/{username}` | Deletes user and related optional settings. Last user deletion is blocked. | + +All mutating endpoints: +- Respect `read_only` mode. +- Accept optional `If-Match` for optimistic concurrency. +- Return new `revision` after successful write. + +## Operational Notes + +| Topic | Details | +| --- | --- | +| API startup | API binds only when `[server.api].enabled=true`. | +| Restart requirements | Changes in `server.api` settings require process restart. | +| Runtime apply path | Successful writes are picked up by existing config watcher/hot-reload path. | +| Exposure | Built-in TLS/mTLS is not provided. Use loopback bind + reverse proxy if needed. | +| Pagination | User list currently has no pagination/filtering. | +| Serialization side effect | Config comments/manual formatting are not preserved on write. | diff --git a/src/api/config_store.rs b/src/api/config_store.rs new file mode 100644 index 0000000..e7fbbca --- /dev/null +++ b/src/api/config_store.rs @@ -0,0 +1,107 @@ +use std::io::Write; +use std::path::{Path, PathBuf}; + +use hyper::header::IF_MATCH; +use sha2::{Digest, Sha256}; + +use crate::config::ProxyConfig; + +use super::model::ApiFailure; + +pub(super) fn parse_if_match(headers: &hyper::HeaderMap) -> Option { + headers + .get(IF_MATCH) + .and_then(|value| value.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.trim_matches('"').to_string()) +} + +pub(super) async fn ensure_expected_revision( + config_path: &Path, + expected_revision: Option<&str>, +) -> Result<(), ApiFailure> { + let Some(expected) = expected_revision else { + return Ok(()); + }; + let current = current_revision(config_path).await?; + if current != expected { + return Err(ApiFailure::new( + hyper::StatusCode::CONFLICT, + "revision_conflict", + "Config revision mismatch", + )); + } + Ok(()) +} + +pub(super) async fn current_revision(config_path: &Path) -> Result { + let content = tokio::fs::read_to_string(config_path) + .await + .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; + Ok(compute_revision(&content)) +} + +pub(super) fn compute_revision(content: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(content.as_bytes()); + hex::encode(hasher.finalize()) +} + +pub(super) async fn load_config_from_disk(config_path: &Path) -> Result { + let config_path = config_path.to_path_buf(); + tokio::task::spawn_blocking(move || ProxyConfig::load(config_path)) + .await + .map_err(|e| ApiFailure::internal(format!("failed to join config loader: {}", e)))? + .map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e))) +} + +pub(super) async fn save_config_to_disk( + config_path: &Path, + cfg: &ProxyConfig, +) -> Result { + let serialized = toml::to_string_pretty(cfg) + .map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?; + write_atomic(config_path.to_path_buf(), serialized.clone()).await?; + Ok(compute_revision(&serialized)) +} + +async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> { + tokio::task::spawn_blocking(move || write_atomic_sync(&path, &contents)) + .await + .map_err(|e| ApiFailure::internal(format!("failed to join writer: {}", e)))? + .map_err(|e| ApiFailure::internal(format!("failed to write config: {}", e))) +} + +fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> { + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + std::fs::create_dir_all(parent)?; + + let tmp_name = format!( + ".{}.tmp-{}", + path.file_name() + .and_then(|s| s.to_str()) + .unwrap_or("config.toml"), + rand::random::() + ); + let tmp_path = parent.join(tmp_name); + + let write_result = (|| { + let mut file = std::fs::OpenOptions::new() + .create_new(true) + .write(true) + .open(&tmp_path)?; + file.write_all(contents.as_bytes())?; + file.sync_all()?; + std::fs::rename(&tmp_path, path)?; + if let Ok(dir) = std::fs::File::open(parent) { + let _ = dir.sync_all(); + } + Ok(()) + })(); + + if write_result.is_err() { + let _ = std::fs::remove_file(&tmp_path); + } + write_result +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..299d5a1 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,414 @@ +use std::convert::Infallible; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use http_body_util::{BodyExt, Full}; +use hyper::body::{Bytes, Incoming}; +use hyper::header::AUTHORIZATION; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Method, Request, Response, StatusCode}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, watch}; +use tracing::{debug, info, warn}; + +use crate::config::ProxyConfig; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; +use crate::transport::middle_proxy::MePool; + +mod config_store; +mod model; +mod runtime_stats; +mod users; + +use config_store::{current_revision, parse_if_match}; +use model::{ + ApiFailure, CreateUserRequest, ErrorBody, ErrorResponse, HealthData, PatchUserRequest, + RotateSecretRequest, SuccessResponse, SummaryData, +}; +use runtime_stats::{ + MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data, + build_zero_all_data, +}; +use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; + +#[derive(Clone)] +pub(super) struct ApiShared { + pub(super) stats: Arc, + pub(super) ip_tracker: Arc, + pub(super) me_pool: Option>, + pub(super) config_path: PathBuf, + pub(super) mutation_lock: Arc>, + pub(super) minimal_cache: Arc>>, + pub(super) request_id: Arc, +} + +impl ApiShared { + fn next_request_id(&self) -> u64 { + self.request_id.fetch_add(1, Ordering::Relaxed) + } +} + +pub async fn serve( + listen: SocketAddr, + stats: Arc, + ip_tracker: Arc, + me_pool: Option>, + config_rx: watch::Receiver>, + config_path: PathBuf, +) { + let listener = match TcpListener::bind(listen).await { + Ok(listener) => listener, + Err(error) => { + warn!( + error = %error, + listen = %listen, + "Failed to bind API listener" + ); + return; + } + }; + + info!("API endpoint: http://{}/v1/*", listen); + + let shared = Arc::new(ApiShared { + stats, + ip_tracker, + me_pool, + config_path, + mutation_lock: Arc::new(Mutex::new(())), + minimal_cache: Arc::new(Mutex::new(None)), + request_id: Arc::new(AtomicU64::new(1)), + }); + + loop { + let (stream, peer) = match listener.accept().await { + Ok(v) => v, + Err(error) => { + warn!(error = %error, "API accept error"); + continue; + } + }; + + let shared_conn = shared.clone(); + let config_rx_conn = config_rx.clone(); + tokio::spawn(async move { + let svc = service_fn(move |req: Request| { + let shared_req = shared_conn.clone(); + let config_rx_req = config_rx_conn.clone(); + async move { handle(req, peer, shared_req, config_rx_req).await } + }); + if let Err(error) = http1::Builder::new() + .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) + .await + { + debug!(error = %error, "API connection error"); + } + }); + } +} + +async fn handle( + req: Request, + peer: SocketAddr, + shared: Arc, + config_rx: watch::Receiver>, +) -> Result>, Infallible> { + let request_id = shared.next_request_id(); + let cfg = config_rx.borrow().clone(); + let api_cfg = &cfg.server.api; + + if !api_cfg.enabled { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::SERVICE_UNAVAILABLE, + "api_disabled", + "API is disabled", + ), + )); + } + + if !api_cfg.whitelist.is_empty() + && !api_cfg + .whitelist + .iter() + .any(|net| net.contains(peer.ip())) + { + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"), + )); + } + + if !api_cfg.auth_header.is_empty() { + let auth_ok = req + .headers() + .get(AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .map(|v| v == api_cfg.auth_header) + .unwrap_or(false); + if !auth_ok { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::UNAUTHORIZED, + "unauthorized", + "Missing or invalid Authorization header", + ), + )); + } + } + + let method = req.method().clone(); + let path = req.uri().path().to_string(); + let body_limit = api_cfg.request_body_limit_bytes; + + let result: Result>, ApiFailure> = async { + match (method.as_str(), path.as_str()) { + ("GET", "/v1/health") => { + let revision = current_revision(&shared.config_path).await?; + let data = HealthData { + status: "ok", + read_only: api_cfg.read_only, + }; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/summary") => { + let revision = current_revision(&shared.config_path).await?; + let data = SummaryData { + uptime_seconds: shared.stats.uptime_secs(), + connections_total: shared.stats.get_connects_all(), + connections_bad_total: shared.stats.get_connects_bad(), + handshake_timeouts_total: shared.stats.get_handshake_timeouts(), + configured_users: cfg.access.users.len(), + }; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/zero/all") => { + let revision = current_revision(&shared.config_path).await?; + let data = build_zero_all_data(&shared.stats, cfg.access.users.len()); + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/minimal/all") => { + let revision = current_revision(&shared.config_path).await?; + let data = build_minimal_all_data(shared.as_ref(), api_cfg).await; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/me-writers") => { + let revision = current_revision(&shared.config_path).await?; + let data = build_me_writers_data(shared.as_ref(), api_cfg).await; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/dcs") => { + let revision = current_revision(&shared.config_path).await?; + let data = build_dcs_data(shared.as_ref(), api_cfg).await; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/users") | ("GET", "/v1/users") => { + let revision = current_revision(&shared.config_path).await?; + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + Ok(success_response(StatusCode::OK, users, revision)) + } + ("POST", "/v1/users") => { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = read_json::(req.into_body(), body_limit).await?; + let (data, revision) = create_user(body, expected_revision, &shared).await?; + Ok(success_response(StatusCode::CREATED, data, revision)) + } + _ => { + if let Some(user) = path.strip_prefix("/v1/users/") + && !user.is_empty() + && !user.contains('/') + { + if method == Method::GET { + let revision = current_revision(&shared.config_path).await?; + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) + { + return Ok(success_response(StatusCode::OK, user_info, revision)); + } + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "User not found"), + )); + } + if method == Method::PATCH { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = read_json::(req.into_body(), body_limit).await?; + let (data, revision) = + patch_user(user, body, expected_revision, &shared).await?; + return Ok(success_response(StatusCode::OK, data, revision)); + } + if method == Method::DELETE { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let (deleted_user, revision) = + delete_user(user, expected_revision, &shared).await?; + return Ok(success_response(StatusCode::OK, deleted_user, revision)); + } + if method == Method::POST + && let Some(base_user) = user.strip_suffix("/rotate-secret") + && !base_user.is_empty() + && !base_user.contains('/') + { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = + read_optional_json::(req.into_body(), body_limit) + .await?; + let (data, revision) = + rotate_secret(base_user, body.unwrap_or_default(), expected_revision, &shared) + .await?; + return Ok(success_response(StatusCode::OK, data, revision)); + } + if method == Method::POST { + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "Route not found"), + )); + } + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::METHOD_NOT_ALLOWED, + "method_not_allowed", + "Unsupported HTTP method for this route", + ), + )); + } + Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "Route not found"), + )) + } + } + } + .await; + + match result { + Ok(resp) => Ok(resp), + Err(error) => Ok(error_response(request_id, error)), + } +} + +fn success_response( + status: StatusCode, + data: T, + revision: String, +) -> Response> { + let payload = SuccessResponse { + ok: true, + data, + revision, + }; + let body = serde_json::to_vec(&payload).unwrap_or_else(|_| b"{\"ok\":false}".to_vec()); + Response::builder() + .status(status) + .header("content-type", "application/json; charset=utf-8") + .body(Full::new(Bytes::from(body))) + .unwrap() +} + +fn error_response(request_id: u64, failure: ApiFailure) -> Response> { + let payload = ErrorResponse { + ok: false, + error: ErrorBody { + code: failure.code, + message: failure.message, + }, + request_id, + }; + let body = serde_json::to_vec(&payload).unwrap_or_else(|_| { + format!( + "{{\"ok\":false,\"error\":{{\"code\":\"internal_error\",\"message\":\"serialization failed\"}},\"request_id\":{}}}", + request_id + ) + .into_bytes() + }); + Response::builder() + .status(failure.status) + .header("content-type", "application/json; charset=utf-8") + .body(Full::new(Bytes::from(body))) + .unwrap() +} + +async fn read_json(body: Incoming, limit: usize) -> Result { + let bytes = read_body_with_limit(body, limit).await?; + serde_json::from_slice(&bytes).map_err(|_| ApiFailure::bad_request("Invalid JSON body")) +} + +async fn read_optional_json( + body: Incoming, + limit: usize, +) -> Result, ApiFailure> { + let bytes = read_body_with_limit(body, limit).await?; + if bytes.is_empty() { + return Ok(None); + } + serde_json::from_slice(&bytes) + .map(Some) + .map_err(|_| ApiFailure::bad_request("Invalid JSON body")) +} + +async fn read_body_with_limit(body: Incoming, limit: usize) -> Result, ApiFailure> { + let mut collected = Vec::new(); + let mut body = body; + while let Some(frame_result) = body.frame().await { + let frame = frame_result.map_err(|_| ApiFailure::bad_request("Invalid request body"))?; + if let Some(chunk) = frame.data_ref() { + if collected.len().saturating_add(chunk.len()) > limit { + return Err(ApiFailure::new( + StatusCode::PAYLOAD_TOO_LARGE, + "payload_too_large", + format!("Body exceeds {} bytes", limit), + )); + } + collected.extend_from_slice(chunk); + } + } + Ok(collected) +} diff --git a/src/api/model.rs b/src/api/model.rs new file mode 100644 index 0000000..be76c4e --- /dev/null +++ b/src/api/model.rs @@ -0,0 +1,395 @@ +use chrono::{DateTime, Utc}; +use hyper::StatusCode; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +const MAX_USERNAME_LEN: usize = 64; + +#[derive(Debug)] +pub(super) struct ApiFailure { + pub(super) status: StatusCode, + pub(super) code: &'static str, + pub(super) message: String, +} + +impl ApiFailure { + pub(super) fn new(status: StatusCode, code: &'static str, message: impl Into) -> Self { + Self { + status, + code, + message: message.into(), + } + } + + pub(super) fn internal(message: impl Into) -> Self { + Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message) + } + + pub(super) fn bad_request(message: impl Into) -> Self { + Self::new(StatusCode::BAD_REQUEST, "bad_request", message) + } +} + +#[derive(Serialize)] +pub(super) struct ErrorBody { + pub(super) code: &'static str, + pub(super) message: String, +} + +#[derive(Serialize)] +pub(super) struct ErrorResponse { + pub(super) ok: bool, + pub(super) error: ErrorBody, + pub(super) request_id: u64, +} + +#[derive(Serialize)] +pub(super) struct SuccessResponse { + pub(super) ok: bool, + pub(super) data: T, + pub(super) revision: String, +} + +#[derive(Serialize)] +pub(super) struct HealthData { + pub(super) status: &'static str, + pub(super) read_only: bool, +} + +#[derive(Serialize)] +pub(super) struct SummaryData { + pub(super) uptime_seconds: f64, + pub(super) connections_total: u64, + pub(super) connections_bad_total: u64, + pub(super) handshake_timeouts_total: u64, + pub(super) configured_users: usize, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroCodeCount { + pub(super) code: i32, + pub(super) total: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroCoreData { + pub(super) uptime_seconds: f64, + pub(super) connections_total: u64, + pub(super) connections_bad_total: u64, + pub(super) handshake_timeouts_total: u64, + pub(super) configured_users: usize, + pub(super) telemetry_core_enabled: bool, + pub(super) telemetry_user_enabled: bool, + pub(super) telemetry_me_level: String, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroUpstreamData { + pub(super) connect_attempt_total: u64, + pub(super) connect_success_total: u64, + pub(super) connect_fail_total: u64, + pub(super) connect_failfast_hard_error_total: u64, + pub(super) connect_attempts_bucket_1: u64, + pub(super) connect_attempts_bucket_2: u64, + pub(super) connect_attempts_bucket_3_4: u64, + pub(super) connect_attempts_bucket_gt_4: u64, + pub(super) connect_duration_success_bucket_le_100ms: u64, + pub(super) connect_duration_success_bucket_101_500ms: u64, + pub(super) connect_duration_success_bucket_501_1000ms: u64, + pub(super) connect_duration_success_bucket_gt_1000ms: u64, + pub(super) connect_duration_fail_bucket_le_100ms: u64, + pub(super) connect_duration_fail_bucket_101_500ms: u64, + pub(super) connect_duration_fail_bucket_501_1000ms: u64, + pub(super) connect_duration_fail_bucket_gt_1000ms: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroMiddleProxyData { + pub(super) keepalive_sent_total: u64, + pub(super) keepalive_failed_total: u64, + pub(super) keepalive_pong_total: u64, + pub(super) keepalive_timeout_total: u64, + pub(super) rpc_proxy_req_signal_sent_total: u64, + pub(super) rpc_proxy_req_signal_failed_total: u64, + pub(super) rpc_proxy_req_signal_skipped_no_meta_total: u64, + pub(super) rpc_proxy_req_signal_response_total: u64, + pub(super) rpc_proxy_req_signal_close_sent_total: u64, + pub(super) reconnect_attempt_total: u64, + pub(super) reconnect_success_total: u64, + pub(super) handshake_reject_total: u64, + pub(super) handshake_error_codes: Vec, + pub(super) reader_eof_total: u64, + pub(super) idle_close_by_peer_total: u64, + pub(super) route_drop_no_conn_total: u64, + pub(super) route_drop_channel_closed_total: u64, + pub(super) route_drop_queue_full_total: u64, + pub(super) route_drop_queue_full_base_total: u64, + pub(super) route_drop_queue_full_high_total: u64, + pub(super) socks_kdf_strict_reject_total: u64, + pub(super) socks_kdf_compat_fallback_total: u64, + pub(super) endpoint_quarantine_total: u64, + pub(super) kdf_drift_total: u64, + pub(super) kdf_port_only_drift_total: u64, + pub(super) hardswap_pending_reuse_total: u64, + pub(super) hardswap_pending_ttl_expired_total: u64, + pub(super) single_endpoint_outage_enter_total: u64, + pub(super) single_endpoint_outage_exit_total: u64, + pub(super) single_endpoint_outage_reconnect_attempt_total: u64, + pub(super) single_endpoint_outage_reconnect_success_total: u64, + pub(super) single_endpoint_quarantine_bypass_total: u64, + pub(super) single_endpoint_shadow_rotate_total: u64, + pub(super) single_endpoint_shadow_rotate_skipped_quarantine_total: u64, + pub(super) floor_mode_switch_total: u64, + pub(super) floor_mode_switch_static_to_adaptive_total: u64, + pub(super) floor_mode_switch_adaptive_to_static_total: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroPoolData { + pub(super) pool_swap_total: u64, + pub(super) pool_drain_active: u64, + pub(super) pool_force_close_total: u64, + pub(super) pool_stale_pick_total: u64, + pub(super) writer_removed_total: u64, + pub(super) writer_removed_unexpected_total: u64, + pub(super) refill_triggered_total: u64, + pub(super) refill_skipped_inflight_total: u64, + pub(super) refill_failed_total: u64, + pub(super) writer_restored_same_endpoint_total: u64, + pub(super) writer_restored_fallback_total: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroDesyncData { + pub(super) secure_padding_invalid_total: u64, + pub(super) desync_total: u64, + pub(super) desync_full_logged_total: u64, + pub(super) desync_suppressed_total: u64, + pub(super) desync_frames_bucket_0: u64, + pub(super) desync_frames_bucket_1_2: u64, + pub(super) desync_frames_bucket_3_10: u64, + pub(super) desync_frames_bucket_gt_10: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct ZeroAllData { + pub(super) generated_at_epoch_secs: u64, + pub(super) core: ZeroCoreData, + pub(super) upstream: ZeroUpstreamData, + pub(super) middle_proxy: ZeroMiddleProxyData, + pub(super) pool: ZeroPoolData, + pub(super) desync: ZeroDesyncData, +} + +#[derive(Serialize, Clone)] +pub(super) struct MeWritersSummary { + pub(super) configured_dc_groups: usize, + pub(super) configured_endpoints: usize, + pub(super) available_endpoints: usize, + pub(super) available_pct: f64, + pub(super) required_writers: usize, + pub(super) alive_writers: usize, + pub(super) coverage_pct: f64, +} + +#[derive(Serialize, Clone)] +pub(super) struct MeWriterStatus { + pub(super) writer_id: u64, + pub(super) dc: Option, + pub(super) endpoint: String, + pub(super) generation: u64, + pub(super) state: &'static str, + pub(super) draining: bool, + pub(super) degraded: bool, + pub(super) bound_clients: usize, + pub(super) idle_for_secs: Option, + pub(super) rtt_ema_ms: Option, +} + +#[derive(Serialize, Clone)] +pub(super) struct MeWritersData { + pub(super) middle_proxy_enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) reason: Option<&'static str>, + pub(super) generated_at_epoch_secs: u64, + pub(super) summary: MeWritersSummary, + pub(super) writers: Vec, +} + +#[derive(Serialize, Clone)] +pub(super) struct DcStatus { + pub(super) dc: i16, + pub(super) endpoints: Vec, + pub(super) available_endpoints: usize, + pub(super) available_pct: f64, + pub(super) required_writers: usize, + pub(super) alive_writers: usize, + pub(super) coverage_pct: f64, + pub(super) rtt_ms: Option, + pub(super) load: usize, +} + +#[derive(Serialize, Clone)] +pub(super) struct DcStatusData { + pub(super) middle_proxy_enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) reason: Option<&'static str>, + pub(super) generated_at_epoch_secs: u64, + pub(super) dcs: Vec, +} + +#[derive(Serialize, Clone)] +pub(super) struct MinimalQuarantineData { + pub(super) endpoint: String, + pub(super) remaining_ms: u64, +} + +#[derive(Serialize, Clone)] +pub(super) struct MinimalDcPathData { + pub(super) dc: i16, + pub(super) ip_preference: Option<&'static str>, + pub(super) selected_addr_v4: Option, + pub(super) selected_addr_v6: Option, +} + +#[derive(Serialize, Clone)] +pub(super) struct MinimalMeRuntimeData { + pub(super) active_generation: u64, + pub(super) warm_generation: u64, + pub(super) pending_hardswap_generation: u64, + pub(super) pending_hardswap_age_secs: Option, + pub(super) hardswap_enabled: bool, + pub(super) floor_mode: &'static str, + pub(super) adaptive_floor_idle_secs: u64, + pub(super) adaptive_floor_min_writers_single_endpoint: u8, + pub(super) adaptive_floor_recover_grace_secs: u64, + pub(super) me_keepalive_enabled: bool, + pub(super) me_keepalive_interval_secs: u64, + pub(super) me_keepalive_jitter_secs: u64, + pub(super) me_keepalive_payload_random: bool, + pub(super) rpc_proxy_req_every_secs: u64, + pub(super) me_reconnect_max_concurrent_per_dc: u32, + pub(super) me_reconnect_backoff_base_ms: u64, + pub(super) me_reconnect_backoff_cap_ms: u64, + pub(super) me_reconnect_fast_retry_count: u32, + pub(super) me_pool_drain_ttl_secs: u64, + pub(super) me_pool_force_close_secs: u64, + pub(super) me_pool_min_fresh_ratio: f32, + pub(super) me_bind_stale_mode: &'static str, + pub(super) me_bind_stale_ttl_secs: u64, + pub(super) me_single_endpoint_shadow_writers: u8, + pub(super) me_single_endpoint_outage_mode_enabled: bool, + pub(super) me_single_endpoint_outage_disable_quarantine: bool, + pub(super) me_single_endpoint_outage_backoff_min_ms: u64, + pub(super) me_single_endpoint_outage_backoff_max_ms: u64, + pub(super) me_single_endpoint_shadow_rotate_every_secs: u64, + pub(super) me_deterministic_writer_sort: bool, + pub(super) me_socks_kdf_policy: &'static str, + pub(super) quarantined_endpoints_total: usize, + pub(super) quarantined_endpoints: Vec, +} + +#[derive(Serialize, Clone)] +pub(super) struct MinimalAllPayload { + pub(super) me_writers: MeWritersData, + pub(super) dcs: DcStatusData, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) me_runtime: Option, + pub(super) network_path: Vec, +} + +#[derive(Serialize, Clone)] +pub(super) struct MinimalAllData { + pub(super) enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) reason: Option<&'static str>, + pub(super) generated_at_epoch_secs: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) data: Option, +} + +#[derive(Serialize)] +pub(super) struct UserLinks { + pub(super) classic: Vec, + pub(super) secure: Vec, + pub(super) tls: Vec, +} + +#[derive(Serialize)] +pub(super) struct UserInfo { + pub(super) username: String, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, + pub(super) current_connections: u64, + pub(super) active_unique_ips: usize, + pub(super) total_octets: u64, + pub(super) links: UserLinks, +} + +#[derive(Serialize)] +pub(super) struct CreateUserResponse { + pub(super) user: UserInfo, + pub(super) secret: String, +} + +#[derive(Deserialize)] +pub(super) struct CreateUserRequest { + pub(super) username: String, + pub(super) secret: Option, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, +} + +#[derive(Deserialize)] +pub(super) struct PatchUserRequest { + pub(super) secret: Option, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, +} + +#[derive(Default, Deserialize)] +pub(super) struct RotateSecretRequest { + pub(super) secret: Option, +} + +pub(super) fn parse_optional_expiration( + value: Option<&str>, +) -> Result>, ApiFailure> { + let Some(raw) = value else { + return Ok(None); + }; + let parsed = DateTime::parse_from_rfc3339(raw) + .map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?; + Ok(Some(parsed.with_timezone(&Utc))) +} + +pub(super) fn is_valid_user_secret(secret: &str) -> bool { + secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub(super) fn is_valid_ad_tag(tag: &str) -> bool { + tag.len() == 32 && tag.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub(super) fn is_valid_username(user: &str) -> bool { + !user.is_empty() + && user.len() <= MAX_USERNAME_LEN + && user + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.')) +} + +pub(super) fn random_user_secret() -> String { + let mut bytes = [0u8; 16]; + rand::rng().fill(&mut bytes); + hex::encode(bytes) +} diff --git a/src/api/runtime_stats.rs b/src/api/runtime_stats.rs new file mode 100644 index 0000000..53fdeff --- /dev/null +++ b/src/api/runtime_stats.rs @@ -0,0 +1,392 @@ +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use crate::config::ApiConfig; +use crate::stats::Stats; + +use super::ApiShared; +use super::model::{ + DcStatus, DcStatusData, MeWriterStatus, MeWritersData, MeWritersSummary, MinimalAllData, + MinimalAllPayload, MinimalDcPathData, MinimalMeRuntimeData, MinimalQuarantineData, + ZeroAllData, ZeroCodeCount, ZeroCoreData, ZeroDesyncData, ZeroMiddleProxyData, ZeroPoolData, + ZeroUpstreamData, +}; + +const FEATURE_DISABLED_REASON: &str = "feature_disabled"; +const SOURCE_UNAVAILABLE_REASON: &str = "source_unavailable"; + +#[derive(Clone)] +pub(crate) struct MinimalCacheEntry { + pub(super) expires_at: Instant, + pub(super) payload: MinimalAllPayload, + pub(super) generated_at_epoch_secs: u64, +} + +pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> ZeroAllData { + let telemetry = stats.telemetry_policy(); + let handshake_error_codes = stats + .get_me_handshake_error_code_counts() + .into_iter() + .map(|(code, total)| ZeroCodeCount { code, total }) + .collect(); + + ZeroAllData { + generated_at_epoch_secs: now_epoch_secs(), + core: ZeroCoreData { + uptime_seconds: stats.uptime_secs(), + connections_total: stats.get_connects_all(), + connections_bad_total: stats.get_connects_bad(), + handshake_timeouts_total: stats.get_handshake_timeouts(), + configured_users, + telemetry_core_enabled: telemetry.core_enabled, + telemetry_user_enabled: telemetry.user_enabled, + telemetry_me_level: telemetry.me_level.to_string(), + }, + upstream: ZeroUpstreamData { + connect_attempt_total: stats.get_upstream_connect_attempt_total(), + connect_success_total: stats.get_upstream_connect_success_total(), + connect_fail_total: stats.get_upstream_connect_fail_total(), + connect_failfast_hard_error_total: stats.get_upstream_connect_failfast_hard_error_total(), + connect_attempts_bucket_1: stats.get_upstream_connect_attempts_bucket_1(), + connect_attempts_bucket_2: stats.get_upstream_connect_attempts_bucket_2(), + connect_attempts_bucket_3_4: stats.get_upstream_connect_attempts_bucket_3_4(), + connect_attempts_bucket_gt_4: stats.get_upstream_connect_attempts_bucket_gt_4(), + connect_duration_success_bucket_le_100ms: stats + .get_upstream_connect_duration_success_bucket_le_100ms(), + connect_duration_success_bucket_101_500ms: stats + .get_upstream_connect_duration_success_bucket_101_500ms(), + connect_duration_success_bucket_501_1000ms: stats + .get_upstream_connect_duration_success_bucket_501_1000ms(), + connect_duration_success_bucket_gt_1000ms: stats + .get_upstream_connect_duration_success_bucket_gt_1000ms(), + connect_duration_fail_bucket_le_100ms: stats + .get_upstream_connect_duration_fail_bucket_le_100ms(), + connect_duration_fail_bucket_101_500ms: stats + .get_upstream_connect_duration_fail_bucket_101_500ms(), + connect_duration_fail_bucket_501_1000ms: stats + .get_upstream_connect_duration_fail_bucket_501_1000ms(), + connect_duration_fail_bucket_gt_1000ms: stats + .get_upstream_connect_duration_fail_bucket_gt_1000ms(), + }, + middle_proxy: ZeroMiddleProxyData { + keepalive_sent_total: stats.get_me_keepalive_sent(), + keepalive_failed_total: stats.get_me_keepalive_failed(), + keepalive_pong_total: stats.get_me_keepalive_pong(), + keepalive_timeout_total: stats.get_me_keepalive_timeout(), + rpc_proxy_req_signal_sent_total: stats.get_me_rpc_proxy_req_signal_sent_total(), + rpc_proxy_req_signal_failed_total: stats.get_me_rpc_proxy_req_signal_failed_total(), + rpc_proxy_req_signal_skipped_no_meta_total: stats + .get_me_rpc_proxy_req_signal_skipped_no_meta_total(), + rpc_proxy_req_signal_response_total: stats.get_me_rpc_proxy_req_signal_response_total(), + rpc_proxy_req_signal_close_sent_total: stats + .get_me_rpc_proxy_req_signal_close_sent_total(), + reconnect_attempt_total: stats.get_me_reconnect_attempts(), + reconnect_success_total: stats.get_me_reconnect_success(), + handshake_reject_total: stats.get_me_handshake_reject_total(), + handshake_error_codes, + reader_eof_total: stats.get_me_reader_eof_total(), + idle_close_by_peer_total: stats.get_me_idle_close_by_peer_total(), + route_drop_no_conn_total: stats.get_me_route_drop_no_conn(), + route_drop_channel_closed_total: stats.get_me_route_drop_channel_closed(), + route_drop_queue_full_total: stats.get_me_route_drop_queue_full(), + route_drop_queue_full_base_total: stats.get_me_route_drop_queue_full_base(), + route_drop_queue_full_high_total: stats.get_me_route_drop_queue_full_high(), + socks_kdf_strict_reject_total: stats.get_me_socks_kdf_strict_reject(), + socks_kdf_compat_fallback_total: stats.get_me_socks_kdf_compat_fallback(), + endpoint_quarantine_total: stats.get_me_endpoint_quarantine_total(), + kdf_drift_total: stats.get_me_kdf_drift_total(), + kdf_port_only_drift_total: stats.get_me_kdf_port_only_drift_total(), + hardswap_pending_reuse_total: stats.get_me_hardswap_pending_reuse_total(), + hardswap_pending_ttl_expired_total: stats.get_me_hardswap_pending_ttl_expired_total(), + single_endpoint_outage_enter_total: stats.get_me_single_endpoint_outage_enter_total(), + single_endpoint_outage_exit_total: stats.get_me_single_endpoint_outage_exit_total(), + single_endpoint_outage_reconnect_attempt_total: stats + .get_me_single_endpoint_outage_reconnect_attempt_total(), + single_endpoint_outage_reconnect_success_total: stats + .get_me_single_endpoint_outage_reconnect_success_total(), + single_endpoint_quarantine_bypass_total: stats + .get_me_single_endpoint_quarantine_bypass_total(), + single_endpoint_shadow_rotate_total: stats.get_me_single_endpoint_shadow_rotate_total(), + single_endpoint_shadow_rotate_skipped_quarantine_total: stats + .get_me_single_endpoint_shadow_rotate_skipped_quarantine_total(), + floor_mode_switch_total: stats.get_me_floor_mode_switch_total(), + floor_mode_switch_static_to_adaptive_total: stats + .get_me_floor_mode_switch_static_to_adaptive_total(), + floor_mode_switch_adaptive_to_static_total: stats + .get_me_floor_mode_switch_adaptive_to_static_total(), + }, + pool: ZeroPoolData { + pool_swap_total: stats.get_pool_swap_total(), + pool_drain_active: stats.get_pool_drain_active(), + pool_force_close_total: stats.get_pool_force_close_total(), + pool_stale_pick_total: stats.get_pool_stale_pick_total(), + writer_removed_total: stats.get_me_writer_removed_total(), + writer_removed_unexpected_total: stats.get_me_writer_removed_unexpected_total(), + refill_triggered_total: stats.get_me_refill_triggered_total(), + refill_skipped_inflight_total: stats.get_me_refill_skipped_inflight_total(), + refill_failed_total: stats.get_me_refill_failed_total(), + writer_restored_same_endpoint_total: stats.get_me_writer_restored_same_endpoint_total(), + writer_restored_fallback_total: stats.get_me_writer_restored_fallback_total(), + }, + desync: ZeroDesyncData { + secure_padding_invalid_total: stats.get_secure_padding_invalid(), + desync_total: stats.get_desync_total(), + desync_full_logged_total: stats.get_desync_full_logged(), + desync_suppressed_total: stats.get_desync_suppressed(), + desync_frames_bucket_0: stats.get_desync_frames_bucket_0(), + desync_frames_bucket_1_2: stats.get_desync_frames_bucket_1_2(), + desync_frames_bucket_3_10: stats.get_desync_frames_bucket_3_10(), + desync_frames_bucket_gt_10: stats.get_desync_frames_bucket_gt_10(), + }, + } +} + +pub(super) async fn build_minimal_all_data( + shared: &ApiShared, + api_cfg: &ApiConfig, +) -> MinimalAllData { + let now = now_epoch_secs(); + if !api_cfg.minimal_runtime_enabled { + return MinimalAllData { + enabled: false, + reason: Some(FEATURE_DISABLED_REASON), + generated_at_epoch_secs: now, + data: None, + }; + } + + let Some((generated_at_epoch_secs, payload)) = + get_minimal_payload_cached(shared, api_cfg.minimal_runtime_cache_ttl_ms).await + else { + return MinimalAllData { + enabled: true, + reason: Some(SOURCE_UNAVAILABLE_REASON), + generated_at_epoch_secs: now, + data: Some(MinimalAllPayload { + me_writers: disabled_me_writers(now, SOURCE_UNAVAILABLE_REASON), + dcs: disabled_dcs(now, SOURCE_UNAVAILABLE_REASON), + me_runtime: None, + network_path: Vec::new(), + }), + }; + }; + + MinimalAllData { + enabled: true, + reason: None, + generated_at_epoch_secs, + data: Some(payload), + } +} + +pub(super) async fn build_me_writers_data( + shared: &ApiShared, + api_cfg: &ApiConfig, +) -> MeWritersData { + let now = now_epoch_secs(); + if !api_cfg.minimal_runtime_enabled { + return disabled_me_writers(now, FEATURE_DISABLED_REASON); + } + + let Some((_, payload)) = + get_minimal_payload_cached(shared, api_cfg.minimal_runtime_cache_ttl_ms).await + else { + return disabled_me_writers(now, SOURCE_UNAVAILABLE_REASON); + }; + payload.me_writers +} + +pub(super) async fn build_dcs_data(shared: &ApiShared, api_cfg: &ApiConfig) -> DcStatusData { + let now = now_epoch_secs(); + if !api_cfg.minimal_runtime_enabled { + return disabled_dcs(now, FEATURE_DISABLED_REASON); + } + + let Some((_, payload)) = + get_minimal_payload_cached(shared, api_cfg.minimal_runtime_cache_ttl_ms).await + else { + return disabled_dcs(now, SOURCE_UNAVAILABLE_REASON); + }; + payload.dcs +} + +async fn get_minimal_payload_cached( + shared: &ApiShared, + cache_ttl_ms: u64, +) -> Option<(u64, MinimalAllPayload)> { + if cache_ttl_ms > 0 { + let now = Instant::now(); + let cached = shared.minimal_cache.lock().await.clone(); + if let Some(entry) = cached + && now < entry.expires_at + { + return Some((entry.generated_at_epoch_secs, entry.payload)); + } + } + + let pool = shared.me_pool.as_ref()?; + let status = pool.api_status_snapshot().await; + let runtime = pool.api_runtime_snapshot().await; + let generated_at_epoch_secs = status.generated_at_epoch_secs; + + let me_writers = MeWritersData { + middle_proxy_enabled: true, + reason: None, + generated_at_epoch_secs, + summary: MeWritersSummary { + configured_dc_groups: status.configured_dc_groups, + configured_endpoints: status.configured_endpoints, + available_endpoints: status.available_endpoints, + available_pct: status.available_pct, + required_writers: status.required_writers, + alive_writers: status.alive_writers, + coverage_pct: status.coverage_pct, + }, + writers: status + .writers + .into_iter() + .map(|entry| MeWriterStatus { + writer_id: entry.writer_id, + dc: entry.dc, + endpoint: entry.endpoint.to_string(), + generation: entry.generation, + state: entry.state, + draining: entry.draining, + degraded: entry.degraded, + bound_clients: entry.bound_clients, + idle_for_secs: entry.idle_for_secs, + rtt_ema_ms: entry.rtt_ema_ms, + }) + .collect(), + }; + let dcs = DcStatusData { + middle_proxy_enabled: true, + reason: None, + generated_at_epoch_secs, + dcs: status + .dcs + .into_iter() + .map(|entry| DcStatus { + dc: entry.dc, + endpoints: entry + .endpoints + .into_iter() + .map(|value| value.to_string()) + .collect(), + available_endpoints: entry.available_endpoints, + available_pct: entry.available_pct, + required_writers: entry.required_writers, + alive_writers: entry.alive_writers, + coverage_pct: entry.coverage_pct, + rtt_ms: entry.rtt_ms, + load: entry.load, + }) + .collect(), + }; + let me_runtime = MinimalMeRuntimeData { + active_generation: runtime.active_generation, + warm_generation: runtime.warm_generation, + pending_hardswap_generation: runtime.pending_hardswap_generation, + pending_hardswap_age_secs: runtime.pending_hardswap_age_secs, + hardswap_enabled: runtime.hardswap_enabled, + floor_mode: runtime.floor_mode, + adaptive_floor_idle_secs: runtime.adaptive_floor_idle_secs, + adaptive_floor_min_writers_single_endpoint: runtime + .adaptive_floor_min_writers_single_endpoint, + adaptive_floor_recover_grace_secs: runtime.adaptive_floor_recover_grace_secs, + me_keepalive_enabled: runtime.me_keepalive_enabled, + me_keepalive_interval_secs: runtime.me_keepalive_interval_secs, + me_keepalive_jitter_secs: runtime.me_keepalive_jitter_secs, + me_keepalive_payload_random: runtime.me_keepalive_payload_random, + rpc_proxy_req_every_secs: runtime.rpc_proxy_req_every_secs, + me_reconnect_max_concurrent_per_dc: runtime.me_reconnect_max_concurrent_per_dc, + me_reconnect_backoff_base_ms: runtime.me_reconnect_backoff_base_ms, + me_reconnect_backoff_cap_ms: runtime.me_reconnect_backoff_cap_ms, + me_reconnect_fast_retry_count: runtime.me_reconnect_fast_retry_count, + me_pool_drain_ttl_secs: runtime.me_pool_drain_ttl_secs, + me_pool_force_close_secs: runtime.me_pool_force_close_secs, + me_pool_min_fresh_ratio: runtime.me_pool_min_fresh_ratio, + me_bind_stale_mode: runtime.me_bind_stale_mode, + me_bind_stale_ttl_secs: runtime.me_bind_stale_ttl_secs, + me_single_endpoint_shadow_writers: runtime.me_single_endpoint_shadow_writers, + me_single_endpoint_outage_mode_enabled: runtime.me_single_endpoint_outage_mode_enabled, + me_single_endpoint_outage_disable_quarantine: runtime + .me_single_endpoint_outage_disable_quarantine, + me_single_endpoint_outage_backoff_min_ms: runtime.me_single_endpoint_outage_backoff_min_ms, + me_single_endpoint_outage_backoff_max_ms: runtime.me_single_endpoint_outage_backoff_max_ms, + me_single_endpoint_shadow_rotate_every_secs: runtime + .me_single_endpoint_shadow_rotate_every_secs, + me_deterministic_writer_sort: runtime.me_deterministic_writer_sort, + me_socks_kdf_policy: runtime.me_socks_kdf_policy, + quarantined_endpoints_total: runtime.quarantined_endpoints.len(), + quarantined_endpoints: runtime + .quarantined_endpoints + .into_iter() + .map(|entry| MinimalQuarantineData { + endpoint: entry.endpoint.to_string(), + remaining_ms: entry.remaining_ms, + }) + .collect(), + }; + let network_path = runtime + .network_path + .into_iter() + .map(|entry| MinimalDcPathData { + dc: entry.dc, + ip_preference: entry.ip_preference, + selected_addr_v4: entry.selected_addr_v4.map(|value| value.to_string()), + selected_addr_v6: entry.selected_addr_v6.map(|value| value.to_string()), + }) + .collect(); + + let payload = MinimalAllPayload { + me_writers, + dcs, + me_runtime: Some(me_runtime), + network_path, + }; + + if cache_ttl_ms > 0 { + let entry = MinimalCacheEntry { + expires_at: Instant::now() + Duration::from_millis(cache_ttl_ms), + payload: payload.clone(), + generated_at_epoch_secs, + }; + *shared.minimal_cache.lock().await = Some(entry); + } + + Some((generated_at_epoch_secs, payload)) +} + +fn disabled_me_writers(now_epoch_secs: u64, reason: &'static str) -> MeWritersData { + MeWritersData { + middle_proxy_enabled: false, + reason: Some(reason), + generated_at_epoch_secs: now_epoch_secs, + summary: MeWritersSummary { + configured_dc_groups: 0, + configured_endpoints: 0, + available_endpoints: 0, + available_pct: 0.0, + required_writers: 0, + alive_writers: 0, + coverage_pct: 0.0, + }, + writers: Vec::new(), + } +} + +fn disabled_dcs(now_epoch_secs: u64, reason: &'static str) -> DcStatusData { + DcStatusData { + middle_proxy_enabled: false, + reason: Some(reason), + generated_at_epoch_secs: now_epoch_secs, + dcs: Vec::new(), + } +} + +fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} diff --git a/src/api/users.rs b/src/api/users.rs new file mode 100644 index 0000000..9fc03e9 --- /dev/null +++ b/src/api/users.rs @@ -0,0 +1,435 @@ +use std::collections::HashMap; +use std::net::IpAddr; + +use hyper::StatusCode; + +use crate::config::ProxyConfig; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; + +use super::ApiShared; +use super::config_store::{ + ensure_expected_revision, load_config_from_disk, save_config_to_disk, +}; +use super::model::{ + ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest, + UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username, + parse_optional_expiration, random_user_secret, +}; + +pub(super) async fn create_user( + body: CreateUserRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(CreateUserResponse, String), ApiFailure> { + if !is_valid_username(&body.username) { + return Err(ApiFailure::bad_request( + "username must match [A-Za-z0-9_.-] and be 1..64 chars", + )); + } + + let secret = match body.secret { + Some(secret) => { + if !is_valid_user_secret(&secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + secret + } + None => random_user_secret(), + }; + + if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + return Err(ApiFailure::bad_request( + "user_ad_tag must be exactly 32 hex characters", + )); + } + + let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if cfg.access.users.contains_key(&body.username) { + return Err(ApiFailure::new( + StatusCode::CONFLICT, + "user_exists", + "User already exists", + )); + } + + cfg.access.users.insert(body.username.clone(), secret.clone()); + if let Some(ad_tag) = body.user_ad_tag { + cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag); + } + if let Some(limit) = body.max_tcp_conns { + cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit); + } + if let Some(expiration) = expiration { + cfg.access + .user_expirations + .insert(body.username.clone(), expiration); + } + if let Some(quota) = body.data_quota_bytes { + cfg.access.user_data_quota.insert(body.username.clone(), quota); + } + + let updated_limit = body.max_unique_ips; + if let Some(limit) = updated_limit { + cfg.access + .user_max_unique_ips + .insert(body.username.clone(), limit); + } + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + + if let Some(limit) = updated_limit { + shared.ip_tracker.set_user_limit(&body.username, limit).await; + } + + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user = users + .into_iter() + .find(|entry| entry.username == body.username) + .unwrap_or(UserInfo { + username: body.username.clone(), + user_ad_tag: None, + max_tcp_conns: None, + expiration_rfc3339: None, + data_quota_bytes: None, + max_unique_ips: updated_limit, + current_connections: 0, + active_unique_ips: 0, + total_octets: 0, + links: build_user_links(&cfg, &secret), + }); + + Ok((CreateUserResponse { user, secret }, revision)) +} + +pub(super) async fn patch_user( + user: &str, + body: PatchUserRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(UserInfo, String), ApiFailure> { + if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + return Err(ApiFailure::bad_request( + "user_ad_tag must be exactly 32 hex characters", + )); + } + let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + + if let Some(secret) = body.secret { + cfg.access.users.insert(user.to_string(), secret); + } + if let Some(ad_tag) = body.user_ad_tag { + cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + } + if let Some(limit) = body.max_tcp_conns { + cfg.access.user_max_tcp_conns.insert(user.to_string(), limit); + } + if let Some(expiration) = expiration { + cfg.access.user_expirations.insert(user.to_string(), expiration); + } + if let Some(quota) = body.data_quota_bytes { + cfg.access.user_data_quota.insert(user.to_string(), quota); + } + + let mut updated_limit = None; + if let Some(limit) = body.max_unique_ips { + cfg.access.user_max_unique_ips.insert(user.to_string(), limit); + updated_limit = Some(limit); + } + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + if let Some(limit) = updated_limit { + shared.ip_tracker.set_user_limit(user, limit).await; + } + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user_info = users + .into_iter() + .find(|entry| entry.username == user) + .ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?; + + Ok((user_info, revision)) +} + +pub(super) async fn rotate_secret( + user: &str, + body: RotateSecretRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(CreateUserResponse, String), ApiFailure> { + let secret = body.secret.unwrap_or_else(random_user_secret); + if !is_valid_user_secret(&secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + + cfg.access.users.insert(user.to_string(), secret.clone()); + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user_info = users + .into_iter() + .find(|entry| entry.username == user) + .ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?; + + Ok(( + CreateUserResponse { + user: user_info, + secret, + }, + revision, + )) +} + +pub(super) async fn delete_user( + user: &str, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(String, String), ApiFailure> { + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + if cfg.access.users.len() <= 1 { + return Err(ApiFailure::new( + StatusCode::CONFLICT, + "last_user_forbidden", + "Cannot delete the last configured user", + )); + } + + cfg.access.users.remove(user); + cfg.access.user_ad_tags.remove(user); + cfg.access.user_max_tcp_conns.remove(user); + cfg.access.user_expirations.remove(user); + cfg.access.user_data_quota.remove(user); + cfg.access.user_max_unique_ips.remove(user); + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + shared.ip_tracker.clear_user_ips(user).await; + + Ok((user.to_string(), revision)) +} + +pub(super) async fn users_from_config( + cfg: &ProxyConfig, + stats: &Stats, + ip_tracker: &UserIpTracker, +) -> Vec { + let ip_counts = ip_tracker + .get_stats() + .await + .into_iter() + .map(|(user, count, _)| (user, count)) + .collect::>(); + + let mut names = cfg.access.users.keys().cloned().collect::>(); + names.sort(); + + let mut users = Vec::with_capacity(names.len()); + for username in names { + let links = cfg + .access + .users + .get(&username) + .map(|secret| build_user_links(cfg, secret)) + .unwrap_or(UserLinks { + classic: Vec::new(), + secure: Vec::new(), + tls: Vec::new(), + }); + users.push(UserInfo { + user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(), + max_tcp_conns: cfg.access.user_max_tcp_conns.get(&username).copied(), + expiration_rfc3339: cfg + .access + .user_expirations + .get(&username) + .map(chrono::DateTime::::to_rfc3339), + data_quota_bytes: cfg.access.user_data_quota.get(&username).copied(), + max_unique_ips: cfg.access.user_max_unique_ips.get(&username).copied(), + current_connections: stats.get_user_curr_connects(&username), + active_unique_ips: ip_counts.get(&username).copied().unwrap_or(0), + total_octets: stats.get_user_total_octets(&username), + links, + username, + }); + } + users +} + +fn build_user_links(cfg: &ProxyConfig, secret: &str) -> UserLinks { + let hosts = resolve_link_hosts(cfg); + let port = cfg.general.links.public_port.unwrap_or(cfg.server.port); + let tls_domains = resolve_tls_domains(cfg); + + let mut classic = Vec::new(); + let mut secure = Vec::new(); + let mut tls = Vec::new(); + + for host in &hosts { + if cfg.general.modes.classic { + classic.push(format!( + "tg://proxy?server={}&port={}&secret={}", + host, port, secret + )); + } + if cfg.general.modes.secure { + secure.push(format!( + "tg://proxy?server={}&port={}&secret=dd{}", + host, port, secret + )); + } + if cfg.general.modes.tls { + for domain in &tls_domains { + let domain_hex = hex::encode(domain); + tls.push(format!( + "tg://proxy?server={}&port={}&secret=ee{}{}", + host, port, secret, domain_hex + )); + } + } + } + + UserLinks { + classic, + secure, + tls, + } +} + +fn resolve_link_hosts(cfg: &ProxyConfig) -> Vec { + if let Some(host) = cfg + .general + .links + .public_host + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + { + return vec![host.to_string()]; + } + + let mut hosts = Vec::new(); + for listener in &cfg.server.listeners { + if let Some(host) = listener + .announce + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + { + push_unique_host(&mut hosts, host); + continue; + } + if let Some(ip) = listener.announce_ip { + if !ip.is_unspecified() { + push_unique_host(&mut hosts, &ip.to_string()); + } + continue; + } + if !listener.ip.is_unspecified() { + push_unique_host(&mut hosts, &listener.ip.to_string()); + } + } + + if hosts.is_empty() { + if let Some(host) = cfg.server.listen_addr_ipv4.as_deref() { + push_host_from_legacy_listen(&mut hosts, host); + } + if let Some(host) = cfg.server.listen_addr_ipv6.as_deref() { + push_host_from_legacy_listen(&mut hosts, host); + } + } + + hosts +} + +fn push_host_from_legacy_listen(hosts: &mut Vec, raw: &str) { + let candidate = raw.trim(); + if candidate.is_empty() { + return; + } + + match candidate.parse::() { + Ok(ip) if ip.is_unspecified() => {} + Ok(ip) => push_unique_host(hosts, &ip.to_string()), + Err(_) => push_unique_host(hosts, candidate), + } +} + +fn push_unique_host(hosts: &mut Vec, candidate: &str) { + if !hosts.iter().any(|existing| existing == candidate) { + hosts.push(candidate.to_string()); + } +} + +fn resolve_tls_domains(cfg: &ProxyConfig) -> Vec<&str> { + let mut domains = Vec::with_capacity(1 + cfg.censorship.tls_domains.len()); + let primary = cfg.censorship.tls_domain.as_str(); + if !primary.is_empty() { + domains.push(primary); + } + for domain in &cfg.censorship.tls_domains { + let value = domain.as_str(); + if value.is_empty() || domains.contains(&value) { + continue; + } + domains.push(value); + } + domains +} diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 41573a4..86f569b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -92,6 +92,26 @@ pub(crate) fn default_metrics_whitelist() -> Vec { ] } +pub(crate) fn default_api_listen() -> String { + "127.0.0.1:9091".to_string() +} + +pub(crate) fn default_api_whitelist() -> Vec { + default_metrics_whitelist() +} + +pub(crate) fn default_api_request_body_limit_bytes() -> usize { + 64 * 1024 +} + +pub(crate) fn default_api_minimal_runtime_enabled() -> bool { + false +} + +pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 { + 1000 +} + pub(crate) fn default_prefer_4() -> u8 { 4 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 902811c..d752d45 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -115,6 +115,18 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig) { old.server.port, new.server.port ); } + if old.server.api.enabled != new.server.api.enabled + || old.server.api.listen != new.server.api.listen + || old.server.api.whitelist != new.server.api.whitelist + || old.server.api.auth_header != new.server.api.auth_header + || old.server.api.request_body_limit_bytes != new.server.api.request_body_limit_bytes + || old.server.api.minimal_runtime_enabled != new.server.api.minimal_runtime_enabled + || old.server.api.minimal_runtime_cache_ttl_ms + != new.server.api.minimal_runtime_cache_ttl_ms + || old.server.api.read_only != new.server.api.read_only + { + warn!("config reload: server.api changed; restart required"); + } if old.censorship.tls_domain != new.censorship.tls_domain { warn!( "config reload: censorship.tls_domain changed ('{}' → '{}'); restart required", diff --git a/src/config/load.rs b/src/config/load.rs index c051b8e..b469299 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,7 +1,7 @@ #![allow(deprecated)] use std::collections::HashMap; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::Path; use rand::Rng; @@ -398,6 +398,24 @@ impl ProxyConfig { )); } + if config.server.api.request_body_limit_bytes == 0 { + return Err(ProxyError::Config( + "server.api.request_body_limit_bytes must be > 0".to_string(), + )); + } + + if config.server.api.minimal_runtime_cache_ttl_ms > 60_000 { + return Err(ProxyError::Config( + "server.api.minimal_runtime_cache_ttl_ms must be within [0, 60000]".to_string(), + )); + } + + if config.server.api.listen.parse::().is_err() { + return Err(ProxyError::Config( + "server.api.listen must be in IP:PORT format".to_string(), + )); + } + if config.general.effective_me_pool_force_close_secs() > 0 && config.general.effective_me_pool_force_close_secs() < config.general.me_pool_drain_ttl_secs @@ -695,6 +713,20 @@ mod tests { assert_eq!(cfg.general.update_every, default_update_every()); assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4()); assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt()); + assert_eq!(cfg.server.api.listen, default_api_listen()); + assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); + assert_eq!( + cfg.server.api.request_body_limit_bytes, + default_api_request_body_limit_bytes() + ); + assert_eq!( + cfg.server.api.minimal_runtime_enabled, + default_api_minimal_runtime_enabled() + ); + assert_eq!( + cfg.server.api.minimal_runtime_cache_ttl_ms, + default_api_minimal_runtime_cache_ttl_ms() + ); assert_eq!(cfg.access.users, default_access_users()); } @@ -776,6 +808,20 @@ mod tests { let server = ServerConfig::default(); assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6())); + assert_eq!(server.api.listen, default_api_listen()); + assert_eq!(server.api.whitelist, default_api_whitelist()); + assert_eq!( + server.api.request_body_limit_bytes, + default_api_request_body_limit_bytes() + ); + assert_eq!( + server.api.minimal_runtime_enabled, + default_api_minimal_runtime_enabled() + ); + assert_eq!( + server.api.minimal_runtime_cache_ttl_ms, + default_api_minimal_runtime_cache_ttl_ms() + ); let access = AccessConfig::default(); assert_eq!(access.users, default_access_users()); @@ -1322,6 +1368,28 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn api_minimal_runtime_cache_ttl_out_of_range_is_rejected() { + let toml = r#" + [server.api] + enabled = true + listen = "127.0.0.1:9091" + minimal_runtime_cache_ttl_ms = 70000 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_api_minimal_runtime_cache_ttl_invalid_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("server.api.minimal_runtime_cache_ttl_ms must be within [0, 60000]")); + let _ = std::fs::remove_file(path); + } + #[test] fn force_close_bumped_when_below_drain_ttl() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 64be729..ee17108 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -793,6 +793,58 @@ impl Default for LinksConfig { } } +/// API settings for control-plane endpoints. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ApiConfig { + /// Enable or disable REST API. + #[serde(default)] + pub enabled: bool, + + /// Listen address for API in `IP:PORT` format. + #[serde(default = "default_api_listen")] + pub listen: String, + + /// CIDR whitelist allowed to access API. + #[serde(default = "default_api_whitelist")] + pub whitelist: Vec, + + /// Optional static value for `Authorization` header validation. + /// Empty string disables header auth. + #[serde(default)] + pub auth_header: String, + + /// Maximum accepted HTTP request body size in bytes. + #[serde(default = "default_api_request_body_limit_bytes")] + pub request_body_limit_bytes: usize, + + /// Enable runtime snapshots that require read-lock aggregation on API request path. + #[serde(default = "default_api_minimal_runtime_enabled")] + pub minimal_runtime_enabled: bool, + + /// Cache TTL for minimal runtime snapshots in milliseconds (0 disables caching). + #[serde(default = "default_api_minimal_runtime_cache_ttl_ms")] + pub minimal_runtime_cache_ttl_ms: u64, + + /// Read-only mode: mutating endpoints are rejected. + #[serde(default)] + pub read_only: bool, +} + +impl Default for ApiConfig { + fn default() -> Self { + Self { + enabled: false, + listen: default_api_listen(), + whitelist: default_api_whitelist(), + auth_header: String::new(), + request_body_limit_bytes: default_api_request_body_limit_bytes(), + minimal_runtime_enabled: default_api_minimal_runtime_enabled(), + minimal_runtime_cache_ttl_ms: default_api_minimal_runtime_cache_ttl_ms(), + read_only: false, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerConfig { #[serde(default = "default_port")] @@ -828,6 +880,9 @@ pub struct ServerConfig { #[serde(default = "default_metrics_whitelist")] pub metrics_whitelist: Vec, + #[serde(default, alias = "admin_api")] + pub api: ApiConfig, + #[serde(default)] pub listeners: Vec, } @@ -844,6 +899,7 @@ impl Default for ServerConfig { proxy_protocol: false, metrics_port: None, metrics_whitelist: default_metrics_whitelist(), + api: ApiConfig::default(), listeners: Vec::new(), } } diff --git a/src/main.rs b/src/main.rs index f7f9239..c4f0e68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; use tokio::net::UnixListener; mod cli; +mod api; mod config; mod crypto; mod error; @@ -1152,6 +1153,38 @@ async fn main() -> std::result::Result<(), Box> { }); } + if config.server.api.enabled { + let listen = match config.server.api.listen.parse::() { + Ok(listen) => listen, + Err(error) => { + warn!( + error = %error, + listen = %config.server.api.listen, + "Invalid server.api.listen; API is disabled" + ); + SocketAddr::from(([127, 0, 0, 1], 0)) + } + }; + if listen.port() != 0 { + let stats = stats.clone(); + let ip_tracker_api = ip_tracker.clone(); + let me_pool_api = me_pool.clone(); + let config_rx_api = config_rx.clone(); + let config_path_api = std::path::PathBuf::from(&config_path); + tokio::spawn(async move { + api::serve( + listen, + stats, + ip_tracker_api, + me_pool_api, + config_rx_api, + config_path_api, + ) + .await; + }); + } + } + for (listener, listener_proxy_protocol) in listeners { let mut config_rx: tokio::sync::watch::Receiver> = config_rx.clone(); let stats = stats.clone(); diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 26c58a6..e7c7957 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -18,6 +18,7 @@ mod rotation; mod send; mod secret; mod wire; +mod pool_status; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs new file mode 100644 index 0000000..c01f74b --- /dev/null +++ b/src/transport/middle_proxy/pool_status.rs @@ -0,0 +1,424 @@ +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::net::SocketAddr; +use std::sync::atomic::Ordering; +use std::time::Instant; + +use super::pool::{MePool, WriterContour}; +use crate::config::{MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy}; +use crate::transport::upstream::IpPreference; + +#[derive(Clone, Debug)] +pub(crate) struct MeApiWriterStatusSnapshot { + pub writer_id: u64, + pub dc: Option, + pub endpoint: SocketAddr, + pub generation: u64, + pub state: &'static str, + pub draining: bool, + pub degraded: bool, + pub bound_clients: usize, + pub idle_for_secs: Option, + pub rtt_ema_ms: Option, +} + +#[derive(Clone, Debug)] +pub(crate) struct MeApiDcStatusSnapshot { + pub dc: i16, + pub endpoints: Vec, + pub available_endpoints: usize, + pub available_pct: f64, + pub required_writers: usize, + pub alive_writers: usize, + pub coverage_pct: f64, + pub rtt_ms: Option, + pub load: usize, +} + +#[derive(Clone, Debug)] +pub(crate) struct MeApiStatusSnapshot { + pub generated_at_epoch_secs: u64, + pub configured_dc_groups: usize, + pub configured_endpoints: usize, + pub available_endpoints: usize, + pub available_pct: f64, + pub required_writers: usize, + pub alive_writers: usize, + pub coverage_pct: f64, + pub writers: Vec, + pub dcs: Vec, +} + +#[derive(Clone, Debug)] +pub(crate) struct MeApiQuarantinedEndpointSnapshot { + pub endpoint: SocketAddr, + pub remaining_ms: u64, +} + +#[derive(Clone, Debug)] +pub(crate) struct MeApiDcPathSnapshot { + pub dc: i16, + pub ip_preference: Option<&'static str>, + pub selected_addr_v4: Option, + pub selected_addr_v6: Option, +} + +#[derive(Clone, Debug)] +pub(crate) struct MeApiRuntimeSnapshot { + pub active_generation: u64, + pub warm_generation: u64, + pub pending_hardswap_generation: u64, + pub pending_hardswap_age_secs: Option, + pub hardswap_enabled: bool, + pub floor_mode: &'static str, + pub adaptive_floor_idle_secs: u64, + pub adaptive_floor_min_writers_single_endpoint: u8, + pub adaptive_floor_recover_grace_secs: u64, + pub me_keepalive_enabled: bool, + pub me_keepalive_interval_secs: u64, + pub me_keepalive_jitter_secs: u64, + pub me_keepalive_payload_random: bool, + pub rpc_proxy_req_every_secs: u64, + pub me_reconnect_max_concurrent_per_dc: u32, + pub me_reconnect_backoff_base_ms: u64, + pub me_reconnect_backoff_cap_ms: u64, + pub me_reconnect_fast_retry_count: u32, + pub me_pool_drain_ttl_secs: u64, + pub me_pool_force_close_secs: u64, + pub me_pool_min_fresh_ratio: f32, + pub me_bind_stale_mode: &'static str, + pub me_bind_stale_ttl_secs: u64, + pub me_single_endpoint_shadow_writers: u8, + pub me_single_endpoint_outage_mode_enabled: bool, + pub me_single_endpoint_outage_disable_quarantine: bool, + pub me_single_endpoint_outage_backoff_min_ms: u64, + pub me_single_endpoint_outage_backoff_max_ms: u64, + pub me_single_endpoint_shadow_rotate_every_secs: u64, + pub me_deterministic_writer_sort: bool, + pub me_socks_kdf_policy: &'static str, + pub quarantined_endpoints: Vec, + pub network_path: Vec, +} + +impl MePool { + pub(crate) async fn api_status_snapshot(&self) -> MeApiStatusSnapshot { + let now_epoch_secs = Self::now_epoch_secs(); + + let mut endpoints_by_dc = BTreeMap::>::new(); + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + + let mut endpoint_to_dc = HashMap::::new(); + for (dc, endpoints) in &endpoints_by_dc { + for endpoint in endpoints { + endpoint_to_dc.entry(*endpoint).or_insert(*dc); + } + } + + let configured_dc_groups = endpoints_by_dc.len(); + let configured_endpoints = endpoints_by_dc.values().map(BTreeSet::len).sum(); + + let required_writers = endpoints_by_dc + .values() + .map(|endpoints| self.required_writers_for_dc_with_floor_mode(endpoints.len(), false)) + .sum(); + + let idle_since = self.registry.writer_idle_since_snapshot().await; + let activity = self.registry.writer_activity_snapshot().await; + let rtt = self.rtt_stats.lock().await.clone(); + let writers = self.writers.read().await.clone(); + + let mut live_writers_by_endpoint = HashMap::::new(); + let mut live_writers_by_dc = HashMap::::new(); + let mut dc_rtt_agg = HashMap::::new(); + let mut writer_rows = Vec::::with_capacity(writers.len()); + + for writer in writers { + let endpoint = writer.addr; + let dc = endpoint_to_dc.get(&endpoint).copied(); + let draining = writer.draining.load(Ordering::Relaxed); + let degraded = writer.degraded.load(Ordering::Relaxed); + let bound_clients = activity + .bound_clients_by_writer + .get(&writer.id) + .copied() + .unwrap_or(0); + let idle_for_secs = idle_since + .get(&writer.id) + .map(|idle_ts| now_epoch_secs.saturating_sub(*idle_ts)); + let rtt_ema_ms = rtt.get(&writer.id).map(|(_, ema)| *ema); + let state = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Warm => "warm", + WriterContour::Active => "active", + WriterContour::Draining => "draining", + }; + + if !draining { + *live_writers_by_endpoint.entry(endpoint).or_insert(0) += 1; + if let Some(dc_idx) = dc { + *live_writers_by_dc.entry(dc_idx).or_insert(0) += 1; + if let Some(ema_ms) = rtt_ema_ms { + let entry = dc_rtt_agg.entry(dc_idx).or_insert((0.0, 0)); + entry.0 += ema_ms; + entry.1 += 1; + } + } + } + + writer_rows.push(MeApiWriterStatusSnapshot { + writer_id: writer.id, + dc, + endpoint, + generation: writer.generation, + state, + draining, + degraded, + bound_clients, + idle_for_secs, + rtt_ema_ms, + }); + } + + writer_rows.sort_by_key(|row| (row.dc.unwrap_or(i16::MAX), row.endpoint, row.writer_id)); + + let mut dcs = Vec::::with_capacity(endpoints_by_dc.len()); + let mut available_endpoints = 0usize; + let mut alive_writers = 0usize; + for (dc, endpoints) in endpoints_by_dc { + let endpoint_count = endpoints.len(); + let dc_available_endpoints = endpoints + .iter() + .filter(|endpoint| live_writers_by_endpoint.contains_key(endpoint)) + .count(); + let dc_required_writers = + self.required_writers_for_dc_with_floor_mode(endpoint_count, false); + let dc_alive_writers = live_writers_by_dc.get(&dc).copied().unwrap_or(0); + let dc_load = activity + .active_sessions_by_target_dc + .get(&dc) + .copied() + .unwrap_or(0); + let dc_rtt_ms = dc_rtt_agg + .get(&dc) + .and_then(|(sum, count)| (*count > 0).then_some(*sum / (*count as f64))); + + available_endpoints += dc_available_endpoints; + alive_writers += dc_alive_writers; + + dcs.push(MeApiDcStatusSnapshot { + dc, + endpoints: endpoints.into_iter().collect(), + available_endpoints: dc_available_endpoints, + available_pct: ratio_pct(dc_available_endpoints, endpoint_count), + required_writers: dc_required_writers, + alive_writers: dc_alive_writers, + coverage_pct: ratio_pct(dc_alive_writers, dc_required_writers), + rtt_ms: dc_rtt_ms, + load: dc_load, + }); + } + + MeApiStatusSnapshot { + generated_at_epoch_secs: now_epoch_secs, + configured_dc_groups, + configured_endpoints, + available_endpoints, + available_pct: ratio_pct(available_endpoints, configured_endpoints), + required_writers, + alive_writers, + coverage_pct: ratio_pct(alive_writers, required_writers), + writers: writer_rows, + dcs, + } + } + + pub(crate) async fn api_runtime_snapshot(&self) -> MeApiRuntimeSnapshot { + let now = Instant::now(); + let now_epoch_secs = Self::now_epoch_secs(); + let pending_started_at = self + .pending_hardswap_started_at_epoch_secs + .load(Ordering::Relaxed); + let pending_hardswap_age_secs = (pending_started_at > 0) + .then_some(now_epoch_secs.saturating_sub(pending_started_at)); + + let mut quarantined_endpoints = Vec::::new(); + { + let guard = self.endpoint_quarantine.lock().await; + for (endpoint, expires_at) in guard.iter() { + if *expires_at <= now { + continue; + } + let remaining_ms = expires_at.duration_since(now).as_millis() as u64; + quarantined_endpoints.push(MeApiQuarantinedEndpointSnapshot { + endpoint: *endpoint, + remaining_ms, + }); + } + } + quarantined_endpoints.sort_by_key(|entry| entry.endpoint); + + let mut network_path = Vec::::new(); + if let Some(upstream) = &self.upstream { + for dc in 1..=5 { + let dc_idx = dc as i16; + let ip_preference = upstream + .get_dc_ip_preference(dc_idx) + .await + .map(ip_preference_label); + let selected_addr_v4 = upstream.get_dc_addr(dc_idx, false).await; + let selected_addr_v6 = upstream.get_dc_addr(dc_idx, true).await; + network_path.push(MeApiDcPathSnapshot { + dc: dc_idx, + ip_preference, + selected_addr_v4, + selected_addr_v6, + }); + } + } + + MeApiRuntimeSnapshot { + active_generation: self.active_generation.load(Ordering::Relaxed), + warm_generation: self.warm_generation.load(Ordering::Relaxed), + pending_hardswap_generation: self.pending_hardswap_generation.load(Ordering::Relaxed), + pending_hardswap_age_secs, + hardswap_enabled: self.hardswap.load(Ordering::Relaxed), + floor_mode: floor_mode_label(self.floor_mode()), + adaptive_floor_idle_secs: self.me_adaptive_floor_idle_secs.load(Ordering::Relaxed), + adaptive_floor_min_writers_single_endpoint: self + .me_adaptive_floor_min_writers_single_endpoint + .load(Ordering::Relaxed), + adaptive_floor_recover_grace_secs: self + .me_adaptive_floor_recover_grace_secs + .load(Ordering::Relaxed), + me_keepalive_enabled: self.me_keepalive_enabled, + me_keepalive_interval_secs: self.me_keepalive_interval.as_secs(), + me_keepalive_jitter_secs: self.me_keepalive_jitter.as_secs(), + me_keepalive_payload_random: self.me_keepalive_payload_random, + rpc_proxy_req_every_secs: self.rpc_proxy_req_every_secs.load(Ordering::Relaxed), + me_reconnect_max_concurrent_per_dc: self.me_reconnect_max_concurrent_per_dc, + me_reconnect_backoff_base_ms: self.me_reconnect_backoff_base.as_millis() as u64, + me_reconnect_backoff_cap_ms: self.me_reconnect_backoff_cap.as_millis() as u64, + me_reconnect_fast_retry_count: self.me_reconnect_fast_retry_count, + me_pool_drain_ttl_secs: self.me_pool_drain_ttl_secs.load(Ordering::Relaxed), + me_pool_force_close_secs: self.me_pool_force_close_secs.load(Ordering::Relaxed), + me_pool_min_fresh_ratio: Self::permille_to_ratio( + self.me_pool_min_fresh_ratio_permille.load(Ordering::Relaxed), + ), + me_bind_stale_mode: bind_stale_mode_label(self.bind_stale_mode()), + me_bind_stale_ttl_secs: self.me_bind_stale_ttl_secs.load(Ordering::Relaxed), + me_single_endpoint_shadow_writers: self + .me_single_endpoint_shadow_writers + .load(Ordering::Relaxed), + me_single_endpoint_outage_mode_enabled: self + .me_single_endpoint_outage_mode_enabled + .load(Ordering::Relaxed), + me_single_endpoint_outage_disable_quarantine: self + .me_single_endpoint_outage_disable_quarantine + .load(Ordering::Relaxed), + me_single_endpoint_outage_backoff_min_ms: self + .me_single_endpoint_outage_backoff_min_ms + .load(Ordering::Relaxed), + me_single_endpoint_outage_backoff_max_ms: self + .me_single_endpoint_outage_backoff_max_ms + .load(Ordering::Relaxed), + me_single_endpoint_shadow_rotate_every_secs: self + .me_single_endpoint_shadow_rotate_every_secs + .load(Ordering::Relaxed), + me_deterministic_writer_sort: self + .me_deterministic_writer_sort + .load(Ordering::Relaxed), + me_socks_kdf_policy: socks_kdf_policy_label(self.socks_kdf_policy()), + quarantined_endpoints, + network_path, + } + } +} + +fn ratio_pct(part: usize, total: usize) -> f64 { + if total == 0 { + return 0.0; + } + let pct = ((part as f64) / (total as f64)) * 100.0; + pct.clamp(0.0, 100.0) +} + +fn floor_mode_label(mode: MeFloorMode) -> &'static str { + match mode { + MeFloorMode::Static => "static", + MeFloorMode::Adaptive => "adaptive", + } +} + +fn bind_stale_mode_label(mode: MeBindStaleMode) -> &'static str { + match mode { + MeBindStaleMode::Never => "never", + MeBindStaleMode::Ttl => "ttl", + MeBindStaleMode::Always => "always", + } +} + +fn socks_kdf_policy_label(policy: MeSocksKdfPolicy) -> &'static str { + match policy { + MeSocksKdfPolicy::Strict => "strict", + MeSocksKdfPolicy::Compat => "compat", + } +} + +fn ip_preference_label(preference: IpPreference) -> &'static str { + match preference { + IpPreference::Unknown => "unknown", + IpPreference::PreferV6 => "prefer_v6", + IpPreference::PreferV4 => "prefer_v4", + IpPreference::BothWork => "both", + IpPreference::Unavailable => "unavailable", + } +} + +#[cfg(test)] +mod tests { + use super::ratio_pct; + + #[test] + fn ratio_pct_is_zero_when_denominator_is_zero() { + assert_eq!(ratio_pct(1, 0), 0.0); + } + + #[test] + fn ratio_pct_is_capped_at_100() { + assert_eq!(ratio_pct(7, 3), 100.0); + } + + #[test] + fn ratio_pct_reports_expected_value() { + assert_eq!(ratio_pct(1, 4), 25.0); + } +} diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 4a66654..869030a 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -45,6 +45,12 @@ pub struct ConnWriter { pub tx: mpsc::Sender, } +#[derive(Clone, Debug, Default)] +pub(super) struct WriterActivitySnapshot { + pub bound_clients_by_writer: HashMap, + pub active_sessions_by_target_dc: HashMap, +} + struct RegistryInner { map: HashMap>, writers: HashMap>, @@ -241,6 +247,30 @@ impl ConnRegistry { inner.writer_idle_since_epoch_secs.clone() } + pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot { + let inner = self.inner.read().await; + let mut bound_clients_by_writer = HashMap::::new(); + let mut active_sessions_by_target_dc = HashMap::::new(); + + for (writer_id, conn_ids) in &inner.conns_for_writer { + bound_clients_by_writer.insert(*writer_id, conn_ids.len()); + } + for conn_meta in inner.meta.values() { + let dc_u16 = conn_meta.target_dc.unsigned_abs(); + if dc_u16 == 0 { + continue; + } + if let Ok(dc) = i16::try_from(dc_u16) { + *active_sessions_by_target_dc.entry(dc).or_insert(0) += 1; + } + } + + WriterActivitySnapshot { + bound_clients_by_writer, + active_sessions_by_target_dc, + } + } + pub async fn get_writer(&self, conn_id: u64) -> Option { let inner = self.inner.read().await; let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?; @@ -288,3 +318,69 @@ impl ConnRegistry { .unwrap_or(true) } } + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + use super::ConnMeta; + use super::ConnRegistry; + + #[tokio::test] + async fn writer_activity_snapshot_tracks_writer_and_dc_load() { + let registry = ConnRegistry::new(); + + let (conn_a, _rx_a) = registry.register().await; + let (conn_b, _rx_b) = registry.register().await; + let (conn_c, _rx_c) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + registry + .bind_writer( + conn_a, + 10, + writer_tx_a.clone(), + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await; + registry + .bind_writer( + conn_b, + 10, + writer_tx_a, + ConnMeta { + target_dc: -2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await; + registry + .bind_writer( + conn_c, + 20, + writer_tx_b, + ConnMeta { + target_dc: 4, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await; + + let snapshot = registry.writer_activity_snapshot().await; + assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2)); + assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&2)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); + } +}