diff --git a/tests/test_socks5_protocol.py b/tests/test_socks5_protocol.py new file mode 100644 index 0000000..964cd44 --- /dev/null +++ b/tests/test_socks5_protocol.py @@ -0,0 +1,129 @@ +import asyncio +import socket +import unittest +from unittest.mock import patch + +from proxy.tg_ws_proxy import _handle_client, _socks5_reply + + +class _FakeTransport: + def get_extra_info(self, name): + return None + + def get_write_buffer_size(self): + return 0 + + +class _FakeReader: + def __init__(self, payload: bytes): + self._payload = payload + self._offset = 0 + + async def readexactly(self, n: int) -> bytes: + end = self._offset + n + if end > len(self._payload): + partial = self._payload[self._offset:] + self._offset = len(self._payload) + raise asyncio.IncompleteReadError(partial, n) + chunk = self._payload[self._offset:end] + self._offset = end + return chunk + + +class _FakeWriter: + def __init__(self): + self.transport = _FakeTransport() + self.writes = [] + self.closed = False + self.close_calls = 0 + + def get_extra_info(self, name): + if name == "peername": + return ("127.0.0.1", 50000) + return None + + def write(self, data: bytes): + self.writes.append(data) + + async def drain(self): + return None + + def close(self): + self.closed = True + self.close_calls += 1 + + async def wait_closed(self): + return None + + +def _ipv4_connect_request(ip: str, port: int, cmd: int = 1) -> bytes: + return bytes([0x05, cmd, 0x00, 0x01]) + socket.inet_aton(ip) + port.to_bytes(2, "big") + + +def _domain_connect_request(domain: str, port: int, cmd: int = 1) -> bytes: + encoded = domain.encode("utf-8") + return ( + bytes([0x05, cmd, 0x00, 0x03, len(encoded)]) + + encoded + + port.to_bytes(2, "big") + ) + + +def _ipv6_connect_request(ip: str, port: int) -> bytes: + return ( + bytes([0x05, 0x01, 0x00, 0x04]) + + socket.inet_pton(socket.AF_INET6, ip) + + port.to_bytes(2, "big") + ) + + +class Socks5ProtocolTests(unittest.IsolatedAsyncioTestCase): + async def test_rejects_non_socks5_greeting(self): + reader = _FakeReader(b"\x04\x01") + writer = _FakeWriter() + + await _handle_client(reader, writer) + + self.assertEqual(writer.writes, []) + self.assertTrue(writer.closed) + + async def test_rejects_unsupported_command(self): + reader = _FakeReader(b"\x05\x01\x00" + _ipv4_connect_request("1.1.1.1", 443, cmd=2)) + writer = _FakeWriter() + + await _handle_client(reader, writer) + + self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x07)]) + self.assertTrue(writer.closed) + + async def test_rejects_unsupported_address_type(self): + reader = _FakeReader(b"\x05\x01\x00" + b"\x05\x01\x00\x02") + writer = _FakeWriter() + + await _handle_client(reader, writer) + + self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x08)]) + self.assertTrue(writer.closed) + + async def test_rejects_ipv6_destinations(self): + reader = _FakeReader(b"\x05\x01\x00" + _ipv6_connect_request("2001:db8::1", 443)) + writer = _FakeWriter() + + await _handle_client(reader, writer) + + self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)]) + self.assertTrue(writer.closed) + + async def test_passthrough_connect_failure_returns_error(self): + reader = _FakeReader(b"\x05\x01\x00" + _domain_connect_request("example.com", 443)) + writer = _FakeWriter() + + with patch("proxy.tg_ws_proxy.asyncio.open_connection", side_effect=OSError("boom")): + await _handle_client(reader, writer) + + self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)]) + self.assertTrue(writer.closed) + + +if __name__ == "__main__": + unittest.main()