238 lines
8.0 KiB
Python
238 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import secrets
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
|
|
log = logging.getLogger("carbus_remote.server")
|
|
|
|
|
|
@dataclass
|
|
class AgentControl:
|
|
serial: str
|
|
password: str
|
|
reader: asyncio.StreamReader
|
|
writer: asyncio.StreamWriter
|
|
|
|
|
|
@dataclass
|
|
class PendingSession:
|
|
serial: str
|
|
client_reader: asyncio.StreamReader
|
|
client_writer: asyncio.StreamWriter
|
|
ready: asyncio.Event = field(default_factory=asyncio.Event) # agent_data подключился
|
|
done: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future()) # pipe завершён
|
|
|
|
|
|
class RelayServer:
|
|
def __init__(self) -> None:
|
|
self._agents: Dict[str, AgentControl] = {}
|
|
self._pending: Dict[str, PendingSession] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def handle_conn(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
peer = writer.get_extra_info("peername")
|
|
try:
|
|
line = await asyncio.wait_for(reader.readline(), timeout=10.0)
|
|
if not line:
|
|
return
|
|
|
|
try:
|
|
hello = json.loads(line.decode("utf-8", errors="ignore").strip())
|
|
except Exception:
|
|
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
|
|
return
|
|
|
|
role = hello.get("role")
|
|
|
|
if role == "agent":
|
|
await self._handle_agent_control(reader, writer, hello, peer)
|
|
return
|
|
|
|
if role == "client":
|
|
await self._handle_client(reader, writer, hello, peer)
|
|
return
|
|
|
|
if role == "agent_data":
|
|
await self._handle_agent_data(reader, writer, hello, peer)
|
|
return
|
|
|
|
await self._send_json(writer, {"ok": False, "error": "bad_role"})
|
|
|
|
except asyncio.TimeoutError:
|
|
await self._send_json(writer, {"ok": False, "error": "handshake_timeout"})
|
|
except Exception:
|
|
log.exception("Connection error from %s", peer)
|
|
finally:
|
|
if getattr(writer, "_carbus_piped", False):
|
|
return
|
|
if not writer.is_closing():
|
|
writer.close()
|
|
try:
|
|
await writer.wait_closed()
|
|
except Exception:
|
|
pass
|
|
|
|
async def _handle_agent_control(self, reader, writer, hello, peer) -> None:
|
|
serial = str(hello.get("serial", "")).strip()
|
|
password = str(hello.get("password", "")).strip()
|
|
if not serial or not password:
|
|
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
|
|
return
|
|
|
|
async with self._lock:
|
|
old = self._agents.get(serial)
|
|
if old:
|
|
try:
|
|
old.writer.close()
|
|
except Exception:
|
|
pass
|
|
self._agents[serial] = AgentControl(serial, password, reader, writer)
|
|
|
|
await self._send_json(writer, {"ok": True})
|
|
log.info("Agent online serial=%s from %s", serial, peer)
|
|
|
|
try:
|
|
while True:
|
|
line = await reader.readline()
|
|
if not line:
|
|
break
|
|
finally:
|
|
async with self._lock:
|
|
cur = self._agents.get(serial)
|
|
if cur and cur.writer is writer:
|
|
self._agents.pop(serial, None)
|
|
log.info("Agent offline serial=%s", serial)
|
|
|
|
async def _handle_client(self, reader, writer, hello, peer) -> None:
|
|
serial = str(hello.get("serial", "")).strip()
|
|
password = str(hello.get("password", "")).strip()
|
|
if not serial or not password:
|
|
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
|
|
return
|
|
|
|
async with self._lock:
|
|
agent = self._agents.get(serial)
|
|
if agent is None:
|
|
await self._send_json(writer, {"ok": False, "error": "agent_offline"})
|
|
return
|
|
if agent.password != password:
|
|
await self._send_json(writer, {"ok": False, "error": "unauthorized"})
|
|
return
|
|
|
|
session = secrets.token_hex(8)
|
|
ps = PendingSession(serial=serial, client_reader=reader, client_writer=writer)
|
|
self._pending[session] = ps
|
|
|
|
try:
|
|
agent.writer.write((json.dumps({"cmd": "open_session", "session": session}) + "\n").encode("utf-8"))
|
|
await agent.writer.drain()
|
|
except Exception:
|
|
self._pending.pop(session, None)
|
|
await self._send_json(writer, {"ok": False, "error": "agent_write_failed"})
|
|
return
|
|
|
|
await self._send_json(writer, {"ok": True, "session": session})
|
|
log.info("Client accepted serial=%s session=%s from %s", serial, session, peer)
|
|
|
|
try:
|
|
await asyncio.wait_for(ps.ready.wait(), timeout=10.0)
|
|
except asyncio.TimeoutError:
|
|
async with self._lock:
|
|
self._pending.pop(session, None)
|
|
await self._send_json(writer, {"ok": False, "error": "agent_data_timeout"})
|
|
return
|
|
|
|
try:
|
|
await ps.done
|
|
finally:
|
|
async with self._lock:
|
|
self._pending.pop(session, None)
|
|
|
|
async def _handle_agent_data(self, reader, writer, hello, peer) -> None:
|
|
session = str(hello.get("session", "")).strip()
|
|
if not session:
|
|
await self._send_json(writer, {"ok": False, "error": "bad_handshake"})
|
|
return
|
|
|
|
async with self._lock:
|
|
ps = self._pending.get(session)
|
|
|
|
if ps is None:
|
|
await self._send_json(writer, {"ok": False, "error": "unknown_session"})
|
|
return
|
|
|
|
await self._send_json(writer, {"ok": True})
|
|
log.info("Agent data connected session=%s from %s (pairing)", session, peer)
|
|
|
|
setattr(ps.client_writer, "_carbus_piped", True)
|
|
setattr(writer, "_carbus_piped", True)
|
|
|
|
ps.ready.set()
|
|
|
|
try:
|
|
await self._pipe(ps.client_reader, ps.client_writer, reader, writer)
|
|
finally:
|
|
if not ps.done.done():
|
|
ps.done.set_result(None)
|
|
|
|
async def _pipe(self, a_reader, a_writer, b_reader, b_writer, bufsize: int = 4096) -> None:
|
|
async def pump(src, dst):
|
|
try:
|
|
while True:
|
|
data = await src.read(bufsize)
|
|
if not data:
|
|
break
|
|
dst.write(data)
|
|
await dst.drain()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
dst.close()
|
|
except Exception:
|
|
pass
|
|
|
|
t1 = asyncio.create_task(pump(a_reader, b_writer))
|
|
t2 = asyncio.create_task(pump(b_reader, a_writer))
|
|
await asyncio.wait({t1, t2}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
for t in (t1, t2):
|
|
t.cancel()
|
|
try:
|
|
await t
|
|
except Exception:
|
|
pass
|
|
|
|
for w in (a_writer, b_writer):
|
|
try:
|
|
w.close()
|
|
except Exception:
|
|
pass
|
|
for w in (a_writer, b_writer):
|
|
try:
|
|
await w.wait_closed()
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
async def _send_json(writer, obj: dict) -> None:
|
|
writer.write((json.dumps(obj) + "\n").encode("utf-8"))
|
|
await writer.drain()
|
|
|
|
|
|
async def main(host: str = "0.0.0.0", port: int = 9000) -> None:
|
|
rs = RelayServer()
|
|
srv = await asyncio.start_server(rs.handle_conn, host, port)
|
|
log.info("Relay server listening on %s", ", ".join(str(s.getsockname()) for s in srv.sockets or []))
|
|
async with srv:
|
|
await srv.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
asyncio.run(main())
|