carbus_lib/carbus_async/remote/agent.py

208 lines
6.4 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import json
import logging
from typing import Optional
import serial_asyncio
log = logging.getLogger("carbus_remote.agent")
def parse_hostport(s: str) -> tuple[str, int]:
if ":" not in s:
raise ValueError("server must be host:port")
host, p = s.rsplit(":", 1)
return host, int(p)
async def pipe_bidirectional_streams(
a_reader: asyncio.StreamReader,
a_writer: asyncio.StreamWriter,
b_reader: asyncio.StreamReader,
b_writer: asyncio.StreamWriter,
*,
bufsize: int = 4096,
) -> None:
async def pump(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
try:
while True:
data = await src.read(bufsize)
if not data:
break
dst.write(data)
await dst.drain()
except Exception:
pass
finally:
try:
dst.close()
except Exception:
pass
t1 = asyncio.create_task(pump(a_reader, b_writer), name="pipe_a_to_b")
t2 = asyncio.create_task(pump(b_reader, a_writer), name="pipe_b_to_a")
await asyncio.wait({t1, t2}, return_when=asyncio.FIRST_COMPLETED)
for w in (a_writer, b_writer):
try:
w.close()
except Exception:
pass
await asyncio.gather(t1, t2, return_exceptions=True)
for w in (a_writer, b_writer):
try:
await w.wait_closed()
except Exception:
pass
async def open_serial_with_retry(port: str, baudrate: int, *, attempts: int = 3) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
last: Optional[BaseException] = None
for i in range(attempts):
try:
return await serial_asyncio.open_serial_connection(url=port, baudrate=baudrate)
except Exception as e:
last = e
await asyncio.sleep(0.25 + 0.25 * i)
assert last is not None
raise last
async def agent_run(
*,
port: str,
baudrate: int,
server: str,
serial: str,
password: str,
) -> None:
server_host, server_port = parse_hostport(server)
session_lock = asyncio.Lock()
async def open_data_session(session: str) -> None:
async with session_lock:
net_r: Optional[asyncio.StreamReader] = None
net_w: Optional[asyncio.StreamWriter] = None
dev_r: Optional[asyncio.StreamReader] = None
dev_w: Optional[asyncio.StreamWriter] = None
try:
log.info("Opening data session %s", session)
net_r, net_w = await asyncio.wait_for(
asyncio.open_connection(server_host, server_port),
timeout=5.0,
)
net_w.write((json.dumps({"role": "agent_data", "session": session}) + "\n").encode("utf-8"))
await net_w.drain()
line = await asyncio.wait_for(net_r.readline(), timeout=5.0)
resp = json.loads(line.decode("utf-8", errors="ignore") or "{}")
if not resp.get("ok"):
log.error("Data session refused %s: %s", session, resp)
return
log.info("Opening COM %s @ %d for session %s", port, baudrate, session)
dev_r, dev_w = await open_serial_with_retry(port, baudrate, attempts=3)
log.info("Session %s accepted. Piping bytes (COM <-> relay).", session)
await pipe_bidirectional_streams(dev_r, dev_w, net_r, net_w)
log.info("Session %s finished.", session)
except Exception:
log.exception("Data session %s crashed", session)
finally:
if net_w is not None:
try:
net_w.close()
await net_w.wait_closed()
except Exception:
pass
if dev_w is not None:
try:
dev_w.close()
await dev_w.wait_closed()
except Exception:
pass
log.info("Connecting to relay %s:%d (control)", server_host, server_port)
ctrl_reader, ctrl_writer = await asyncio.open_connection(server_host, server_port)
ctrl_writer.write((json.dumps({"role": "agent", "serial": serial, "password": password}) + "\n").encode("utf-8"))
await ctrl_writer.drain()
line = await asyncio.wait_for(ctrl_reader.readline(), timeout=10.0)
resp = json.loads(line.decode("utf-8", errors="ignore") or "{}")
if not resp.get("ok"):
ctrl_writer.close()
try:
await ctrl_writer.wait_closed()
except Exception:
pass
raise RuntimeError(f"relay refused agent: {resp}")
log.info("Agent registered OK. Waiting for sessions... (serial=%s)", serial)
try:
while True:
line = await ctrl_reader.readline()
if not line:
log.warning("Control connection closed by relay.")
break
try:
msg = json.loads(line.decode("utf-8", errors="ignore") or "{}")
except Exception:
continue
if msg.get("cmd") == "open_session":
session = str(msg.get("session", "")).strip()
if session:
asyncio.create_task(open_data_session(session), name=f"agent_data_session_{session}")
finally:
try:
ctrl_writer.close()
await ctrl_writer.wait_closed()
except Exception:
pass
log.info("Agent stopped.")
async def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--port", required=True, help="COM port, e.g. COM6")
ap.add_argument("--baudrate", type=int, default=115200)
ap.add_argument("--server", required=True, help="relay host:port, e.g. 1.2.3.4:9000")
ap.add_argument("--serial", required=True, help="device serial number (string/int)")
ap.add_argument("--password", required=True, help="shared password for this serial")
args = ap.parse_args()
await agent_run(
port=args.port,
baudrate=args.baudrate,
server=args.server,
serial=str(args.serial),
password=str(args.password),
)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
asyncio.run(main())