// Package transport - websocket.go 实现 bridge.BridgeTransport 的 WebSocket 协议. // // 协议约定: // - 客户端 GET /ws?session_id= 建立双向 WS 通道 // - 上行: 客户端发 WS text frame, payload 是 bridge.ClientMessage JSON // - 下行: 服务端发 WS text frame, payload 是 bridge.BridgeEvent JSON // - Keepalive: 服务端每 PingInterval 主动发 Ping, 客户端回 Pong // // 与 SSE 的差异: // - 单端点双向, 不需要分离 /events (GET 流) 和 /messages (POST) // - 低延迟 (无 POST 往返成本), 但企业代理/WAF 可能拦 Upgrade 帧 // - 不原生支持 Last-Event-ID 断线重连 -- 本最小版未实现 ResumeWindow, // 客户端必须自行保留未确认事件序号并在重连 query 里传 last_event_id. // SSE 环形缓冲设计迁移 WS 需要新 ack 协议, 非零成本, 留作独立 TODO. // // 升华改进(ELEVATED): 复用 pkg/websocket.UpgradeHTTP + WebSocketConn 的 825 行 // 底层 RFC 6455 实现, bridge 层只做 SessionConn 契约适配, 协议栈零重复. // 替代方案(路径 B, 否决): 让 websocket 包反向 import bridge 自暴露 BridgeTransport -- // 破坏 websocket 包自包含性且依赖层级倒挂 (bridge 消费 websocket 是正向, 反过来不是). package transport import ( "context" "encoding/json" "fmt" "log" "net/http" "strings" "sync" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/bridge" "git.flytoex.net/yuanwei/flyto-agent/pkg/websocket" ) // WS 帧操作码本地副本. // 精妙之处(CLEVER): websocket 包把这些 const 定义为未导出 (opText/opClose 等), // 避免污染公开 API. transport 层需要手写操作码逻辑 (readLoop switch/pingLoop 主动发 Ping), // 直接抄一份比让 websocket 包导出这组细节更清洁 -- 只有 bridge 层需要. const ( wsOpText = 0x1 wsOpBinary = 0x2 wsOpClose = 0x8 wsOpPing = 0x9 wsOpPong = 0xA ) // WebSocketTransport 实现 bridge.BridgeTransport, 使用 WebSocket 协议. type WebSocketTransport struct { cfg bridge.BridgeConfig prefix string mu sync.RWMutex conns map[string]*WebSocketSessionConn acceptCh chan bridge.SessionConn closed bool closedCh chan struct{} } // NewWebSocketTransport 创建 WebSocket 传输层. // // prefix 默认 "/bridge", 端点 GET /ws?session_id=. func NewWebSocketTransport(prefix string, cfg bridge.BridgeConfig) *WebSocketTransport { if prefix == "" { prefix = "/bridge" } return &WebSocketTransport{ cfg: cfg, prefix: strings.TrimRight(prefix, "/"), conns: make(map[string]*WebSocketSessionConn), acceptCh: make(chan bridge.SessionConn, 64), closedCh: make(chan struct{}), } } // ServeHTTP 实现 http.Handler, 路由 WS 升级端点. func (t *WebSocketTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, t.prefix) if r.Method == http.MethodGet && path == "/ws" { t.handleConnect(w, r) return } http.NotFound(w, r) } // Accept 实现 bridge.BridgeTransport. func (t *WebSocketTransport) Accept() <-chan bridge.SessionConn { return t.acceptCh } // Close 关闭传输层, 断开所有连接. func (t *WebSocketTransport) Close() error { t.mu.Lock() defer t.mu.Unlock() if t.closed { return nil } t.closed = true close(t.closedCh) for _, c := range t.conns { c.closeInternal() } close(t.acceptCh) return nil } func (t *WebSocketTransport) handleConnect(w http.ResponseWriter, r *http.Request) { t.mu.RLock() if t.closed { t.mu.RUnlock() http.Error(w, "transport closed", http.StatusServiceUnavailable) return } t.mu.RUnlock() sessionID := r.URL.Query().Get("session_id") if sessionID == "" { http.Error(w, "missing session_id", http.StatusBadRequest) return } wsConn, err := websocket.UpgradeHTTP(w, r) if err != nil { // UpgradeHTTP 已 Hijack 成功后的错误无法再写 http.Error. // 未 Hijack 的早期错误 (校验 header) 由底层返回后这里只能 log. log.Printf("ws: upgrade failed for session %s: %v", sessionID, err) return } conn := t.getOrCreateConn(sessionID, wsConn) select { case t.acceptCh <- conn: default: // acceptCh 满了代表 DaemonManager 没有及时消费, 记录日志但不阻塞. log.Printf("ws: acceptCh full, dropping conn for session %s", sessionID) } go conn.readLoop() go conn.pingLoop() // 阻塞直到连接断开; 释放 HTTP handler 让 http.Server 能清理请求资源. select { case <-conn.Done(): case <-t.closedCh: conn.closeInternal() case <-r.Context().Done(): conn.closeInternal() } } // getOrCreateConn 按 sessionID 注册连接. // 同 sessionID 已有连接 (重连场景) -- 关旧的, 换新的. func (t *WebSocketTransport) getOrCreateConn(sessionID string, wsConn *websocket.WebSocketConn) *WebSocketSessionConn { t.mu.Lock() defer t.mu.Unlock() if old, exists := t.conns[sessionID]; exists { old.closeInternal() } conn := newWebSocketSessionConn(sessionID, wsConn, t.cfg) t.conns[sessionID] = conn return conn } // ─── WebSocketSessionConn ───────────────────────────────────────────────────── // WebSocketSessionConn 实现 bridge.SessionConn, 基于 WebSocket 协议. type WebSocketSessionConn struct { sessionID string conn *websocket.WebSocketConn cfg bridge.BridgeConfig recvCh chan bridge.ClientMessage doneCh chan struct{} closeOnce sync.Once } func newWebSocketSessionConn(sessionID string, conn *websocket.WebSocketConn, cfg bridge.BridgeConfig) *WebSocketSessionConn { bufSize := cfg.EventBufferSize if bufSize <= 0 { bufSize = 256 } return &WebSocketSessionConn{ sessionID: sessionID, conn: conn, cfg: cfg, recvCh: make(chan bridge.ClientMessage, bufSize), doneCh: make(chan struct{}), } } func (c *WebSocketSessionConn) SessionID() string { return c.sessionID } // Send 将 BridgeEvent 序列化为 JSON 通过 WS text frame 推送给客户端. func (c *WebSocketSessionConn) Send(ctx context.Context, evt bridge.BridgeEvent) error { select { case <-c.doneCh: return bridge.ErrConnClosed default: } payload, err := json.Marshal(evt) if err != nil { return fmt.Errorf("marshal bridge event: %w", err) } // 底层 WebSocketConn.WriteMessage 自带 10s 写 deadline. // 包一层 goroutine 让 ctx cancellation 可提前返回. done := make(chan error, 1) go func() { done <- c.conn.WriteMessage(wsOpText, payload) }() select { case err := <-done: if err != nil { c.closeInternal() return bridge.ErrConnClosed } return nil case <-ctx.Done(): return ctx.Err() case <-c.doneCh: return bridge.ErrConnClosed } } func (c *WebSocketSessionConn) Recv() <-chan bridge.ClientMessage { return c.recvCh } func (c *WebSocketSessionConn) Done() <-chan struct{} { return c.doneCh } func (c *WebSocketSessionConn) Close() error { c.closeInternal() return nil } func (c *WebSocketSessionConn) closeInternal() { c.closeOnce.Do(func() { // 先关 doneCh 让上下游感知, 再关底层 WS (发 close frame + close TCP). close(c.doneCh) _ = c.conn.Close() }) } // readLoop 持续读客户端帧, 将 text/binary 解析为 ClientMessage 投递给 recvCh. // 自动处理 Ping/Pong/Close 控制帧. func (c *WebSocketSessionConn) readLoop() { defer c.closeInternal() for { opcode, payload, err := c.conn.ReadMessage() if err != nil { return } switch opcode { case wsOpText, wsOpBinary: var msg bridge.ClientMessage if err := json.Unmarshal(payload, &msg); err != nil { // 单条恶意/格式错误消息不能导致整个 session 断连. log.Printf("ws: malformed ClientMessage on session %s: %v", c.sessionID, err) continue } // 客户端可省略 session_id (帧已绑定 session), 用连接的 sessionID 兜底. if msg.SessionID == "" { msg.SessionID = c.sessionID } select { case c.recvCh <- msg: case <-c.doneCh: return } case wsOpPing: // RFC 6455: 收到 Ping 必须回 Pong 并带相同 payload. _ = c.conn.WriteMessage(wsOpPong, payload) case wsOpPong: // 客户端回的 Pong, 保活信号, 忽略. case wsOpClose: return } } } // pingLoop 定期主动发 Ping 维持连接 (防负载均衡/代理 idle timeout). func (c *WebSocketSessionConn) pingLoop() { interval := c.cfg.PingInterval if interval <= 0 { interval = 15 * time.Second } ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-c.doneCh: return case <-ticker.C: if err := c.conn.WriteMessage(wsOpPing, nil); err != nil { c.closeInternal() return } } } }