From 3599fced7a379d9e9cb1d028fc7738ed76e93a91 Mon Sep 17 00:00:00 2001 From: amurcanov Date: Fri, 27 Mar 2026 03:48:33 +0300 Subject: [PATCH] Update v1.0.2 --- tg-ws-proxy.go | 642 +++++++++++++++++++++++++++---------------------- 1 file changed, 350 insertions(+), 292 deletions(-) diff --git a/tg-ws-proxy.go b/tg-ws-proxy.go index 6006e98..db5af01 100644 --- a/tg-ws-proxy.go +++ b/tg-ws-proxy.go @@ -17,7 +17,7 @@ static void androidLogProxy(char *msg) { import "C" import ( - "bytes" + "bufio" "context" "crypto/aes" "crypto/cipher" @@ -29,7 +29,6 @@ import ( "io" "log" "math" - "math/big" mrand "math/rand/v2" "net" "os" @@ -53,19 +52,19 @@ const ( tcpNodelay = true defaultRecvBuf = 256 * 1024 defaultSendBuf = 256 * 1024 - wsPoolSize = 4 - wsPoolMaxAge = 60.0 // seconds — reduced from 90 to 60 for VPNs - wsBridgeIdle = 120.0 // seconds — max idle time before bridge considers WS dead + defaultPoolSz = 4 + wsPoolMaxAge = 60.0 + wsBridgeIdle = 120.0 - dcFailCooldown = 30.0 // seconds - wsFailTimeout = 2.0 // seconds - poolMaintainInterval = 15 // seconds — frequent to send pings/keep-alives + dcFailCooldown = 30.0 + wsFailTimeout = 2.0 + poolMaintainInterval = 15 ) var ( recvBuf = defaultRecvBuf sendBuf = defaultSendBuf - poolSize = wsPoolSize + poolSize = defaultPoolSz logVerbose = false ) @@ -83,7 +82,7 @@ var ( type androidLogWriter struct{} func (w androidLogWriter) Write(p []byte) (n int, err error) { - os.Stderr.Write(p) + _, _ = os.Stderr.Write(p) cs := C.CString(string(p)) C.androidLogProxy(cs) C.free(unsafe.Pointer(cs)) @@ -206,22 +205,18 @@ var validProtos = map[uint32]bool{ // --------------------------------------------------------------------------- var ( - dcOpt map[int]string // dc -> target IP (empty string means not configured) - dcOptMu sync.RWMutex - wsBlackMu sync.RWMutex - wsBlacklist = make(map[[2]int]bool) // [dc, isMedia(0/1)] + dcOpt map[int]string + dcOptMu sync.RWMutex + + wsBlackMu sync.RWMutex + wsBlacklist = make(map[[2]int]bool) dcFailMu sync.RWMutex - dcFailUntil = make(map[[2]int]float64) // monotonic time + dcFailUntil = make(map[[2]int]float64) zero64 = make([]byte, 64) ) -// TLS config (skip verify, like Python version) -var tlsConfig = &tls.Config{ - InsecureSkipVerify: true, -} - // --------------------------------------------------------------------------- // Stats // --------------------------------------------------------------------------- @@ -243,7 +238,7 @@ func (s *Stats) Summary() string { ph := s.poolHits.Load() pm := s.poolMisses.Load() return fmt.Sprintf( - "total=%d ws=%d tcp_fb=%d http_skip=%d pass=%d err=%d pool_hits=%d/%d up=%s down=%s", + "total=%d ws=%d tcp_fb=%d http_skip=%d pass=%d err=%d pool=%d/%d up=%s down=%s", s.connectionsTotal.Load(), s.connectionsWs.Load(), s.connectionsTcpFallback.Load(), @@ -256,6 +251,19 @@ func (s *Stats) Summary() string { ) } +func (s *Stats) Reset() { + s.connectionsTotal.Store(0) + s.connectionsWs.Store(0) + s.connectionsTcpFallback.Store(0) + s.connectionsHttpReject.Store(0) + s.connectionsPassthrough.Store(0) + s.wsErrors.Store(0) + s.bytesUp.Store(0) + s.bytesDown.Store(0) + s.poolHits.Store(0) + s.poolMisses.Store(0) +} + var stats Stats func humanBytes(n int64) string { @@ -297,20 +305,58 @@ func setSockOpts(conn net.Conn) { } // --------------------------------------------------------------------------- -// XOR mask (WebSocket masking) +// XOR mask — optimized 8-byte processing // --------------------------------------------------------------------------- func xorMask(data, mask []byte) []byte { - if len(data) == 0 { + n := len(data) + if n == 0 { return data } - result := make([]byte, len(data)) - for i := range data { - result[i] = data[i] ^ mask[i%4] + + result := make([]byte, n) + + // Build 8-byte mask + mask8 := uint64(mask[0]) | uint64(mask[1])<<8 | + uint64(mask[2])<<16 | uint64(mask[3])<<24 | + uint64(mask[0])<<32 | uint64(mask[1])<<40 | + uint64(mask[2])<<48 | uint64(mask[3])<<56 + + i := 0 + // Process 8 bytes at a time + for ; i+8 <= n; i += 8 { + v := binary.LittleEndian.Uint64(data[i:]) + binary.LittleEndian.PutUint64(result[i:], v^mask8) + } + // Process remaining bytes + for ; i < n; i++ { + result[i] = data[i] ^ mask[i&3] } return result } +// xorMaskInPlace modifies data in place +func xorMaskInPlace(data, mask []byte) { + n := len(data) + if n == 0 { + return + } + + mask8 := uint64(mask[0]) | uint64(mask[1])<<8 | + uint64(mask[2])<<16 | uint64(mask[3])<<24 | + uint64(mask[0])<<32 | uint64(mask[1])<<40 | + uint64(mask[2])<<48 | uint64(mask[3])<<56 + + i := 0 + for ; i+8 <= n; i += 8 { + v := binary.LittleEndian.Uint64(data[i:]) + binary.LittleEndian.PutUint64(data[i:], v^mask8) + } + for ; i < n; i++ { + data[i] ^= mask[i&3] + } +} + // --------------------------------------------------------------------------- // WsHandshakeError // --------------------------------------------------------------------------- @@ -348,9 +394,10 @@ const ( ) type RawWebSocket struct { - conn net.Conn - mu sync.Mutex // write lock - closed atomic.Bool + conn net.Conn + bufReader *bufio.Reader + writeMu sync.Mutex + closed atomic.Bool } func wsConnect(ip, domain, path string, timeout float64) (*RawWebSocket, error) { @@ -410,30 +457,26 @@ func wsConnect(ip, domain, path string, timeout float64) (*RawWebSocket, error) } _ = rawConn.SetWriteDeadline(time.Time{}) - // Read HTTP response headers + // Use buffered reader for efficient header parsing + bufReader := bufio.NewReaderSize(rawConn, 4096) + _ = rawConn.SetReadDeadline(time.Now().Add(time.Duration(timeout * float64(time.Second)))) var responseLines []string - buf := make([]byte, 0, 4096) - tmp := make([]byte, 1) for { - _, err := rawConn.Read(tmp) + line, err := bufReader.ReadString('\n') if err != nil { rawConn.Close() return nil, err } - buf = append(buf, tmp[0]) - if len(buf) >= 2 && buf[len(buf)-2] == '\r' && buf[len(buf)-1] == '\n' { - line := string(buf[:len(buf)-2]) - buf = buf[:0] - if line == "" { - break - } - responseLines = append(responseLines, line) + line = strings.TrimRight(line, "\r\n") + if line == "" { + break } - if len(buf) > 16384 { + responseLines = append(responseLines, line) + if len(responseLines) > 100 { rawConn.Close() - return nil, fmt.Errorf("HTTP header too large") + return nil, fmt.Errorf("too many HTTP headers") } } _ = rawConn.SetReadDeadline(time.Time{}) @@ -451,7 +494,10 @@ func wsConnect(ip, domain, path string, timeout float64) (*RawWebSocket, error) } if statusCode == 101 { - ws := &RawWebSocket{conn: rawConn} + ws := &RawWebSocket{ + conn: rawConn, + bufReader: bufReader, + } return ws, nil } @@ -477,10 +523,10 @@ func (ws *RawWebSocket) Send(data []byte) error { if ws.closed.Load() { return fmt.Errorf("WebSocket closed") } - frame := buildFrame(opBinary, data, true) - ws.mu.Lock() - defer ws.mu.Unlock() + frame := ws.buildFrame(opBinary, data, true) + ws.writeMu.Lock() _, err := ws.conn.Write(frame) + ws.writeMu.Unlock() return err } @@ -488,20 +534,27 @@ func (ws *RawWebSocket) SendBatch(parts [][]byte) error { if ws.closed.Load() { return fmt.Errorf("WebSocket closed") } - ws.mu.Lock() - defer ws.mu.Unlock() + ws.writeMu.Lock() + defer ws.writeMu.Unlock() for _, part := range parts { - frame := buildFrame(opBinary, part, true) - _, err := ws.conn.Write(frame) - if err != nil { + frame := ws.buildFrame(opBinary, part, true) + if _, err := ws.conn.Write(frame); err != nil { return err } } - return err_nil_hack() + return nil } -// err_nil_hack is a workaround to return nil error -func err_nil_hack() error { return nil } +func (ws *RawWebSocket) SendPing() error { + if ws.closed.Load() { + return fmt.Errorf("WebSocket closed") + } + frame := ws.buildFrame(opPing, nil, true) + ws.writeMu.Lock() + _, err := ws.conn.Write(frame) + ws.writeMu.Unlock() + return err +} func (ws *RawWebSocket) Recv() ([]byte, error) { for !ws.closed.Load() { @@ -518,17 +571,17 @@ func (ws *RawWebSocket) Recv() ([]byte, error) { if len(closePayload) > 2 { closePayload = closePayload[:2] } - reply := buildFrame(opClose, closePayload, true) - ws.mu.Lock() + reply := ws.buildFrame(opClose, closePayload, true) + ws.writeMu.Lock() _, _ = ws.conn.Write(reply) - ws.mu.Unlock() + ws.writeMu.Unlock() return nil, io.EOF case opPing: - pong := buildFrame(opPong, payload, true) - ws.mu.Lock() + pong := ws.buildFrame(opPong, payload, true) + ws.writeMu.Lock() _, _ = ws.conn.Write(pong) - ws.mu.Unlock() + ws.writeMu.Unlock() continue case opPong: @@ -548,82 +601,90 @@ func (ws *RawWebSocket) Close() { if ws.closed.Swap(true) { return } - frame := buildFrame(opClose, nil, true) - ws.mu.Lock() + frame := ws.buildFrame(opClose, nil, true) + ws.writeMu.Lock() _, _ = ws.conn.Write(frame) - ws.mu.Unlock() + ws.writeMu.Unlock() _ = ws.conn.Close() } -func (ws *RawWebSocket) SendPing() error { - if ws.closed.Load() { - return fmt.Errorf("WebSocket closed") - } - frame := buildFrame(opPing, []byte{}, true) - ws.mu.Lock() - defer ws.mu.Unlock() - _, err := ws.conn.Write(frame) - return err +// SetReadDeadline exposes deadline control for the bridge +func (ws *RawWebSocket) SetReadDeadline(t time.Time) error { + return ws.conn.SetReadDeadline(t) } -func buildFrame(opcode int, data []byte, mask bool) []byte { +// buildFrame creates a WebSocket frame with minimal allocations +func (ws *RawWebSocket) buildFrame(opcode int, data []byte, mask bool) []byte { length := len(data) fb := byte(0x80 | opcode) - var header []byte - - if !mask { - if length < 126 { - header = []byte{fb, byte(length)} - } else if length < 65536 { - header = make([]byte, 4) - header[0] = fb - header[1] = 126 - binary.BigEndian.PutUint16(header[2:], uint16(length)) - } else { - header = make([]byte, 10) - header[0] = fb - header[1] = 127 - binary.BigEndian.PutUint64(header[2:], uint64(length)) - } - result := make([]byte, len(header)+length) - copy(result, header) - copy(result[len(header):], data) - return result + // Calculate total size + headerSize := 2 + if mask { + headerSize += 4 + } + if length >= 126 && length < 65536 { + headerSize += 2 + } else if length >= 65536 { + headerSize += 8 } - maskKey := make([]byte, 4) - _, _ = rand.Read(maskKey) - masked := xorMask(data, maskKey) + totalSize := headerSize + length + result := make([]byte, totalSize) + pos := 0 + + result[pos] = fb + pos++ + + var maskKey [4]byte + if mask { + _, _ = rand.Read(maskKey[:]) + } if length < 126 { - header = make([]byte, 6) - header[0] = fb - header[1] = byte(0x80 | length) - copy(header[2:6], maskKey) + lb := byte(length) + if mask { + lb |= 0x80 + } + result[pos] = lb + pos++ } else if length < 65536 { - header = make([]byte, 8) - header[0] = fb - header[1] = byte(0x80 | 126) - binary.BigEndian.PutUint16(header[2:4], uint16(length)) - copy(header[4:8], maskKey) + lb := byte(126) + if mask { + lb |= 0x80 + } + result[pos] = lb + pos++ + binary.BigEndian.PutUint16(result[pos:], uint16(length)) + pos += 2 } else { - header = make([]byte, 14) - header[0] = fb - header[1] = byte(0x80 | 127) - binary.BigEndian.PutUint64(header[2:10], uint64(length)) - copy(header[10:14], maskKey) + lb := byte(127) + if mask { + lb |= 0x80 + } + result[pos] = lb + pos++ + binary.BigEndian.PutUint64(result[pos:], uint64(length)) + pos += 8 + } + + if mask { + copy(result[pos:], maskKey[:]) + pos += 4 + // XOR directly into result buffer + payloadStart := pos + copy(result[payloadStart:], data) + xorMaskInPlace(result[payloadStart:payloadStart+length], maskKey[:]) + } else { + copy(result[pos:], data) } - result := make([]byte, len(header)+len(masked)) - copy(result, header) - copy(result[len(header):], masked) return result } func (ws *RawWebSocket) readFrame() (int, []byte, error) { hdr := make([]byte, 2) - if _, err := io.ReadFull(ws.conn, hdr); err != nil { + if _, err := io.ReadFull(ws.bufReader, hdr); err != nil { return 0, nil, err } @@ -632,13 +693,13 @@ func (ws *RawWebSocket) readFrame() (int, []byte, error) { if length == 126 { buf := make([]byte, 2) - if _, err := io.ReadFull(ws.conn, buf); err != nil { + if _, err := io.ReadFull(ws.bufReader, buf); err != nil { return 0, nil, err } length = uint64(binary.BigEndian.Uint16(buf)) } else if length == 127 { buf := make([]byte, 8) - if _, err := io.ReadFull(ws.conn, buf); err != nil { + if _, err := io.ReadFull(ws.bufReader, buf); err != nil { return 0, nil, err } length = binary.BigEndian.Uint64(buf) @@ -648,20 +709,20 @@ func (ws *RawWebSocket) readFrame() (int, []byte, error) { var maskKey []byte if hasMask { maskKey = make([]byte, 4) - if _, err := io.ReadFull(ws.conn, maskKey); err != nil { + if _, err := io.ReadFull(ws.bufReader, maskKey); err != nil { return 0, nil, err } } payload := make([]byte, length) if length > 0 { - if _, err := io.ReadFull(ws.conn, payload); err != nil { + if _, err := io.ReadFull(ws.bufReader, payload); err != nil { return 0, nil, err } } if hasMask { - payload = xorMask(payload, maskKey) + xorMaskInPlace(payload, maskKey) } return opcode, payload, nil @@ -693,7 +754,6 @@ func dcFromInit(data []byte) (dc int, isMedia bool, ok bool) { keystream := make([]byte, 64) stream.XORKeyStream(keystream, zero64) - // XOR bytes 56..64 of data with keystream plain := make([]byte, 8) for i := 0; i < 8; i++ { plain[i] = data[56+i] ^ keystream[56+i] @@ -762,7 +822,6 @@ func newMsgSplitter(initData []byte) (*MsgSplitter, error) { if err != nil { return nil, err } - // skip init packet (64 bytes of keystream) skip := make([]byte, 64) stream.XORKeyStream(skip, zero64) @@ -784,13 +843,12 @@ func (s *MsgSplitter) Split(chunk []byte) [][]byte { if pos+4 > plainLen { break } - // 3-byte little-endian length msgLen = int(uint32(plain[pos+1]) | uint32(plain[pos+2])<<8 | uint32(plain[pos+3])<<16) msgLen *= 4 pos += 4 } else { msgLen = int(first) * 4 - pos += 1 + pos++ } if msgLen == 0 || pos+msgLen > plainLen { break @@ -843,12 +901,12 @@ func wsDomains(dc int, isMedia *bool) []string { type poolEntry struct { ws *RawWebSocket - created float64 // monotonic seconds + created float64 } type WsPool struct { mu sync.Mutex - idle map[[2]int][]poolEntry // [dc, isMedia01] + idle map[[2]int][]poolEntry refilling map[[2]int]bool } @@ -875,6 +933,8 @@ func (p *WsPool) Get(dc int, isMedia bool, targetIP string, domains []string) *R now := monoNow() p.mu.Lock() + defer p.mu.Unlock() + bucket := p.idle[key] for len(bucket) > 0 { entry := bucket[0] @@ -890,19 +950,17 @@ func (p *WsPool) Get(dc int, isMedia bool, targetIP string, domains []string) *R stats.poolHits.Add(1) logDebug.Printf("WS pool hit for DC%d%s (age=%.1fs, left=%d)", dc, mediaTag(isMedia), age, len(bucket)) - p.scheduleRefill(key, targetIP, domains) - p.mu.Unlock() + p.scheduleRefillLocked(key, targetIP, domains) return entry.ws } - p.mu.Unlock() stats.poolMisses.Add(1) - p.scheduleRefill(key, targetIP, domains) + p.scheduleRefillLocked(key, targetIP, domains) return nil } -func (p *WsPool) scheduleRefill(key [2]int, targetIP string, domains []string) { - // Must be called with p.mu held or be safe +// scheduleRefillLocked must be called with p.mu held +func (p *WsPool) scheduleRefillLocked(key [2]int, targetIP string, domains []string) { if p.refilling[key] { return } @@ -971,6 +1029,9 @@ func connectOneWS(targetIP string, domains []string) *RawWebSocket { } func (p *WsPool) Warmup(dcOptMap map[int]string) { + p.mu.Lock() + defer p.mu.Unlock() + for dc, targetIP := range dcOptMap { if targetIP == "" { continue @@ -978,62 +1039,65 @@ func (p *WsPool) Warmup(dcOptMap map[int]string) { for _, isMedia := range []bool{false, true} { domains := wsDomains(dc, &isMedia) key := [2]int{dc, isMediaInt(isMedia)} - p.mu.Lock() - p.scheduleRefill(key, targetIP, domains) - p.mu.Unlock() + p.scheduleRefillLocked(key, targetIP, domains) } } logInfo.Printf("WS pool warmup started for %d DC(s)", len(dcOptMap)) } -// Maintain periodically evicts stale WS connections and refills pools func (p *WsPool) Maintain(ctx context.Context, dcOptMap map[int]string) { ticker := time.NewTicker(poolMaintainInterval * time.Second) defer ticker.Stop() + for { select { case <-ctx.Done(): return case <-ticker.C: - now := monoNow() - p.mu.Lock() - for key, bucket := range p.idle { - var fresh []poolEntry - for _, e := range bucket { - age := now - e.created - if age > wsPoolMaxAge || e.ws.closed.Load() { - go e.ws.Close() - } else { - // Send ping to keep VPN NAT session alive, and actively close if network is dead - go func(ws *RawWebSocket) { - if err := ws.SendPing(); err != nil { - ws.Close() - } - }(e.ws) - fresh = append(fresh, e) - } - } - p.idle[key] = fresh - } - p.mu.Unlock() - - // Refill all known DCs - for dc, targetIP := range dcOptMap { - if targetIP == "" { - continue - } - for _, isMedia := range []bool{false, true} { - domains := wsDomains(dc, &isMedia) - key := [2]int{dc, isMediaInt(isMedia)} - p.mu.Lock() - p.scheduleRefill(key, targetIP, domains) - p.mu.Unlock() - } - } + p.maintainOnce(dcOptMap) } } } +func (p *WsPool) maintainOnce(dcOptMap map[int]string) { + now := monoNow() + + p.mu.Lock() + for key, bucket := range p.idle { + var fresh []poolEntry + for _, e := range bucket { + age := now - e.created + if age > wsPoolMaxAge || e.ws.closed.Load() { + go e.ws.Close() + } else { + // Send ping to keep connection alive + go func(ws *RawWebSocket) { + if err := ws.SendPing(); err != nil { + ws.Close() + } + }(e.ws) + fresh = append(fresh, e) + } + } + p.idle[key] = fresh + } + p.mu.Unlock() + + // Refill all known DCs + p.mu.Lock() + for dc, targetIP := range dcOptMap { + if targetIP == "" { + continue + } + for _, isMedia := range []bool{false, true} { + domains := wsDomains(dc, &isMedia) + key := [2]int{dc, isMediaInt(isMedia)} + p.scheduleRefillLocked(key, targetIP, domains) + } + } + p.mu.Unlock() +} + func (p *WsPool) IdleCount() int { p.mu.Lock() defer p.mu.Unlock() @@ -1044,6 +1108,17 @@ func (p *WsPool) IdleCount() int { return count } +func (p *WsPool) CloseAll() { + p.mu.Lock() + defer p.mu.Unlock() + for key, bucket := range p.idle { + for _, e := range bucket { + go e.ws.Close() + } + delete(p.idle, key) + } +} + var wsPool = newWsPool() // --------------------------------------------------------------------------- @@ -1057,26 +1132,35 @@ func mediaTag(isMedia bool) string { return "" } -func boolPtr(b bool) *bool { - return &b -} - // --------------------------------------------------------------------------- // HTTP detection // --------------------------------------------------------------------------- func isHTTPTransport(data []byte) bool { - return bytes.HasPrefix(data, []byte("POST ")) || - bytes.HasPrefix(data, []byte("GET ")) || - bytes.HasPrefix(data, []byte("HEAD ")) || - bytes.HasPrefix(data, []byte("OPTIONS ")) + if len(data) < 4 { + return false + } + return string(data[:4]) == "POST" || + string(data[:3]) == "GET" || + string(data[:4]) == "HEAD" || + string(data[:7]) == "OPTIONS" } // --------------------------------------------------------------------------- // SOCKS5 reply // --------------------------------------------------------------------------- +var socks5Replies = map[byte][]byte{ + 0x00: {0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}, + 0x05: {0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}, + 0x07: {0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}, + 0x08: {0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}, +} + func socks5Reply(status byte) []byte { + if r, ok := socks5Replies[status]; ok { + return r + } return []byte{0x05, status, 0x00, 0x01, 0, 0, 0, 0, 0, 0} } @@ -1095,7 +1179,14 @@ func bridgeWS(ctx context.Context, conn net.Conn, ws *RawWebSocket, startTime := time.Now() ctx2, cancel := context.WithCancel(ctx) - defer cancel() + + // Critical: close connections when context is cancelled + // This unblocks the Read() calls in goroutines + go func() { + <-ctx2.Done() + _ = conn.Close() + ws.Close() + }() var wg sync.WaitGroup wg.Add(2) @@ -1106,12 +1197,6 @@ func bridgeWS(ctx context.Context, conn net.Conn, ws *RawWebSocket, defer cancel() buf := make([]byte, 65536) for { - select { - case <-ctx2.Done(): - return - default: - } - _ = conn.SetReadDeadline(time.Now().Add(time.Duration(wsBridgeIdle * float64(time.Second)))) n, err := conn.Read(buf) if n > 0 { chunk := buf[:n] @@ -1119,21 +1204,19 @@ func bridgeWS(ctx context.Context, conn net.Conn, ws *RawWebSocket, upBytes += int64(n) upPkts++ + var sendErr error if splitter != nil { parts := splitter.Split(chunk) if len(parts) > 1 { - if err2 := ws.SendBatch(parts); err2 != nil { - return - } + sendErr = ws.SendBatch(parts) } else { - if err2 := ws.Send(parts[0]); err2 != nil { - return - } + sendErr = ws.Send(parts[0]) } } else { - if err2 := ws.Send(chunk); err2 != nil { - return - } + sendErr = ws.Send(chunk) + } + if sendErr != nil { + return } } if err != nil { @@ -1147,12 +1230,6 @@ func bridgeWS(ctx context.Context, conn net.Conn, ws *RawWebSocket, defer wg.Done() defer cancel() for { - select { - case <-ctx2.Done(): - return - default: - } - _ = ws.conn.SetReadDeadline(time.Now().Add(time.Duration(wsBridgeIdle * float64(time.Second)))) data, err := ws.Recv() if err != nil || data == nil { return @@ -1175,9 +1252,6 @@ func bridgeWS(ctx context.Context, conn net.Conn, ws *RawWebSocket, humanBytes(upBytes), upPkts, humanBytes(downBytes), downPkts, elapsed) - - ws.Close() - conn.Close() } // --------------------------------------------------------------------------- @@ -1188,7 +1262,13 @@ func bridgeTCP(ctx context.Context, client, remote net.Conn, label string, dc int, dst string, port int, isMedia bool) { ctx2, cancel := context.WithCancel(ctx) - defer cancel() + + // Close connections when context cancelled + go func() { + <-ctx2.Done() + _ = client.Close() + _ = remote.Close() + }() var wg sync.WaitGroup wg.Add(2) @@ -1198,11 +1278,6 @@ func bridgeTCP(ctx context.Context, client, remote net.Conn, defer cancel() buf := make([]byte, 65536) for { - select { - case <-ctx2.Done(): - return - default: - } n, err := src.Read(buf) if n > 0 { if isUp { @@ -1224,8 +1299,6 @@ func bridgeTCP(ctx context.Context, client, remote net.Conn, go forward(remote, client, false) wg.Wait() - client.Close() - remote.Close() } // --------------------------------------------------------------------------- @@ -1253,7 +1326,8 @@ func tcpFallback(ctx context.Context, client net.Conn, dst string, port int, // Pipe (non-Telegram passthrough) // --------------------------------------------------------------------------- -func pipe(ctx context.Context, src, dst net.Conn) { +func pipe(ctx context.Context, src, dst net.Conn, done chan<- struct{}) { + defer func() { done <- struct{}{} }() buf := make([]byte, 65536) for { select { @@ -1280,7 +1354,7 @@ func pipe(ctx context.Context, src, dst net.Conn) { func readExactly(conn net.Conn, n int, timeout time.Duration) ([]byte, error) { if timeout > 0 { _ = conn.SetReadDeadline(time.Now().Add(timeout)) - defer conn.SetReadDeadline(time.Time{}) + defer func() { _ = conn.SetReadDeadline(time.Time{}) }() } buf := make([]byte, n) _, err := io.ReadFull(conn, buf) @@ -1294,9 +1368,7 @@ func handleClient(ctx context.Context, conn net.Conn) { setSockOpts(conn) - defer func() { - conn.Close() - }() + defer conn.Close() // -- SOCKS5 greeting -- hdr, err := readExactly(conn, 2, 10*time.Second) @@ -1325,7 +1397,7 @@ func handleClient(ctx context.Context, conn net.Conn) { atyp := req[3] if cmd != 1 { - conn.Write(socks5Reply(0x07)) + _, _ = conn.Write(socks5Reply(0x07)) return } @@ -1354,7 +1426,7 @@ func handleClient(ctx context.Context, conn net.Conn) { } dst = net.IP(raw).String() default: - conn.Write(socks5Reply(0x08)) + _, _ = conn.Write(socks5Reply(0x08)) return } @@ -1369,7 +1441,7 @@ func handleClient(ctx context.Context, conn net.Conn) { "IPv6 addresses are not supported; "+ "disable IPv6 to continue using the proxy.", label, dst, port) - conn.Write(socks5Reply(0x05)) + _, _ = conn.Write(socks5Reply(0x05)) return } @@ -1382,26 +1454,34 @@ func handleClient(ctx context.Context, conn net.Conn) { remote, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", dst, port)) if err != nil { logWarn.Printf("[%s] passthrough failed to %s: %T: %v", label, dst, err, err) - conn.Write(socks5Reply(0x05)) + _, _ = conn.Write(socks5Reply(0x05)) return } - conn.Write(socks5Reply(0x00)) + _, _ = conn.Write(socks5Reply(0x00)) ctx2, cancel := context.WithCancel(ctx) defer cancel() - var wg sync.WaitGroup - wg.Add(2) - go func() { defer wg.Done(); pipe(ctx2, conn, remote); cancel() }() - go func() { defer wg.Done(); pipe(ctx2, remote, conn); cancel() }() - wg.Wait() - remote.Close() + // Close connections when context done + go func() { + <-ctx2.Done() + _ = conn.Close() + _ = remote.Close() + }() + + done := make(chan struct{}, 2) + go pipe(ctx2, conn, remote, done) + go pipe(ctx2, remote, conn, done) + <-done + cancel() + <-done + _ = remote.Close() return } // -- Telegram DC: accept SOCKS, read init -- - conn.Write(socks5Reply(0x00)) + _, _ = conn.Write(socks5Reply(0x00)) init, err := readExactly(conn, 64, 15*time.Second) if err != nil { @@ -1437,11 +1517,10 @@ func handleClient(ctx context.Context, conn net.Conn) { dcOptMu.RUnlock() if hasDC { - signedDC := dc + // media -> positive dc, non-media -> negative dc + signedDC := -dc if isMedia { signedDC = dc - } else { - signedDC = -dc } init = patchInitDC(init, signedDC) initPatched = true @@ -1593,7 +1672,7 @@ func handleClient(ctx context.Context, conn net.Conn) { // Send init packet if err := ws.Send(init); err != nil { - logDebug.Printf("[%s] reconnecting via TCP fallback (WS broken by NAT): %v", label, err) + logDebug.Printf("[%s] reconnecting via TCP fallback (WS broken): %v", label, err) ws.Close() tcpFallback(ctx, conn, dst, port, init, label, dc, isMedia) return @@ -1607,13 +1686,6 @@ func handleClient(ctx context.Context, conn net.Conn) { // Server // --------------------------------------------------------------------------- -type ProxyServer struct { - listener net.Listener - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup -} - func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]string) error { dcOptMu.Lock() dcOpt = dcOptMap @@ -1627,7 +1699,6 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin return fmt.Errorf("listen on %s: %w", addr, err) } - // Set TCP_NODELAY on listening socket if possible if tcpL, ok := listener.(*net.TCPListener); ok { raw, err := tcpL.SyscallConn() if err == nil { @@ -1638,14 +1709,10 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin } srvCtx, srvCancel := context.WithCancel(ctx) - srv := &ProxyServer{ - listener: listener, - ctx: srvCtx, - cancel: srvCancel, - } + defer srvCancel() logInfo.Println(strings.Repeat("=", 60)) - logInfo.Println(" Telegram WS Bridge Proxy") + logInfo.Println(" Telegram WS Bridge Proxy (Go)") logInfo.Printf(" Listening on %s:%d", host, port) logInfo.Println(" Target DC IPs:") for dc, ip := range dcOptMap { @@ -1680,7 +1747,7 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin bl = strings.Join(blParts, ", ") } idleCount := wsPool.IdleCount() - logInfo.Printf("stats: %s idle_conns=%d | ws_bl: %s", stats.Summary(), idleCount, bl) + logInfo.Printf("stats: %s idle=%d | ws_bl: %s", stats.Summary(), idleCount, bl) } } }() @@ -1691,6 +1758,9 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin // Periodic pool maintenance go wsPool.Maintain(srvCtx, dcOptMap) + // Track active connections for graceful shutdown + var activeConns sync.WaitGroup + // Accept loop go func() { for { @@ -1707,23 +1777,23 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin return } } - srv.wg.Add(1) + activeConns.Add(1) go func() { - defer srv.wg.Done() + defer activeConns.Done() handleClient(srvCtx, conn) }() } }() - // Wait for context cancellation (graceful shutdown) + // Wait for context cancellation <-srvCtx.Done() logInfo.Println("Shutting down proxy server...") - listener.Close() + _ = listener.Close() - // Wait for all active connections with a timeout + // Wait for active connections with timeout done := make(chan struct{}) go func() { - srv.wg.Wait() + activeConns.Wait() close(done) }() @@ -1731,18 +1801,29 @@ func runProxy(ctx context.Context, host string, port int, dcOptMap map[int]strin case <-done: logInfo.Println("All connections closed gracefully") case <-time.After(30 * time.Second): - logWarn.Println("Graceful shutdown timed out after 30s, forcing exit") + logWarn.Println("Graceful shutdown timed out after 30s") } + // Close pool connections + wsPool.CloseAll() + logInfo.Printf("Final stats: %s", stats.Summary()) return nil } // --------------------------------------------------------------------------- -// Parse DC:IP list +// Parse DC:IP list / CIDR pool // --------------------------------------------------------------------------- func randomIPFromCIDR(cidr string) (string, error) { + // If it's just an IP (no /), return it as-is + if !strings.Contains(cidr, "/") { + if ip := net.ParseIP(cidr); ip != nil { + return cidr, nil + } + return "", fmt.Errorf("invalid IP: %s", cidr) + } + ip, ipnet, err := net.ParseCIDR(cidr) if err != nil { return "", err @@ -1751,16 +1832,16 @@ func randomIPFromCIDR(cidr string) (string, error) { if ip == nil { return "", fmt.Errorf("not ipv4") } - + start := binary.BigEndian.Uint32(ip) mask := binary.BigEndian.Uint32(ipnet.Mask) - + wildcard := ^mask offset := uint32(1) if wildcard > 1 { - offset = 1 + mrand.Uint32N(wildcard-1) + offset = 1 + mrand.Uint32N(wildcard-1) } - + randIP := start + offset res := make(net.IP, 4) binary.BigEndian.PutUint32(res, randIP) @@ -1778,7 +1859,7 @@ func parseCIDRPool(cidrsStr string) (map[int]string, error) { } } if len(validCIDRs) == 0 { - validCIDRs = []string{"149.154.167.220/32"} // Fallback + validCIDRs = []string{"149.154.167.220/32"} } dcs := []int{1, 2, 3, 4, 5, 203} @@ -1787,13 +1868,11 @@ func parseCIDRPool(cidrsStr string) (map[int]string, error) { ipStr, err := randomIPFromCIDR(cidr) if err == nil { result[dc] = ipStr + } else if net.ParseIP(cidr) != nil { + result[dc] = cidr } else { - if net.ParseIP(cidr) != nil { - result[dc] = cidr - } else { - result[dc] = "149.154.167.220" - } - } + result[dc] = "149.154.167.220" + } } return result, nil } @@ -1814,8 +1893,7 @@ func StartProxy(cHost *C.char, port C.int, cDcIps *C.char, verbose C.int) C.int defer globalMu.Unlock() if globalCancel != nil { - // Already running - return -1 + return -1 // Already running } host := C.GoString(cHost) @@ -1825,7 +1903,6 @@ func StartProxy(cHost *C.char, port C.int, cDcIps *C.char, verbose C.int) C.int initLogging(isVerbose) - // Passed string is a comma-separated list of CIDRs dcOptMap, err := parseCIDRPool(dcIpsStr) if err != nil { logError.Printf("parseCIDRPool: %v", err) @@ -1856,19 +1933,9 @@ func StopProxy() C.int { globalCancel = nil globalCtx = nil - // Reset stats for next run - stats.connectionsTotal.Store(0) - stats.connectionsWs.Store(0) - stats.connectionsTcpFallback.Store(0) - stats.connectionsHttpReject.Store(0) - stats.connectionsPassthrough.Store(0) - stats.wsErrors.Store(0) - stats.bytesUp.Store(0) - stats.bytesDown.Store(0) - stats.poolHits.Store(0) - stats.poolMisses.Store(0) + // Reset state + stats.Reset() - // Reset blacklists and fail timers wsBlackMu.Lock() wsBlacklist = make(map[[2]int]bool) wsBlackMu.Unlock() @@ -1907,17 +1974,14 @@ func FreeString(p *C.char) { } // --------------------------------------------------------------------------- -// Standalone main (for testing; noop when built as c-shared) +// Standalone main // --------------------------------------------------------------------------- func main() { - // When built as c-shared, main() is not called. - // For standalone testing: runtime.LockOSThread() initLogging(false) - // Default DC IPs dcOptMap := map[int]string{ 2: "149.154.167.220", 4: "149.154.167.220", @@ -1926,7 +1990,6 @@ func main() { host := "127.0.0.1" port := defaultPort - // Parse simple command line args := os.Args[1:] for i := 0; i < len(args); i++ { switch args[i] { @@ -1963,7 +2026,6 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - // Graceful shutdown on SIGINT/SIGTERM sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { @@ -1977,7 +2039,3 @@ func main() { os.Exit(1) } } - -// Ensure mrand is used to avoid import errors -var _ = mrand.Int -var _ = big.NewInt