diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 21ba025..bc3ca22 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -651,7 +651,9 @@ async def _bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, await ws.send(parts[0]) else: await ws.send(chunk) - except (asyncio.CancelledError, ConnectionError, OSError): + except asyncio.CancelledError: + raise + except (ConnectionError, OSError): return except Exception as e: log.debug("[%s] tcp->ws ended: %s", label, e) @@ -671,7 +673,9 @@ async def _bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, data = clt_encryptor.update(plain) writer.write(data) await writer.drain() - except (asyncio.CancelledError, ConnectionError, OSError): + except asyncio.CancelledError: + raise + except (ConnectionError, OSError): return except Exception as e: log.debug("[%s] ws->tcp ended: %s", label, e) @@ -682,12 +686,9 @@ async def _bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: for t in tasks: - t.cancel() - for t in tasks: - try: - await t - except BaseException: - pass + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) elapsed = asyncio.get_running_loop().time() - start_time log.info("[%s] %s WS session closed: " "^%s (%d pkts) v%s (%d pkts) in %.1fs", @@ -730,7 +731,7 @@ async def _bridge_tcp_reencrypt(reader, writer, remote_reader, remote_writer, dst_w.write(data) await dst_w.drain() except asyncio.CancelledError: - pass + raise except Exception as e: log.debug("[%s] forward ended: %s", label, e) @@ -742,12 +743,9 @@ async def _bridge_tcp_reencrypt(reader, writer, remote_reader, remote_writer, await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: for t in tasks: - t.cancel() - for t in tasks: - try: - await t - except BaseException: - pass + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) for w in (writer, remote_writer): try: w.close() @@ -990,6 +988,7 @@ async def _handle_client(reader, writer, secret: bytes): log.debug("[%s] client disconnected", label) except asyncio.CancelledError: log.debug("[%s] cancelled", label) + raise except ConnectionResetError: log.debug("[%s] connection reset", label) except OSError as exc: