// tool_input_guard_test.go - 8.4 tool_use input JSON 预检测试. // // 覆盖场景: // - 合法 JSON:正常解析,不推送错误事件 // - 残缺 JSON:json.Valid 失败,推送 ToolResultEvent(IsError=true) + ToolUseEvent(空 Input) // - 空 partialJSON:视为合法(空 map),不推送错误事件 // - getToolSchemaHint:工具存在时返回 schema,不存在时返回提示 package engine import ( "context" "encoding/json" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // mockSchemaTool 是仅用于测试 getToolSchemaHint 的最简工具实现. type mockSchemaTool struct { name string schema json.RawMessage } func (m *mockSchemaTool) Name() string { return m.name } func (m *mockSchemaTool) Description(_ context.Context) string { return "mock" } func (m *mockSchemaTool) InputSchema() json.RawMessage { return m.schema } func (m *mockSchemaTool) Execute(_ context.Context, _ json.RawMessage, _ tools.ProgressFunc) (*tools.Result, error) { return &tools.Result{Output: "mock"}, nil } // newTestEngineWithTool 创建一个有最小配置的 Engine,仅注册 mockSchemaTool. // 不启动 API 连接,仅测试注册表和 getToolSchemaHint 逻辑. func newTestEngineWithTool(t *testing.T, toolName string, schema json.RawMessage) *Engine { t.Helper() reg := tools.NewRegistry() mock := &mockSchemaTool{name: toolName, schema: schema} if err := reg.Register(mock); err != nil { t.Fatalf("register mock tool: %v", err) } // 只需要 tools 字段,其他字段为零值 return &Engine{tools: reg} } // ── getToolSchemaHint 测试 ─────────────────────────────────────────────────── // TestGetToolSchemaHint_ToolNotFound 验证工具不存在时返回友好提示. func TestGetToolSchemaHint_ToolNotFound(t *testing.T) { e := &Engine{tools: tools.NewRegistry()} hint := e.getToolSchemaHint("nonexistent_tool") if hint == "" { t.Error("hint should be non-empty when tool not found") } // 应包含工具名 if !contains(hint, "nonexistent_tool") { t.Errorf("hint %q should mention tool name", hint) } } // TestGetToolSchemaHint_ToolFound 验证工具存在时返回 schema 片段. func TestGetToolSchemaHint_ToolFound(t *testing.T) { schema := json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}}}`) e := newTestEngineWithTool(t, "bash", schema) hint := e.getToolSchemaHint("bash") if hint == "" { t.Error("hint should be non-empty when tool exists") } if !contains(hint, "bash") { t.Errorf("hint %q should mention tool name", hint) } if !contains(hint, "command") { t.Errorf("hint %q should contain schema content", hint) } } // TestGetToolSchemaHint_NilTools 验证 tools 为 nil 时安全返回空字符串. func TestGetToolSchemaHint_NilTools(t *testing.T) { e := &Engine{tools: nil} hint := e.getToolSchemaHint("any_tool") if hint != "" { t.Errorf("hint should be empty when tools registry is nil, got %q", hint) } } // TestGetToolSchemaHint_LargeSchema 验证超过 512 字节的 schema 会被截断. func TestGetToolSchemaHint_LargeSchema(t *testing.T) { // 构造一个超过 512 字节的 schema large := make([]byte, 600) for i := range large { large[i] = 'x' } // 包成合法 JSON 字符串格式 schemaStr := `{"type":"object","description":"` + string(large) + `"}` e := newTestEngineWithTool(t, "big_tool", json.RawMessage(schemaStr)) hint := e.getToolSchemaHint("big_tool") // 应该被截断,包含截断标记 if !contains(hint, "截断") && !contains(hint, "...") { t.Errorf("large schema hint should be truncated, got len=%d", len(hint)) } } // ── JSON 预检逻辑单元测试 ──────────────────────────────────────────────────── // TestJSONValidPrecheck_ValidJSON 验证合法 JSON 通过预检(json.Valid 应为 true). func TestJSONValidPrecheck_ValidJSON(t *testing.T) { cases := []string{ `{"command":"ls"}`, `{}`, `{"a":1,"b":[1,2,3]}`, `{"nested":{"key":"value"}}`, } for _, c := range cases { if !json.Valid([]byte(c)) { t.Errorf("expected valid JSON: %s", c) } } } // TestJSONValidPrecheck_InvalidJSON 验证残缺 JSON 被 json.Valid 正确识别. func TestJSONValidPrecheck_InvalidJSON(t *testing.T) { cases := []string{ `{"command":`, // 未关闭 `{"command":"ls"`, // 缺少 } `{command:"ls"}`, // key 未加引号 `{`, // 只有开括号 `{"a":1,"b":}`, // 值缺失 `undefined`, // 非 JSON } for _, c := range cases { if json.Valid([]byte(c)) { t.Errorf("expected invalid JSON: %s", c) } } } // TestJSONValidPrecheck_EmptyString 验证空字符串 partialJSON 不触发预检(跳过 Valid 调用). // 引擎逻辑:partialJSON == "" 时直接用空 map,不调用 json.Valid. func TestJSONValidPrecheck_EmptyString(t *testing.T) { // 空字符串本身不是合法 JSON,但引擎在 partialJSON=="" 时不走 Valid 分支 // 此测试验证空字符串不是合法 JSON(确保我们的 "" check 是必要的) if json.Valid([]byte("")) { t.Error("empty string should not be valid JSON") } } // ── elicitation.go 单元测试 ───────────────────────────────────────────────── // TestNoopElicitationHandler_ReturnsCancel 验证 NoopElicitationHandler 返回 cancel. func TestNoopElicitationHandler_ReturnsCancel(t *testing.T) { h := NoopElicitationHandler{} resp, err := h.HandleElicitation(ElicitationRequest{ ServerName: "test-server", Message: "请输入表名", Fields: []ElicitationField{{Name: "table", Type: "string", Required: true}}, }) if err != nil { t.Errorf("NoopElicitationHandler should not return error: %v", err) } if resp.Action != "cancel" { t.Errorf("NoopElicitationHandler should return cancel, got %q", resp.Action) } } // TestElicitationHandlerFunc_Adaptor 验证 ElicitationHandlerFunc 适配器工作正常. func TestElicitationHandlerFunc_Adaptor(t *testing.T) { called := false var capturedReq ElicitationRequest h := ElicitationHandlerFunc(func(req ElicitationRequest) (ElicitationResponse, error) { called = true capturedReq = req return ElicitationResponse{ Action: "accept", Values: map[string]string{"table": "orders"}, }, nil }) req := ElicitationRequest{ ServerName: "db-server", Message: "请输入表名", Fields: []ElicitationField{ {Name: "table", Type: "string", Title: "表名", Required: true}, }, } resp, err := h.HandleElicitation(req) if err != nil { t.Errorf("unexpected error: %v", err) } if !called { t.Error("handler func should have been called") } if capturedReq.ServerName != "db-server" { t.Errorf("server name = %q, want %q", capturedReq.ServerName, "db-server") } if resp.Action != "accept" { t.Errorf("action = %q, want accept", resp.Action) } if resp.Values["table"] != "orders" { t.Errorf("values[table] = %q, want orders", resp.Values["table"]) } } // TestElicitationField_RequiredFlag 验证 Required 字段正确序列化. func TestElicitationField_RequiredFlag(t *testing.T) { req := ElicitationRequest{ ServerName: "srv", Message: "Fill in", Fields: []ElicitationField{ {Name: "required_field", Type: "string", Required: true}, {Name: "optional_field", Type: "string", Required: false}, }, } if len(req.Fields) != 2 { t.Fatalf("expected 2 fields, got %d", len(req.Fields)) } if !req.Fields[0].Required { t.Error("first field should be required") } if req.Fields[1].Required { t.Error("second field should not be required") } } // ── 辅助函数 ───────────────────────────────────────────────────────────────── // contains 检查 s 中是否包含 substr(避免 import strings 仅为一个 Contains 调用). func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(substr) == 0 || func() bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }()) }