package websocket // WebSocket 传输层实现. // // 基于标准库 net/http + hijack 实现 WebSocket 协议(RFC 6455). // 支持: // - 文本和二进制消息 // - Ping/Pong 心跳(30 秒间隔) // - 自动重连(指数退避) // - 消息缓冲区(断线期间缓存消息,重连后重放) // - 最大消息大小限制(1MB) // - Close 握手 import ( "bufio" "context" "crypto/rand" "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" "fmt" "io" mrand "math/rand" "net" "net/http" "os" "strings" "sync" "time" ) // WebSocket 帧操作码 const ( opContinuation = 0x0 opText = 0x1 opBinary = 0x2 opClose = 0x8 opPing = 0x9 opPong = 0xA ) // WebSocket 常量 const ( // wsMaxMessageSize 最大消息大小:1MB wsMaxMessageSize = 1 * 1024 * 1024 // wsPingInterval Ping/Pong 心跳间隔 wsPingInterval = 30 * time.Second // wsPongTimeout 等待 Pong 响应的超时 wsPongTimeout = 10 * time.Second // wsBufferSize 消息缓冲区大小(断线缓存) wsBufferSize = 256 // wsMaxReconnectAttempts 最大重连次数 wsMaxReconnectAttempts = 10 // wsReconnectBaseDelay 重连基础延迟 wsReconnectBaseDelay = 1 * time.Second // wsReconnectMaxDelay 重连最大延迟 wsReconnectMaxDelay = 30 * time.Second // wsMagicGUID WebSocket 协议握手用的 Magic GUID wsMagicGUID = "258EAFA5-E914-47DA-95CA-5AB5DC508065" ) // WebSocketConfig 是 WebSocket 传输层的配置. type WebSocketConfig struct { // URL WebSocket 服务器地址(ws:// 或 wss://) URL string // MaxMessageSize 最大消息大小(字节),默认 1MB MaxMessageSize int // PingInterval Ping 心跳间隔,默认 30 秒 PingInterval time.Duration // AutoReconnect 是否启用自动重连 AutoReconnect bool // MaxReconnectAttempts 最大重连次数,默认 10 MaxReconnectAttempts int // BufferSize 消息缓冲区大小,默认 256 BufferSize int // OnMessage 消息回调(可选) OnMessage func(msgType int, data []byte) } // WebSocketTransport 是 WebSocket 传输层实现. type WebSocketTransport struct { cfg WebSocketConfig conn net.Conn reader *bufio.Reader writer *bufio.Writer receiveCh chan []byte sendBuffer [][]byte // 断线缓存 mu sync.Mutex connected bool closed bool closeCh chan struct{} closeOnce sync.Once ctx context.Context cancelFunc context.CancelFunc } // NewWebSocket 创建 WebSocket 传输层实例. func NewWebSocket(cfg WebSocketConfig) *WebSocketTransport { if cfg.MaxMessageSize <= 0 { cfg.MaxMessageSize = wsMaxMessageSize } if cfg.PingInterval <= 0 { cfg.PingInterval = wsPingInterval } if cfg.MaxReconnectAttempts <= 0 { cfg.MaxReconnectAttempts = wsMaxReconnectAttempts } if cfg.BufferSize <= 0 { cfg.BufferSize = wsBufferSize } return &WebSocketTransport{ cfg: cfg, receiveCh: make(chan []byte, cfg.BufferSize), closeCh: make(chan struct{}), } } // Connect 建立 WebSocket 连接. // 执行 HTTP 升级握手,然后切换到 WebSocket 帧协议. func (ws *WebSocketTransport) Connect(ctx context.Context) error { ws.mu.Lock() if ws.closed { ws.mu.Unlock() return fmt.Errorf("websocket: transport is closed") } ws.mu.Unlock() ws.ctx, ws.cancelFunc = context.WithCancel(ctx) if err := ws.dial(); err != nil { return err } // 启动读取循环和心跳 go ws.readLoop() go ws.pingLoop() return nil } // Send 发送消息. // 如果连接断开且启用了自动重连,消息会被缓存. func (ws *WebSocketTransport) Send(data []byte) error { ws.mu.Lock() defer ws.mu.Unlock() if ws.closed { return fmt.Errorf("websocket: transport is closed") } if !ws.connected { if ws.cfg.AutoReconnect { // 缓存消息,等重连后重放 if len(ws.sendBuffer) < ws.cfg.BufferSize { buf := make([]byte, len(data)) copy(buf, data) ws.sendBuffer = append(ws.sendBuffer, buf) return nil } return fmt.Errorf("websocket: send buffer full") } return fmt.Errorf("websocket: not connected") } return ws.writeFrame(opText, data) } // SendBinary 发送二进制消息. func (ws *WebSocketTransport) SendBinary(data []byte) error { ws.mu.Lock() defer ws.mu.Unlock() if ws.closed { return fmt.Errorf("websocket: transport is closed") } if !ws.connected { return fmt.Errorf("websocket: not connected") } return ws.writeFrame(opBinary, data) } // Receive 返回接收消息的 channel. func (ws *WebSocketTransport) Receive() <-chan []byte { return ws.receiveCh } // Close 优雅关闭 WebSocket 连接. // 发送 Close 帧并等待对端确认. func (ws *WebSocketTransport) Close() error { ws.closeOnce.Do(func() { ws.mu.Lock() ws.closed = true ws.mu.Unlock() close(ws.closeCh) if ws.cancelFunc != nil { ws.cancelFunc() } // 发送 Close 帧 ws.mu.Lock() if ws.connected && ws.conn != nil { // 关闭码 1000 = 正常关闭 payload := make([]byte, 2) binary.BigEndian.PutUint16(payload, 1000) ws.writeFrame(opClose, payload) ws.conn.Close() ws.connected = false } ws.mu.Unlock() }) return nil } // --- 内部实现 --- // dial 执行 WebSocket 握手. func (ws *WebSocketTransport) dial() error { // 解析 URL url := ws.cfg.URL useTLS := false if strings.HasPrefix(url, "ws://") { url = strings.TrimPrefix(url, "ws://") } else if strings.HasPrefix(url, "wss://") { // 升华改进(ELEVATED): wss:// TLS 支持--早期方案只支持 ws:// 明文连接, // 生产环境通常要求 wss://(如 HTTPS 反向代理后端). // 用 crypto/tls 标准库,无需外部依赖,保持零依赖原则. url = strings.TrimPrefix(url, "wss://") useTLS = true } // 分离 host 和 path host := url path := "/" if idx := strings.Index(url, "/"); idx >= 0 { host = url[:idx] path = url[idx:] } // 如果没有端口,添加默认端口 if !strings.Contains(host, ":") { if useTLS { host = host + ":443" } else { host = host + ":80" } } // 建立连接(TLS 或明文) var conn net.Conn var err error if useTLS { // 提取 SNI hostname(不含端口) serverName := host if idx := strings.LastIndex(host, ":"); idx >= 0 { serverName = host[:idx] } conn, err = tls.DialWithDialer( &net.Dialer{Timeout: 10 * time.Second}, "tcp", host, &tls.Config{ServerName: serverName}, ) } else { conn, err = net.DialTimeout("tcp", host, 10*time.Second) } if err != nil { return fmt.Errorf("websocket: dial %s: %w", host, err) } // 生成随机 key keyBytes := make([]byte, 16) if _, err := rand.Read(keyBytes); err != nil { conn.Close() return fmt.Errorf("websocket: generate key: %w", err) } key := base64.StdEncoding.EncodeToString(keyBytes) // 发送 HTTP 升级请求(Host 头省略默认端口:ws=80, wss=443) reqHost := host if strings.HasSuffix(reqHost, ":80") || strings.HasSuffix(reqHost, ":443") { if idx := strings.LastIndex(reqHost, ":"); idx >= 0 { reqHost = reqHost[:idx] } } upgradeReq := fmt.Sprintf( "GET %s HTTP/1.1\r\n"+ "Host: %s\r\n"+ "Upgrade: websocket\r\n"+ "Connection: Upgrade\r\n"+ "Sec-WebSocket-Key: %s\r\n"+ "Sec-WebSocket-Version: 13\r\n"+ "\r\n", path, reqHost, key) if _, err := conn.Write([]byte(upgradeReq)); err != nil { conn.Close() return fmt.Errorf("websocket: write upgrade request: %w", err) } // 读取 HTTP 响应 reader := bufio.NewReader(conn) resp, err := http.ReadResponse(reader, nil) if err != nil { conn.Close() return fmt.Errorf("websocket: read upgrade response: %w", err) } resp.Body.Close() if resp.StatusCode != 101 { conn.Close() return fmt.Errorf("websocket: upgrade failed with status %d", resp.StatusCode) } // 验证 Sec-WebSocket-Accept expectedAccept := computeAcceptKey(key) actualAccept := resp.Header.Get("Sec-WebSocket-Accept") if actualAccept != expectedAccept { conn.Close() return fmt.Errorf("websocket: invalid Sec-WebSocket-Accept") } ws.mu.Lock() ws.conn = conn ws.reader = reader ws.writer = bufio.NewWriter(conn) ws.connected = true ws.mu.Unlock() // 重放缓存的消息 ws.replayBuffer() return nil } // computeAcceptKey 计算 Sec-WebSocket-Accept 值. func computeAcceptKey(key string) string { h := sha1.New() h.Write([]byte(key + wsMagicGUID)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // readLoop 持续读取 WebSocket 帧. func (ws *WebSocketTransport) readLoop() { defer func() { if r := recover(); r != nil { fmt.Fprintf(os.Stderr, "websocket: readLoop panic: %v\n", r) } }() defer func() { ws.mu.Lock() ws.connected = false ws.mu.Unlock() // 如果需要自动重连 if ws.cfg.AutoReconnect { go ws.reconnectLoop() } }() for { select { case <-ws.closeCh: return default: } opcode, payload, err := ws.readFrame() if err != nil { if ws.isClosed() { return } // 连接断开 return } switch opcode { case opText, opBinary: if ws.cfg.OnMessage != nil { ws.cfg.OnMessage(int(opcode), payload) } // 推送到接收 channel(非阻塞) select { case ws.receiveCh <- payload: default: // 精妙之处(CLEVER): 缓冲区满时丢最旧消息而非最新消息-- // 消费者处理慢时,旧消息通常已过时(如光标位置更新),保留最新的更有意义. // 替代方案:<丢新消息(原方案)> - 高背压时消费者始终看到旧状态,UX 更差. select { case <-ws.receiveCh: // 丢最旧 default: } select { case ws.receiveCh <- payload: // 写最新 default: } fmt.Fprintf(os.Stderr, "websocket: receive buffer full, dropped oldest message (buffer=%d)\n", len(ws.receiveCh)) } case opPing: // 回复 Pong ws.mu.Lock() ws.writeFrame(opPong, payload) ws.mu.Unlock() case opPong: // 心跳回复,不需要处理 case opClose: // 收到关闭帧 if !ws.isClosed() { // 回复关闭帧 ws.mu.Lock() ws.writeFrame(opClose, payload) ws.mu.Unlock() } return } } } // readFrame 读取一个 WebSocket 帧. // 返回操作码,负载和错误. func (ws *WebSocketTransport) readFrame() (byte, []byte, error) { ws.mu.Lock() reader := ws.reader ws.mu.Unlock() if reader == nil { return 0, nil, fmt.Errorf("websocket: no reader") } // 读取第一个字节:FIN + opcode b1, err := reader.ReadByte() if err != nil { return 0, nil, err } opcode := b1 & 0x0F // 读取第二个字节:MASK + payload length b2, err := reader.ReadByte() if err != nil { return 0, nil, err } masked := (b2 & 0x80) != 0 payloadLen := int64(b2 & 0x7F) // 扩展长度 switch payloadLen { case 126: var buf [2]byte if _, err := io.ReadFull(reader, buf[:]); err != nil { return 0, nil, err } payloadLen = int64(binary.BigEndian.Uint16(buf[:])) case 127: var buf [8]byte if _, err := io.ReadFull(reader, buf[:]); err != nil { return 0, nil, err } payloadLen = int64(binary.BigEndian.Uint64(buf[:])) } // 检查消息大小限制 if payloadLen > int64(ws.cfg.MaxMessageSize) { return 0, nil, fmt.Errorf("websocket: message too large: %d > %d", payloadLen, ws.cfg.MaxMessageSize) } // 读取 masking key(如果有) var maskKey [4]byte if masked { if _, err := io.ReadFull(reader, maskKey[:]); err != nil { return 0, nil, err } } // 读取负载 payload := make([]byte, payloadLen) if payloadLen > 0 { if _, err := io.ReadFull(reader, payload); err != nil { return 0, nil, err } } // 解码 mask if masked { for i := range payload { payload[i] ^= maskKey[i%4] } } return opcode, payload, nil } // writeFrame 写入一个 WebSocket 帧. // 调用者需要持有 ws.mu 锁. func (ws *WebSocketTransport) writeFrame(opcode byte, payload []byte) error { if ws.writer == nil || ws.conn == nil { return fmt.Errorf("websocket: no writer") } // 设置写入超时 ws.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) defer ws.conn.SetWriteDeadline(time.Time{}) // 第一个字节:FIN=1 + opcode ws.writer.WriteByte(0x80 | opcode) // 客户端发送的帧必须 mask maskBit := byte(0x80) // 负载长度 payloadLen := len(payload) switch { case payloadLen <= 125: ws.writer.WriteByte(maskBit | byte(payloadLen)) case payloadLen <= 65535: ws.writer.WriteByte(maskBit | 126) var buf [2]byte binary.BigEndian.PutUint16(buf[:], uint16(payloadLen)) ws.writer.Write(buf[:]) default: ws.writer.WriteByte(maskBit | 127) var buf [8]byte binary.BigEndian.PutUint64(buf[:], uint64(payloadLen)) ws.writer.Write(buf[:]) } // Masking key(随机生成) var maskKey [4]byte rand.Read(maskKey[:]) ws.writer.Write(maskKey[:]) // 写入 masked 负载 for i, b := range payload { ws.writer.WriteByte(b ^ maskKey[i%4]) } return ws.writer.Flush() } // pingLoop 定时发送 Ping 帧. func (ws *WebSocketTransport) pingLoop() { defer func() { if r := recover(); r != nil { fmt.Fprintf(os.Stderr, "websocket: pingLoop panic: %v\n", r) } }() ticker := time.NewTicker(ws.cfg.PingInterval) defer ticker.Stop() for { select { case <-ws.closeCh: return case <-ticker.C: ws.mu.Lock() if ws.connected { ws.writeFrame(opPing, []byte("ping")) } ws.mu.Unlock() } } } // reconnectLoop 自动重连循环(指数退避). func (ws *WebSocketTransport) reconnectLoop() { defer func() { if r := recover(); r != nil { fmt.Fprintf(os.Stderr, "websocket: reconnectLoop panic: %v\n", r) } }() delay := wsReconnectBaseDelay for attempt := 0; attempt < ws.cfg.MaxReconnectAttempts; attempt++ { select { case <-ws.closeCh: return case <-time.After(delay): } if ws.isClosed() { return } if err := ws.dial(); err != nil { // 指数退避 + ±25% jitter,防止多客户端同时重连惊群(与 mcp/manager.go 保持一致) delay = delay * 2 if delay > wsReconnectMaxDelay { delay = wsReconnectMaxDelay } jitter := delay / 4 delay = delay - jitter/2 + time.Duration(mrand.Int63n(int64(jitter)+1)) continue } // 重连成功,重启读取循环和心跳 go ws.readLoop() go ws.pingLoop() return } } // replayBuffer 重连后重放缓存的消息. func (ws *WebSocketTransport) replayBuffer() { ws.mu.Lock() buf := ws.sendBuffer ws.sendBuffer = nil ws.mu.Unlock() for _, data := range buf { ws.mu.Lock() ws.writeFrame(opText, data) ws.mu.Unlock() } } // isClosed 检查是否已关闭. func (ws *WebSocketTransport) isClosed() bool { select { case <-ws.closeCh: return true default: return false } } // --- 服务端支持:HTTP Upgrade 辅助函数 --- // UpgradeHTTP 将 HTTP 请求升级为 WebSocket 连接(服务端用). // 使用 net/http 的 Hijack 接口接管底层 TCP 连接. // 返回一个 WebSocketConn 用于收发消息. func UpgradeHTTP(w http.ResponseWriter, r *http.Request) (*WebSocketConn, error) { // 校验请求 if r.Method != http.MethodGet { return nil, fmt.Errorf("websocket: expected GET, got %s", r.Method) } if !headerContains(r.Header, "Connection", "upgrade") { return nil, fmt.Errorf("websocket: missing Connection: upgrade header") } if !headerContains(r.Header, "Upgrade", "websocket") { return nil, fmt.Errorf("websocket: missing Upgrade: websocket header") } key := r.Header.Get("Sec-WebSocket-Key") if key == "" { return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") } // 计算 accept key acceptKey := computeAcceptKey(key) // Hijack 连接 hj, ok := w.(http.Hijacker) if !ok { return nil, fmt.Errorf("websocket: response does not support hijacking") } conn, bufrw, err := hj.Hijack() if err != nil { return nil, fmt.Errorf("websocket: hijack failed: %w", err) } // 发送 101 Switching Protocols 响应 resp := "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + acceptKey + "\r\n" + "\r\n" if _, err := bufrw.WriteString(resp); err != nil { conn.Close() return nil, fmt.Errorf("websocket: write upgrade response: %w", err) } if err := bufrw.Flush(); err != nil { conn.Close() return nil, fmt.Errorf("websocket: flush upgrade response: %w", err) } return &WebSocketConn{ conn: conn, reader: bufrw.Reader, writer: bufrw.Writer, }, nil } // WebSocketConn 是服务端的 WebSocket 连接. type WebSocketConn struct { conn net.Conn reader *bufio.Reader writer *bufio.Writer mu sync.Mutex } // ReadMessage 读取一条消息. func (c *WebSocketConn) ReadMessage() (int, []byte, error) { b1, err := c.reader.ReadByte() if err != nil { return 0, nil, err } opcode := int(b1 & 0x0F) b2, err := c.reader.ReadByte() if err != nil { return 0, nil, err } masked := (b2 & 0x80) != 0 payloadLen := int64(b2 & 0x7F) switch payloadLen { case 126: var buf [2]byte if _, err := io.ReadFull(c.reader, buf[:]); err != nil { return 0, nil, err } payloadLen = int64(binary.BigEndian.Uint16(buf[:])) case 127: var buf [8]byte if _, err := io.ReadFull(c.reader, buf[:]); err != nil { return 0, nil, err } payloadLen = int64(binary.BigEndian.Uint64(buf[:])) } if payloadLen > wsMaxMessageSize { return 0, nil, fmt.Errorf("websocket: message too large: %d", payloadLen) } var maskKey [4]byte if masked { if _, err := io.ReadFull(c.reader, maskKey[:]); err != nil { return 0, nil, err } } payload := make([]byte, payloadLen) if payloadLen > 0 { if _, err := io.ReadFull(c.reader, payload); err != nil { return 0, nil, err } } if masked { for i := range payload { payload[i] ^= maskKey[i%4] } } return opcode, payload, nil } // WriteMessage 写入一条消息(服务端帧不需要 mask). func (c *WebSocketConn) WriteMessage(opcode int, payload []byte) error { c.mu.Lock() defer c.mu.Unlock() c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) defer c.conn.SetWriteDeadline(time.Time{}) // FIN=1 + opcode c.writer.WriteByte(0x80 | byte(opcode)) // 服务端不 mask payloadLen := len(payload) switch { case payloadLen <= 125: c.writer.WriteByte(byte(payloadLen)) case payloadLen <= 65535: c.writer.WriteByte(126) var buf [2]byte binary.BigEndian.PutUint16(buf[:], uint16(payloadLen)) c.writer.Write(buf[:]) default: c.writer.WriteByte(127) var buf [8]byte binary.BigEndian.PutUint64(buf[:], uint64(payloadLen)) c.writer.Write(buf[:]) } c.writer.Write(payload) return c.writer.Flush() } // Close 关闭连接. func (c *WebSocketConn) Close() error { c.mu.Lock() defer c.mu.Unlock() // 发送 Close 帧 payload := make([]byte, 2) binary.BigEndian.PutUint16(payload, 1000) c.writer.WriteByte(0x80 | opClose) c.writer.WriteByte(byte(len(payload))) c.writer.Write(payload) c.writer.Flush() return c.conn.Close() } // headerContains 检查 HTTP header 是否包含指定值(不区分大小写). func headerContains(h http.Header, key, value string) bool { for _, v := range h[key] { for _, s := range strings.Split(v, ",") { if strings.EqualFold(strings.TrimSpace(s), value) { return true } } } return false } // Ensure WebSocketTransport 满足 Transport 接口. var _ Transport = (*WebSocketTransport)(nil)