// elicitation_test.go - MCP 11.5 Elicitation 测试. // // 覆盖场景: // - ElicitationCreateResult / ElicitationCreateParams JSON 序列化 // - ElicitationSchema 结构验证 // - Client.SetElicitationHandler 注入 // - handleElicitationCreate:合法 params → handler 被调用 → 响应发回 // - handleElicitationCreate:无 handler → 自动 cancel // - handleElicitationCreate:无效 params → -32600 error // - handleServerRequest:未知方法 → -32601 MethodNotFound // - dispatchLoop server-to-client 请求路由(有 ID + 有 Method) // - rawMessage 三种类型识别(响应/通知/server 请求) // - Manager.SetElicitationHandler 注入 package mcp import ( "context" "encoding/json" "strings" "sync" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/config" "git.flytoex.net/yuanwei/flyto-agent/pkg/execenv" ) // ── ElicitationCreateResult 序列化测试 ────────────────────────────────────── // TestElicitationCreateResult_MarshalAccept 验证 accept 响应正确序列化. func TestElicitationCreateResult_MarshalAccept(t *testing.T) { result := ElicitationCreateResult{ Action: "accept", Content: map[string]any{ "table_name": "orders", }, } data, err := json.Marshal(result) if err != nil { t.Fatalf("marshal: %v", err) } var parsed map[string]any if err := json.Unmarshal(data, &parsed); err != nil { t.Fatalf("unmarshal: %v", err) } if parsed["action"] != "accept" { t.Errorf("action = %v, want accept", parsed["action"]) } content, ok := parsed["content"].(map[string]any) if !ok { t.Fatal("content should be an object") } if content["table_name"] != "orders" { t.Errorf("table_name = %v, want orders", content["table_name"]) } } // TestElicitationCreateResult_MarshalCancel 验证 cancel 响应不含 content 字段(omitempty). func TestElicitationCreateResult_MarshalCancel(t *testing.T) { result := ElicitationCreateResult{Action: "cancel"} data, err := json.Marshal(result) if err != nil { t.Fatalf("marshal: %v", err) } var parsed map[string]any json.Unmarshal(data, &parsed) if parsed["action"] != "cancel" { t.Errorf("action = %v, want cancel", parsed["action"]) } if _, hasContent := parsed["content"]; hasContent { t.Error("cancel result should not have content field (omitempty)") } } // ── ElicitationSchema 测试 ────────────────────────────────────────────────── // TestElicitationSchema_Properties 验证 schema 结构正确存储. func TestElicitationSchema_Properties(t *testing.T) { schema := &ElicitationSchema{ Type: "object", Properties: map[string]ElicitationProperty{ "name": {Type: "string", Title: "姓名"}, "age": {Type: "number", Description: "年龄"}, "agree": {Type: "boolean", Default: false}, }, Required: []string{"name"}, } if len(schema.Properties) != 3 { t.Errorf("properties count = %d, want 3", len(schema.Properties)) } if schema.Properties["name"].Title != "姓名" { t.Error("name.Title should be 姓名") } } // TestElicitationCreateParams_Unmarshal 验证从 JSON 解析 ElicitationCreateParams. func TestElicitationCreateParams_Unmarshal(t *testing.T) { raw := `{ "message": "请输入数据库表名", "requestedSchema": { "type": "object", "properties": { "table": {"type":"string","title":"表名"}, "limit": {"type":"number"} }, "required": ["table"] } }` var params ElicitationCreateParams if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { t.Fatalf("unmarshal: %v", err) } if params.Message != "请输入数据库表名" { t.Errorf("message = %q", params.Message) } if params.RequestedSchema == nil { t.Fatal("requestedSchema should not be nil") } if params.RequestedSchema.Type != "object" { t.Errorf("schema.type = %q, want object", params.RequestedSchema.Type) } if len(params.RequestedSchema.Properties) != 2 { t.Errorf("properties count = %d, want 2", len(params.RequestedSchema.Properties)) } if len(params.RequestedSchema.Required) != 1 || params.RequestedSchema.Required[0] != "table" { t.Errorf("required = %v, want [table]", params.RequestedSchema.Required) } } // ── mockElicitHandler ──────────────────────────────────────────────────────── // mockElicitHandler 记录调用次数和参数,用于测试断言. type mockElicitHandler struct { mu sync.Mutex calls int lastServer string lastMsg string response ElicitationCreateResult } func (h *mockElicitHandler) HandleElicitation(serverName, message string, _ *ElicitationSchema) ElicitationCreateResult { h.mu.Lock() defer h.mu.Unlock() h.calls++ h.lastServer = serverName h.lastMsg = message return h.response } // ── memTransport:内存传输,用于 Client 单元测试 ───────────────────────────── // memTransport 是仅用于测试的内存传输实现(实现 Transport 接口). // // Recv 从 msgCh channel 逐条接收消息,阻塞直到有消息或传输关闭-- // 这样 dispatchLoop 不会因队列空而立刻退出,允许测试在任意时机调用 handleXxx 方法. // Send 记录发送内容,不阻塞. // // 精妙之处(CLEVER): 使用 channel 而非 slice+mutex+条件变量-- // channel select 天然支持关闭检测,无需额外同步原语. type memTransport struct { msgCh chan []byte // 消息供给(Recv 阻塞在此) sentMu sync.Mutex sent [][]byte // 记录所有 Send 的内容 doneCh chan struct{} // 传输关闭信号 once sync.Once // 保证 doneCh 只关闭一次 } func newMemTransport() *memTransport { return &memTransport{ msgCh: make(chan []byte, 64), doneCh: make(chan struct{}), } } // feed 向传输预注入一条消息(异步,不阻塞). func (m *memTransport) feed(msg []byte) { cp := make([]byte, len(msg)) copy(cp, msg) m.msgCh <- cp } // closeAfterN 无效(memTransport 通过 Close() 或 test 结束控制生命周期). // 仅供 dispatchLoop 退出测试使用:test 结束后 defer client.Close() 会触发关闭. func (m *memTransport) closeAfterN(_ int) {} func (m *memTransport) Send(_ context.Context, data []byte) error { select { case <-m.doneCh: return context.Canceled default: } cp := make([]byte, len(data)) copy(cp, data) m.sentMu.Lock() m.sent = append(m.sent, cp) m.sentMu.Unlock() return nil } func (m *memTransport) Recv(_ context.Context) ([]byte, error) { select { case msg, ok := <-m.msgCh: if !ok { return nil, context.Canceled } return msg, nil case <-m.doneCh: return nil, context.Canceled } } func (m *memTransport) Close() error { m.once.Do(func() { close(m.doneCh) }) return nil } func (m *memTransport) getSent() [][]byte { m.sentMu.Lock() defer m.sentMu.Unlock() out := make([][]byte, len(m.sent)) copy(out, m.sent) return out } // newTestClient 创建不发起实际网络连接的测试用 Client. func newTestClient(serverName string, tp Transport) *Client { cfg := config.MCPServerConfig{Name: serverName} return NewClient(cfg, tp) } // ── Client.SetElicitationHandler 测试 ─────────────────────────────────────── // TestClient_SetElicitationHandler 验证 SetElicitationHandler 正确注入. func TestClient_SetElicitationHandler(t *testing.T) { tp := newMemTransport() client := newTestClient("test-srv", tp) defer client.Close() handler := &mockElicitHandler{response: ElicitationCreateResult{Action: "accept"}} client.SetElicitationHandler(handler) client.mu.Lock() got := client.elicitationHandler client.mu.Unlock() if got == nil { t.Error("elicitationHandler should be set after SetElicitationHandler") } } // TestClient_HandleElicitationCreate_WithHandler 验证 handler 被调用且响应发回. func TestClient_HandleElicitationCreate_WithHandler(t *testing.T) { tp := newMemTransport() client := newTestClient("db-srv", tp) defer client.Close() handler := &mockElicitHandler{ response: ElicitationCreateResult{ Action: "accept", Content: map[string]any{"table": "users"}, }, } client.SetElicitationHandler(handler) params := ElicitationCreateParams{ Message: "请输入表名", RequestedSchema: &ElicitationSchema{ Type: "object", Properties: map[string]ElicitationProperty{"table": {Type: "string"}}, }, } paramsBytes, _ := json.Marshal(params) client.handleElicitationCreate(42, json.RawMessage(paramsBytes)) handler.mu.Lock() calls := handler.calls lastServer := handler.lastServer lastMsg := handler.lastMsg handler.mu.Unlock() if calls != 1 { t.Errorf("handler.calls = %d, want 1", calls) } if lastServer != "db-srv" { t.Errorf("lastServer = %q, want db-srv", lastServer) } if lastMsg != "请输入表名" { t.Errorf("lastMsg = %q", lastMsg) } // 验证响应被发送 sent := tp.getSent() if len(sent) == 0 { t.Fatal("response should have been sent via transport") } var resp struct { JSONRPC string `json:"jsonrpc"` ID int64 `json:"id"` Result ElicitationCreateResult `json:"result"` } if err := json.Unmarshal(sent[len(sent)-1], &resp); err != nil { t.Fatalf("parse response: %v", err) } if resp.ID != 42 { t.Errorf("response ID = %d, want 42", resp.ID) } if resp.Result.Action != "accept" { t.Errorf("result.action = %q, want accept", resp.Result.Action) } } // TestClient_HandleElicitationCreate_NoHandler_AutoCancel 验证无 handler 时自动 cancel. func TestClient_HandleElicitationCreate_NoHandler_AutoCancel(t *testing.T) { tp := newMemTransport() client := newTestClient("srv", tp) defer client.Close() // 不设置 handler params := ElicitationCreateParams{Message: "测试"} paramsBytes, _ := json.Marshal(params) client.handleElicitationCreate(99, json.RawMessage(paramsBytes)) sent := tp.getSent() if len(sent) == 0 { t.Fatal("response should be sent even without handler") } var resp struct { Result ElicitationCreateResult `json:"result"` } json.Unmarshal(sent[len(sent)-1], &resp) if resp.Result.Action != "cancel" { t.Errorf("without handler, action should be cancel, got %q", resp.Result.Action) } } // TestClient_HandleElicitationCreate_InvalidParams 验证 params 无效时回复 -32600. func TestClient_HandleElicitationCreate_InvalidParams(t *testing.T) { tp := newMemTransport() client := newTestClient("srv", tp) defer client.Close() client.handleElicitationCreate(1, json.RawMessage(`not json`)) sent := tp.getSent() if len(sent) == 0 { t.Fatal("error response should have been sent") } var resp struct { Error *RPCError `json:"error"` } json.Unmarshal(sent[len(sent)-1], &resp) if resp.Error == nil { t.Fatal("response should have error field") } if resp.Error.Code != -32600 { t.Errorf("error code = %d, want -32600", resp.Error.Code) } } // TestClient_HandleElicitationCreate_MessageTooLong 验证超长 message 被拒绝. func TestClient_HandleElicitationCreate_MessageTooLong(t *testing.T) { tp := newMemTransport() client := newTestClient("srv", tp) defer client.Close() // 构造超过 maxElicitationMessageLen 的 message longMessage := strings.Repeat("x", maxElicitationMessageLen+1) params := ElicitationCreateParams{Message: longMessage} paramsBytes, _ := json.Marshal(params) client.handleElicitationCreate(50, json.RawMessage(paramsBytes)) sent := tp.getSent() if len(sent) == 0 { t.Fatal("error response should have been sent for oversized message") } var resp struct { ID int64 `json:"id"` Error *RPCError `json:"error"` } json.Unmarshal(sent[len(sent)-1], &resp) if resp.Error == nil { t.Fatal("response should have error for oversized message") } if resp.Error.Code != -32600 { t.Errorf("error code = %d, want -32600", resp.Error.Code) } if !strings.Contains(resp.Error.Message, "too long") { t.Errorf("error message should mention 'too long', got %q", resp.Error.Message) } } // TestClient_HandleElicitationCreate_MessageAtLimit 验证刚好在限制内的 message 通过. func TestClient_HandleElicitationCreate_MessageAtLimit(t *testing.T) { tp := newMemTransport() client := newTestClient("srv", tp) defer client.Close() // 不设置 handler, 期望自动 cancel(不是 error) exactMessage := strings.Repeat("y", maxElicitationMessageLen) params := ElicitationCreateParams{Message: exactMessage} paramsBytes, _ := json.Marshal(params) client.handleElicitationCreate(51, json.RawMessage(paramsBytes)) sent := tp.getSent() if len(sent) == 0 { t.Fatal("response should have been sent") } var resp struct { Result ElicitationCreateResult `json:"result"` Error *RPCError `json:"error"` } json.Unmarshal(sent[len(sent)-1], &resp) if resp.Error != nil { t.Errorf("message at exact limit should not be rejected, got error: %v", resp.Error) } // 无 handler 时应 auto-cancel, 不是 error if resp.Result.Action != "cancel" { t.Errorf("action = %q, want cancel (auto-cancel without handler)", resp.Result.Action) } } // TestClient_HandleServerRequest_UnknownMethod 验证未知方法回复 -32601. func TestClient_HandleServerRequest_UnknownMethod(t *testing.T) { tp := newMemTransport() client := newTestClient("srv", tp) defer client.Close() client.handleServerRequest(77, "unknown/method", json.RawMessage(`{}`)) sent := tp.getSent() if len(sent) == 0 { t.Fatal("error response should have been sent for unknown method") } var resp struct { ID int64 `json:"id"` Error *RPCError `json:"error"` } json.Unmarshal(sent[len(sent)-1], &resp) if resp.ID != 77 { t.Errorf("response ID = %d, want 77", resp.ID) } if resp.Error == nil { t.Fatal("response should have error for unknown method") } if resp.Error.Code != -32601 { t.Errorf("error code = %d, want -32601 (MethodNotFound)", resp.Error.Code) } } // ── dispatchLoop server-to-client 路由测试 ─────────────────────────────────── // TestDispatchLoop_ServerRequest_RoutedToHandler 验证 dispatchLoop 识别 // server-to-client 请求(有 ID + 有 Method)并路由到 elicitation handler. func TestDispatchLoop_ServerRequest_RoutedToHandler(t *testing.T) { tp := newMemTransport() // 使用带通知的 handler:被调用时关闭 called channel calledCh := make(chan struct{}) handler := ¬ifyElicitHandler{ calledCh: calledCh, response: ElicitationCreateResult{Action: "accept"}, } client := newTestClient("test", tp) client.SetElicitationHandler(handler) defer client.Close() // 构造一条 elicitation/create server-to-client 请求 params := ElicitationCreateParams{Message: "需要输入"} paramsBytes, _ := json.Marshal(params) msg := map[string]any{ "jsonrpc": "2.0", "id": int64(123), "method": "elicitation/create", "params": json.RawMessage(paramsBytes), } msgBytes, _ := json.Marshal(msg) tp.feed(msgBytes) // 等待 handler 被调用(带超时,防止测试死锁) select { case <-calledCh: // handler 被调用,符合预期 case <-time.After(2 * time.Second): t.Fatal("timeout waiting for elicitation handler to be called by dispatchLoop") } handler.mu.Lock() calls := handler.calls handler.mu.Unlock() if calls != 1 { t.Errorf("handler.calls = %d, want 1 (dispatchLoop should route to elicitation handler)", calls) } } // notifyElicitHandler 是调用时通知 calledCh 的 ElicitationHandler 实现. type notifyElicitHandler struct { mu sync.Mutex calls int calledCh chan struct{} once sync.Once response ElicitationCreateResult } func (h *notifyElicitHandler) HandleElicitation(_, _ string, _ *ElicitationSchema) ElicitationCreateResult { h.mu.Lock() h.calls++ h.mu.Unlock() h.once.Do(func() { close(h.calledCh) }) return h.response } // ── rawMessage 三类型识别测试 ───────────────────────────────────────────────── // TestRawMessage_TypeIdentification 验证 rawMessage 结构能区分三种 JSON-RPC 消息类型. func TestRawMessage_TypeIdentification(t *testing.T) { tests := []struct { name string jsonStr string expectID bool expectMethod bool msgKind string }{ { name: "响应(有 ID,无 Method)", jsonStr: `{"jsonrpc":"2.0","id":1,"result":{}}`, expectID: true, expectMethod: false, msgKind: "response", }, { name: "通知(无 ID,有 Method)", jsonStr: `{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}`, expectID: false, expectMethod: true, msgKind: "notification", }, { name: "server-to-client 请求(有 ID,有 Method)", jsonStr: `{"jsonrpc":"2.0","id":42,"method":"elicitation/create","params":{}}`, expectID: true, expectMethod: true, msgKind: "server_request", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var raw rawMessage if err := json.Unmarshal([]byte(tt.jsonStr), &raw); err != nil { t.Fatalf("unmarshal: %v", err) } hasID := raw.ID != nil hasMethod := raw.Method != "" if hasID != tt.expectID { t.Errorf("hasID = %v, want %v", hasID, tt.expectID) } if hasMethod != tt.expectMethod { t.Errorf("hasMethod = %v, want %v", hasMethod, tt.expectMethod) } // 验证 dispatchLoop 路由逻辑 switch tt.msgKind { case "server_request": if !(hasID && hasMethod) { t.Error("server_request must have both ID and Method for correct routing") } case "response": // 有 ID 无 Method → 走 pending map 路由 if hasID && hasMethod { t.Error("response should not have Method") } case "notification": // 无 ID 有 Method → 走 notificationHandler if !hasMethod || hasID { t.Error("notification should have Method but not ID") } } }) } } // ── Manager.SetElicitationHandler 测试 ────────────────────────────────────── // TestManager_SetElicitationHandler 验证 Manager 持有 handler 后 ConnectOne 会注入到 Client. func TestManager_SetElicitationHandler(t *testing.T) { m := NewManager(execenv.DefaultExecutor{}) handler := &mockElicitHandler{response: ElicitationCreateResult{Action: "decline"}} m.SetElicitationHandler(handler) m.mu.RLock() got := m.elicitationHandler m.mu.RUnlock() if got == nil { t.Error("manager.elicitationHandler should be set") } } // TestManager_SetElicitationHandler_NilSafe 验证设置为 nil 不 panic. func TestManager_SetElicitationHandler_NilSafe(t *testing.T) { m := NewManager(execenv.DefaultExecutor{}) m.SetElicitationHandler(nil) // 不应 panic m.mu.RLock() got := m.elicitationHandler m.mu.RUnlock() if got != nil { t.Error("nil handler should be stored as nil") } }