// roundtrip_test.go - WebSocket 端到端测试 // // 精妙之处(CLEVER): 用 httptest.NewServer + UpgradeHTTP 起一个真实的 WS server, // 客户端通过 ws://127.0.0.1:port 真实建立连接,验证完整握手 + 帧编解码 + close 握手. // 不 mock 任何东西,证明协议实现整体 round-trip 正确. // // 历史包袱(LEGACY): httptest.NewServer 的 ResponseWriter 实现了 http.Hijacker // 接口(自 Go 1.x 起),所以 UpgradeHTTP 在 httptest 里能工作.如果改用 httptest.Server // 的 wrapped handler 模式失败,回退方案是用 net.Listen("tcp", ":0") 自起 server. package websocket import ( "context" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" ) // echoServer 起一个 WebSocket echo server,把收到的文本消息原样发回. func echoServer(t *testing.T) (*httptest.Server, string) { t.Helper() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := UpgradeHTTP(w, r) if err != nil { t.Logf("server upgrade: %v", err) return } defer conn.Close() for { opcode, payload, err := conn.ReadMessage() if err != nil { return } switch opcode { case opText, opBinary: _ = conn.WriteMessage(opcode, payload) case opClose: // 客户端发了 close 帧,回应后退出 _ = conn.WriteMessage(opClose, payload) return case opPing: _ = conn.WriteMessage(opPong, payload) } } })) wsURL := strings.Replace(server.URL, "http://", "ws://", 1) return server, wsURL } // === 端到端 Roundtrip === func TestRoundTrip_TextMessage(t *testing.T) { server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatalf("Connect 失败: %v", err) } defer ws.Close() if err := ws.Send([]byte("hello")); err != nil { t.Fatalf("Send 失败: %v", err) } select { case msg := <-ws.Receive(): if string(msg) != "hello" { t.Errorf("收到 %q, 期望 hello", msg) } case <-time.After(2 * time.Second): t.Fatal("Receive 超时") } } func TestRoundTrip_BinaryMessage(t *testing.T) { server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatal(err) } defer ws.Close() binData := []byte{0x00, 0x01, 0x02, 0xFF, 0xFE} if err := ws.SendBinary(binData); err != nil { t.Fatalf("SendBinary 失败: %v", err) } select { case msg := <-ws.Receive(): if len(msg) != len(binData) { t.Fatalf("长度 %d, 期望 %d", len(msg), len(binData)) } for i := range binData { if msg[i] != binData[i] { t.Errorf("byte[%d] = %x, 期望 %x", i, msg[i], binData[i]) } } case <-time.After(2 * time.Second): t.Fatal("Receive 超时") } } func TestRoundTrip_MultipleMessages(t *testing.T) { server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatal(err) } defer ws.Close() messages := []string{"first", "second", "third"} for _, m := range messages { if err := ws.Send([]byte(m)); err != nil { t.Fatalf("Send %q: %v", m, err) } } for i, want := range messages { select { case got := <-ws.Receive(): if string(got) != want { t.Errorf("msg[%d] = %q, want %q", i, got, want) } case <-time.After(2 * time.Second): t.Fatalf("Receive #%d 超时", i) } } } func TestRoundTrip_LargeMessage(t *testing.T) { // 测试 16-bit 长度(126 ≤ len ≤ 65535)和 64-bit 长度(> 65535)的帧编解码 server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatal(err) } defer ws.Close() // 200 字节(短帧 < 126,单字节长度) short := strings.Repeat("a", 200) // 1000 字节(中帧 126 ≤ len ≤ 65535,2 字节长度) medium := strings.Repeat("b", 1000) // 70000 字节(长帧 > 65535,8 字节长度) long := strings.Repeat("c", 70000) for _, payload := range []string{short, medium, long} { if err := ws.Send([]byte(payload)); err != nil { t.Fatalf("Send len=%d 失败: %v", len(payload), err) } } for _, want := range []string{short, medium, long} { select { case got := <-ws.Receive(): if len(got) != len(want) { t.Errorf("收到长度 %d, 期望 %d", len(got), len(want)) } if string(got) != want { t.Error("payload 内容不匹配") } case <-time.After(5 * time.Second): t.Fatalf("收 len=%d 超时", len(want)) } } } func TestRoundTrip_CloseFromClient(t *testing.T) { server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatal(err) } // 主动 close if err := ws.Close(); err != nil { t.Errorf("Close 失败: %v", err) } if !ws.isClosed() { t.Error("Close 后 isClosed 应为 true") } // Close 后 Send 应失败 if err := ws.Send([]byte("after close")); err == nil { t.Error("Close 后 Send 应失败") } } // === UpgradeHTTP 错误路径 === func TestUpgradeHTTP_NonGET(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := UpgradeHTTP(w, r) if err == nil { t.Error("非 GET 应失败") } })) defer server.Close() resp, err := http.Post(server.URL, "text/plain", nil) if err == nil { resp.Body.Close() } // 服务端检查在 handler 内,状态码可能正常但 hijack 没发生 } func TestUpgradeHTTP_MissingConnectionHeader(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := UpgradeHTTP(w, r) if err == nil { t.Error("缺 Connection: upgrade 应失败") } if !strings.Contains(err.Error(), "Connection") { t.Errorf("错误应提到 Connection: %v", err) } })) defer server.Close() req, _ := http.NewRequest("GET", server.URL, nil) // 故意不设 Connection header resp, err := http.DefaultClient.Do(req) if err == nil { resp.Body.Close() } } func TestUpgradeHTTP_MissingUpgradeHeader(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := UpgradeHTTP(w, r) if err == nil { t.Error("缺 Upgrade: websocket 应失败") } if !strings.Contains(err.Error(), "Upgrade") { t.Errorf("错误应提到 Upgrade: %v", err) } })) defer server.Close() req, _ := http.NewRequest("GET", server.URL, nil) req.Header.Set("Connection", "upgrade") // 不设 Upgrade resp, err := http.DefaultClient.Do(req) if err == nil { resp.Body.Close() } } func TestUpgradeHTTP_MissingSecWebSocketKey(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := UpgradeHTTP(w, r) if err == nil { t.Error("缺 Sec-WebSocket-Key 应失败") } if !strings.Contains(err.Error(), "Sec-WebSocket-Key") { t.Errorf("错误应提到 Sec-WebSocket-Key: %v", err) } })) defer server.Close() req, _ := http.NewRequest("GET", server.URL, nil) req.Header.Set("Connection", "upgrade") req.Header.Set("Upgrade", "websocket") // 不设 Sec-WebSocket-Key resp, err := http.DefaultClient.Do(req) if err == nil { resp.Body.Close() } } // === Connect 错误路径 === func TestConnect_BadURL(t *testing.T) { ws := NewWebSocket(WebSocketConfig{URL: "ws://127.0.0.1:1"}) // 不可达端口 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() err := ws.Connect(ctx) if err == nil { t.Error("不可达 URL 应失败") } } func TestConnect_NonWebSocketServer(t *testing.T) { // 起一个普通 HTTP server,不实现 WS upgrade server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("not a websocket")) })) defer server.Close() wsURL := strings.Replace(server.URL, "http://", "ws://", 1) ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := ws.Connect(ctx) if err == nil { t.Error("非 WS 服务器应导致 Connect 失败") } // 应该是 status 不是 101 if err != nil && !strings.Contains(err.Error(), "101") && !strings.Contains(err.Error(), "upgrade") && !strings.Contains(err.Error(), "Accept") { t.Logf("错误信息(仅供参考): %v", err) } } // === 并发安全 === func TestRoundTrip_ConcurrentSend(t *testing.T) { server, wsURL := echoServer(t) defer server.Close() ws := NewWebSocket(WebSocketConfig{URL: wsURL}) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := ws.Connect(ctx); err != nil { t.Fatal(err) } defer ws.Close() // 10 goroutines × 5 msgs = 50 messages var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func(id int) { defer wg.Done() for j := 0; j < 5; j++ { _ = ws.Send([]byte("concurrent")) } }(i) } wg.Wait() // 排空 receive channel 直到 50 条消息或超时 count := 0 timeout := time.After(3 * time.Second) loop: for count < 50 { select { case <-ws.Receive(): count++ case <-timeout: break loop } } if count != 50 { t.Errorf("收到 %d 条消息,期望 50", count) } }