diff --git a/src/proxy/tests/client_masking_hard_adversarial_tests.rs b/src/proxy/tests/client_masking_hard_adversarial_tests.rs index 86bd4fe..2794e28 100644 --- a/src/proxy/tests/client_masking_hard_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_hard_adversarial_tests.rs @@ -80,17 +80,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi "TLS length must fit into record header" ); - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); - + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -173,13 +228,11 @@ async fn run_tls_success_mtproto_fail_capture( assert_eq!(tls_response_head[0], 0x16); read_tls_record_body(&mut client_side, tls_response_head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); + let mut client_payload = invalid_mtproto_record; for record in trailing_records { - client_side.write_all(&record).await.unwrap(); + client_payload.extend_from_slice(&record); } + client_side.write_all(&client_payload).await.unwrap(); let got = tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -344,11 +397,9 @@ async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() { client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); read_tls_record_body(&mut client_side, head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&first_tail).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&first_tail); + client_side.write_all(&client_payload).await.unwrap(); } else { let mut one = [0u8; 1]; let no_server_hello = tokio::time::timeout( @@ -419,11 +470,9 @@ async fn connects_bad_increments_once_per_invalid_mtproto() { let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); read_tls_record_body(&mut client_side, head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&tail).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&tail); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -676,8 +725,9 @@ async fn concurrent_tls_mtproto_fail_sessions_are_isolated() { let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); read_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto).await.unwrap(); - client_side.write_all(&trailing).await.unwrap(); + let mut client_payload = invalid_mtproto; + client_payload.extend_from_slice(&trailing); + client_side.write_all(&client_payload).await.unwrap(); client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), handler) diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs index 44efa54..52ff3e7 100644 --- a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -71,17 +71,77 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -250,11 +310,9 @@ async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clea assert_eq!(head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); client_side.shutdown().await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) diff --git a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs index 3243bdd..e6731bf 100644 --- a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs @@ -77,17 +77,73 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi "TLS length must fit into record header" ); - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); - + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; - handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; let ts = timestamp.to_le_bytes(); @@ -156,14 +212,9 @@ async fn run_tls_success_mtproto_fail_session( let mut body = vec![0u8; body_len]; client_side.read_exact(&mut body).await.unwrap(); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side - .write_all(&wrap_tls_application_data(&tail)) - .await - .unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&wrap_tls_application_data(&tail)); + client_side.write_all(&client_payload).await.unwrap(); let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) .await diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs index 6ee205f..39f4862 100644 --- a/src/proxy/tests/client_masking_replay_timing_security_tests.rs +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -34,17 +34,77 @@ fn new_upstream_manager(stats: Arc) -> Arc { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -119,14 +179,9 @@ async fn run_replay_candidate_session( invalid_mtproto_record.extend_from_slice(&TLS_VERSION); invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side - .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") - .await - .unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n"); + client_side.write_all(&client_payload).await.unwrap(); } client_side.shutdown().await.unwrap(); diff --git a/src/proxy/tests/client_masking_stress_adversarial_tests.rs b/src/proxy/tests/client_masking_stress_adversarial_tests.rs index a6f734c..8e60f8a 100644 --- a/src/proxy/tests/client_masking_stress_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_stress_adversarial_tests.rs @@ -80,17 +80,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi "TLS length must fit into record header" ); - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); - + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -205,8 +260,13 @@ async fn run_parallel_tail_fallback_case( assert_eq!(server_hello_head[0], 0x16); read_tls_record_body(&mut client_side, server_hello_head).await; - client_side.write_all(&invalid_mtproto).await.unwrap(); - for chunk in trailing.chunks(write_chunk.max(1)) { + let mut chunks = trailing.chunks(write_chunk.max(1)); + let mut client_payload = invalid_mtproto; + if let Some(first_chunk) = chunks.next() { + client_payload.extend_from_slice(first_chunk); + } + client_side.write_all(&client_payload).await.unwrap(); + for chunk in chunks { client_side.write_all(chunk).await.unwrap(); } client_side.shutdown().await.unwrap(); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 506e230..50f8de2 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::{AesCtr, sha256, sha256_hmac}; use crate::protocol::constants::{ DC_IDX_POS, HANDSHAKE_LEN, IV_LEN, PREKEY_LEN, PROTO_TAG_POS, ProtoTag, SKIP_LEN, - TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_CHANGE_CIPHER, TLS_VERSION, }; use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; @@ -1630,17 +1630,73 @@ fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: "TLS length must fit into record header" ); - let total_len = 5 + tls_len; - let mut handshake = vec![0x42u8; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); - + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; + let fill = 0x42u8; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -2062,8 +2118,9 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&tls_app_record).await.unwrap(); - client_side.write_all(&trailing_tls_record).await.unwrap(); + let mut client_payload = tls_app_record; + client_payload.extend_from_slice(&trailing_tls_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -2188,8 +2245,9 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { client.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); - client.write_all(&tls_app_record).await.unwrap(); - client.write_all(&trailing_tls_record).await.unwrap(); + let mut client_payload = tls_app_record; + client_payload.extend_from_slice(&trailing_tls_record); + client.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), mask_accept_task) .await diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs index 838cd45..49c38b3 100644 --- a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -79,17 +79,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi "TLS length must fit into record header" ); - let total_len = 5 + tls_len; - let mut handshake = vec![fill; total_len]; - - handshake[0] = 0x16; - handshake[1] = 0x03; - handshake[2] = 0x01; - handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); - + const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; + const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033; + const TLS_EXTENSION_PADDING: u16 = 0x0015; + const X25519_KEY_SHARE_LEN: usize = 32; let session_id_len: usize = 32; - handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + let mut extensions = Vec::new(); + let mut key_share = Vec::new(); + key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes()); + key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes()); + key_share.push(9); + key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0); + + let mut key_share_extension = Vec::new(); + key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); + key_share_extension.extend_from_slice(&key_share); + extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes()); + extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&key_share_extension); + + let base_tls_len = 4 + + 2 + + 32 + + 1 + + session_id_len + + 2 + + TLS_AES_128_GCM_SHA256.len() + + 1 + + 1 + + 2 + + extensions.len(); + assert!( + tls_len == base_tls_len || tls_len >= base_tls_len + 4, + "TLS length must leave room for a complete padding extension" + ); + if tls_len > base_tls_len { + let padding_len = tls_len - base_tls_len - 4; + extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes()); + extensions.extend_from_slice(&(padding_len as u16).to_be_bytes()); + extensions.resize(extensions.len() + padding_len, fill); + } + + let body_len = tls_len - 4; + let mut body = Vec::with_capacity(body_len); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[fill; 32]); + body.push(session_id_len as u8); + body.extend_from_slice(&[fill; 32]); + body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes()); + body.extend_from_slice(&TLS_AES_128_GCM_SHA256); + body.push(1); + body.push(0); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + assert_eq!(body.len(), body_len); + + let mut handshake = Vec::with_capacity(5 + tls_len); + handshake.push(0x16); + handshake.extend_from_slice(&[0x03, 0x01]); + handshake.extend_from_slice(&(tls_len as u16).to_be_bytes()); + handshake.push(0x01); + let body_len_bytes = (body_len as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // The proxy authenticates TLS-fronted clients through the random field. handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; @@ -191,11 +246,9 @@ async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -261,11 +314,9 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -335,11 +386,9 @@ async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -403,11 +452,9 @@ async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -481,11 +528,9 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -586,11 +631,9 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) @@ -660,12 +703,14 @@ async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); + let mut chunks = trailing_record.chunks(3); + let mut client_payload = invalid_mtproto_record; + if let Some(first_chunk) = chunks.next() { + client_payload.extend_from_slice(first_chunk); + } + client_side.write_all(&client_payload).await.unwrap(); - for chunk in trailing_record.chunks(3) { + for chunk in chunks { client_side.write_all(chunk).await.unwrap(); } @@ -729,11 +774,13 @@ async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - for b in trailing_record.iter().copied() { + let mut bytes = trailing_record.iter().copied(); + let mut client_payload = invalid_mtproto_record; + if let Some(first_byte) = bytes.next() { + client_payload.push(first_byte); + } + client_side.write_all(&client_payload).await.unwrap(); + for b in bytes { client_side.write_all(&[b]).await.unwrap(); } @@ -802,14 +849,16 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; let mut pos = 0usize; let mut idx = 0usize; + let mut client_payload = invalid_mtproto_record; + let first_step = chaos[idx % chaos.len()]; + let first_end = first_step.min(trailing_record.len()); + client_payload.extend_from_slice(&trailing_record[..first_end]); + client_side.write_all(&client_payload).await.unwrap(); + pos = first_end; + idx += 1; while pos < trailing_record.len() { let step = chaos[idx % chaos.len()]; let end = (pos + step).min(trailing_record.len()); @@ -884,11 +933,9 @@ async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&r1).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&r1); + client_side.write_all(&client_payload).await.unwrap(); client_side.write_all(&r2).await.unwrap(); client_side.write_all(&r3).await.unwrap(); @@ -958,11 +1005,9 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); client_side.shutdown().await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -1029,11 +1074,9 @@ async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -1090,11 +1133,9 @@ async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - let write_res = client_side.write_all(&trailing_record).await; + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + let write_res = client_side.write_all(&client_payload).await; assert!( write_res.is_ok() || write_res.is_err(), "write completion is environment dependent under backend reset" @@ -1170,11 +1211,9 @@ async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); tokio::time::timeout(Duration::from_secs(5), accept_task) .await @@ -1254,11 +1293,9 @@ async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhel let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&trailing_record).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&trailing_record); + client_side.write_all(&client_payload).await.unwrap(); } else { let mut one = [0u8; 1]; let no_server_hello = tokio::time::timeout( @@ -1352,13 +1389,28 @@ async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; - for record in [&a, &b, &c] { + let mut records = [&a, &b, &c].iter().copied(); + let mut client_payload = invalid_mtproto_record; + if let Some(first_record) = records.next() { + let first_step = chaos[0].min(first_record.len()); + client_payload.extend_from_slice(&first_record[..first_step]); + client_side.write_all(&client_payload).await.unwrap(); + + let mut pos = first_step; + let mut idx = 1usize; + while pos < first_record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(first_record.len()); + client_side + .write_all(&first_record[pos..end]) + .await + .unwrap(); + pos = end; + idx += 1; + } + } + for record in records { let mut pos = 0usize; let mut idx = 0usize; while pos < record.len() { @@ -1433,11 +1485,9 @@ async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_ve .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - client_side.write_all(&ccs).await.unwrap(); + let mut client_payload = invalid_mtproto_record; + client_payload.extend_from_slice(&ccs); + client_side.write_all(&client_payload).await.unwrap(); client_side.write_all(&app).await.unwrap(); client_side.write_all(&alert).await.unwrap(); @@ -1533,11 +1583,13 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); - client_side - .write_all(&invalid_mtproto_record) - .await - .unwrap(); - for chunk in record.chunks((idx % 9) + 1) { + let mut chunks = record.chunks((idx % 9) + 1); + let mut client_payload = invalid_mtproto_record; + if let Some(first_chunk) = chunks.next() { + client_payload.extend_from_slice(first_chunk); + } + client_side.write_all(&client_payload).await.unwrap(); + for chunk in chunks { client_side.write_all(chunk).await.unwrap(); }