package transport_test import ( "context" "encoding/json" "net/http/httptest" "strings" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/bridge" "git.flytoex.net/yuanwei/flyto-agent/pkg/bridge/transport" "git.flytoex.net/yuanwei/flyto-agent/pkg/websocket" ) // wsURLFromHTTP 把 httptest.Server 的 http:// 前缀换成 ws://. func wsURLFromHTTP(httpURL, path string) string { return "ws://" + strings.TrimPrefix(httpURL, "http://") + path } // newClient 帮助函数: 用 pkg/websocket 的客户端连到 transport 服务器. // 返回 cli (调用方 defer Close), 测试失败时直接 t.Fatal. func newClient(t *testing.T, url string) *websocket.WebSocketTransport { t.Helper() cli := websocket.NewWebSocket(websocket.WebSocketConfig{URL: url}) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := cli.Connect(ctx); err != nil { t.Fatalf("client connect: %v", err) } return cli } func TestWebSocketTransport_Roundtrip(t *testing.T) { tr := transport.NewWebSocketTransport("/bridge", bridge.DefaultBridgeConfig()) srv := httptest.NewServer(tr) defer srv.Close() defer tr.Close() cli := newClient(t, wsURLFromHTTP(srv.URL, "/bridge/ws?session_id=s1")) defer cli.Close() // 等 Accept 拿到 SessionConn var conn bridge.SessionConn select { case conn = <-tr.Accept(): case <-time.After(2 * time.Second): t.Fatal("no accept within 2s") } if conn.SessionID() != "s1" { t.Fatalf("session id: got %q want %q", conn.SessionID(), "s1") } // 下行: 服务端 Send(BridgeEvent) → 客户端 Receive 拿到 JSON evt := bridge.BridgeEvent{ ID: "e1", Type: "text_delta", SessionID: "s1", Payload: json.RawMessage(`{"text":"hi"}`), Timestamp: time.Now(), } sendCtx, sendCancel := context.WithTimeout(context.Background(), 2*time.Second) defer sendCancel() if err := conn.Send(sendCtx, evt); err != nil { t.Fatalf("server send: %v", err) } select { case raw := <-cli.Receive(): var got bridge.BridgeEvent if err := json.Unmarshal(raw, &got); err != nil { t.Fatalf("client unmarshal event: %v", err) } if got.ID != "e1" || got.Type != "text_delta" { t.Fatalf("event mismatch: got %+v", got) } case <-time.After(2 * time.Second): t.Fatal("no downstream event in 2s") } // 上行: 客户端 Send(ClientMessage JSON) → 服务端 Recv() msg := bridge.ClientMessage{ Type: bridge.ClientMessagePrompt, SessionID: "s1", Prompt: "hello", } raw, _ := json.Marshal(msg) if err := cli.Send(raw); err != nil { t.Fatalf("client send: %v", err) } select { case got := <-conn.Recv(): if got.Prompt != "hello" || got.Type != bridge.ClientMessagePrompt { t.Fatalf("upstream msg mismatch: got %+v", got) } case <-time.After(2 * time.Second): t.Fatal("no upstream message in 2s") } // 关闭 conn.Close() select { case <-conn.Done(): case <-time.After(1 * time.Second): t.Fatal("conn.Done not closed after Close()") } } func TestWebSocketTransport_MissingSessionID(t *testing.T) { tr := transport.NewWebSocketTransport("/bridge", bridge.DefaultBridgeConfig()) srv := httptest.NewServer(tr) defer srv.Close() defer tr.Close() cli := websocket.NewWebSocket(websocket.WebSocketConfig{ URL: wsURLFromHTTP(srv.URL, "/bridge/ws"), }) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err := cli.Connect(ctx) if err == nil { cli.Close() t.Fatal("expected Connect error for missing session_id, got nil") } } func TestWebSocketTransport_TransportClose(t *testing.T) { tr := transport.NewWebSocketTransport("/bridge", bridge.DefaultBridgeConfig()) if err := tr.Close(); err != nil { t.Fatalf("first close: %v", err) } // 幂等 if err := tr.Close(); err != nil { t.Fatalf("idempotent close: %v", err) } // Accept channel 已关闭 if _, ok := <-tr.Accept(); ok { t.Fatal("Accept channel should be closed after Close()") } } func TestWebSocketTransport_Reconnect_ReplaceOldConn(t *testing.T) { tr := transport.NewWebSocketTransport("/bridge", bridge.DefaultBridgeConfig()) srv := httptest.NewServer(tr) defer srv.Close() defer tr.Close() url := wsURLFromHTTP(srv.URL, "/bridge/ws?session_id=reuse") // 第一次连接 cli1 := newClient(t, url) var conn1 bridge.SessionConn select { case conn1 = <-tr.Accept(): case <-time.After(2 * time.Second): t.Fatal("first accept timeout") } // 同 sessionID 第二次连接, 旧 conn 应被关闭 cli2 := newClient(t, url) defer cli2.Close() select { case <-conn1.Done(): // 预期: 旧 conn 被 replace 时 Done 会 fire case <-time.After(2 * time.Second): t.Fatal("old conn1.Done not fired after reconnect") } cli1.Close() var conn2 bridge.SessionConn select { case conn2 = <-tr.Accept(): case <-time.After(2 * time.Second): t.Fatal("second accept timeout") } if conn2.SessionID() != "reuse" { t.Fatalf("new session id: got %q want reuse", conn2.SessionID()) } }