carbus_lib/carbus_async/remote/server.py

251 lines
8.2 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import json
import logging
import secrets
from dataclasses import dataclass, field
from typing import Dict, Optional
log = logging.getLogger("carbus_remote.server")
@dataclass
class AgentControl:
serial: str
password: str
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
@dataclass
class PendingSession:
serial: str
client_reader: asyncio.StreamReader
client_writer: asyncio.StreamWriter
ready: asyncio.Event = field(default_factory=asyncio.Event) # agent_data подключился
done: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future()) # pipe завершён
class RelayServer:
def __init__(self) -> None:
self._agents: Dict[str, AgentControl] = {}
self._pending: Dict[str, PendingSession] = {}
self._lock = asyncio.Lock()
async def handle_conn(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
peer = writer.get_extra_info("peername")
try:
line = await asyncio.wait_for(reader.readline(), timeout=10.0)
if not line:
return
try:
hello = json.loads(line.decode("utf-8", errors="ignore").strip())
except Exception:
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
return
role = hello.get("role")
if role == "agent":
await self._handle_agent_control(reader, writer, hello, peer)
return
if role == "client":
await self._handle_client(reader, writer, hello, peer)
return
if role == "agent_data":
await self._handle_agent_data(reader, writer, hello, peer)
return
await self._send_json(writer, {"ok": False, "error": "bad_role"})
except asyncio.TimeoutError:
await self._send_json(writer, {"ok": False, "error": "handshake_timeout"})
except Exception:
log.exception("Connection error from %s", peer)
finally:
if getattr(writer, "_carbus_piped", False):
return
if not writer.is_closing():
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
async def _handle_agent_control(self, reader, writer, hello, peer) -> None:
serial = str(hello.get("serial", "")).strip()
password = str(hello.get("password", "")).strip()
if not serial or not password:
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
return
async with self._lock:
old = self._agents.get(serial)
if old:
try:
old.writer.close()
except Exception:
pass
self._agents[serial] = AgentControl(serial, password, reader, writer)
await self._send_json(writer, {"ok": True})
log.info("Agent online serial=%s from %s", serial, peer)
try:
while True:
line = await reader.readline()
if not line:
break
finally:
async with self._lock:
cur = self._agents.get(serial)
if cur and cur.writer is writer:
self._agents.pop(serial, None)
log.info("Agent offline serial=%s", serial)
async def _handle_client(self, reader, writer, hello, peer) -> None:
serial = str(hello.get("serial", "")).strip()
password = str(hello.get("password", "")).strip()
if not serial or not password:
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
return
async with self._lock:
agent = self._agents.get(serial)
if agent is None:
await self._send_json(writer, {"ok": False, "error": "agent_offline"})
return
if agent.password != password:
await self._send_json(writer, {"ok": False, "error": "unauthorized"})
return
session = secrets.token_hex(8)
ps = PendingSession(serial=serial, client_reader=reader, client_writer=writer)
self._pending[session] = ps
try:
agent.writer.write((json.dumps({"cmd": "open_session", "session": session}) + "\n").encode("utf-8"))
await agent.writer.drain()
except Exception:
self._pending.pop(session, None)
await self._send_json(writer, {"ok": False, "error": "agent_write_failed"})
return
await self._send_json(writer, {"ok": True, "session": session})
log.info("Client accepted serial=%s session=%s from %s", serial, session, peer)
try:
await asyncio.wait_for(ps.ready.wait(), timeout=10.0)
except asyncio.TimeoutError:
async with self._lock:
self._pending.pop(session, None)
await self._send_json(writer, {"ok": False, "error": "agent_data_timeout"})
return
try:
await ps.done
finally:
async with self._lock:
self._pending.pop(session, None)
async def _handle_agent_data(self, reader, writer, hello, peer) -> None:
session = str(hello.get("session", "")).strip()
if not session:
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
return
async with self._lock:
ps = self._pending.get(session)
if ps is None:
await self._send_json(writer, {"ok": False, "error": "unknown_session"})
return
await self._send_json(writer, {"ok": True})
log.info("Agent data connected session=%s from %s (pairing)", session, peer)
setattr(ps.client_writer, "_carbus_piped", True)
setattr(writer, "_carbus_piped", True)
ps.ready.set()
try:
await self._pipe(ps.client_reader, ps.client_writer, reader, writer)
finally:
if not ps.done.done():
ps.done.set_result(None)
async def _pipe(self, a_reader, a_writer, b_reader, b_writer, bufsize: int = 4096) -> None:
async def pump(src, dst):
try:
while True:
data = await src.read(bufsize)
if not data:
break
dst.write(data)
await dst.drain()
except Exception:
pass
finally:
try:
dst.close()
except Exception:
pass
t1 = asyncio.create_task(pump(a_reader, b_writer))
t2 = asyncio.create_task(pump(b_reader, a_writer))
await asyncio.wait({t1, t2}, return_when=asyncio.FIRST_COMPLETED)
for t in (t1, t2):
t.cancel()
try:
await t
except Exception:
pass
for w in (a_writer, b_writer):
try:
w.close()
except Exception:
pass
for w in (a_writer, b_writer):
try:
await w.wait_closed()
except Exception:
pass
@staticmethod
async def _send_json(writer, obj: dict) -> None:
writer.write((json.dumps(obj) + "\n").encode("utf-8"))
await writer.drain()
async def main(host: str = "0.0.0.0", port: int = 9000) -> None:
rs = RelayServer()
srv = await asyncio.start_server(rs.handle_conn, host, port)
log.info("Relay server listening on %s", ", ".join(str(s.getsockname()) for s in srv.sockets or []))
async with srv:
await srv.serve_forever()
def cli() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--host", default="0.0.0.0")
ap.add_argument("--port", type=int, default=9000)
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
asyncio.run(main(host=args.host, port=args.port))
if __name__ == "__main__":
cli()