package tools import ( "context" "encoding/json" "testing" ) // mockDeferredTool 是用于测试的模拟工具. type mockDeferredTool struct { name string searchHint string aliases []string } func (t *mockDeferredTool) Name() string { return t.name } func (t *mockDeferredTool) Description(ctx context.Context) string { return t.name + " tool" } func (t *mockDeferredTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (t *mockDeferredTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok"}, nil } func (t *mockDeferredTool) Metadata() Metadata { return Metadata{ ConcurrencySafe: true, ReadOnly: true, SearchHint: t.searchHint, Aliases: t.aliases, } } // createTestRegistry 创建一个包含多种工具的测试注册表. func createTestRegistry(count int) *Registry { reg := NewRegistry() // 先注册核心工具 coreNames := []string{"Bash", "Read", "Edit", "Write", "Glob", "Grep", "Agent", "ToolSearch"} for _, name := range coreNames { _ = reg.Register(&mockDeferredTool{name: name}) } // 注册额外的非核心工具 extras := []string{"WebFetch", "WebSearch", "TaskCreate", "TaskList", "TaskUpdate", "NotebookEdit", "ImageGen", "DocParser", "CodeFormat", "MCPTool1", "MCPTool2"} for i := 0; i < count && i < len(extras); i++ { _ = reg.Register(&mockDeferredTool{ name: extras[i], searchHint: "extra tool " + extras[i], }) } return reg } // TestDeferredRegistry_BelowThreshold 测试工具数量未超过阈值时全部返回 func TestDeferredRegistry_BelowThreshold(t *testing.T) { reg := createTestRegistry(3) // 8 核心 + 3 额外 = 11 dr := NewDeferredRegistry(reg, 15) active := dr.ActiveTools() if len(active) != reg.Count() { t.Errorf("未超过阈值时应返回全部工具, 期望 %d, 实际 %d", reg.Count(), len(active)) } } // TestDeferredRegistry_AboveThreshold 测试工具数量超过阈值时只返回核心工具 func TestDeferredRegistry_AboveThreshold(t *testing.T) { reg := createTestRegistry(10) // 8 核心 + 10 额外 = 18 > 15 dr := NewDeferredRegistry(reg, 15) active := dr.ActiveTools() // 应该只有 8 个核心工具 if len(active) != 8 { t.Errorf("超过阈值时应只返回核心工具, 期望 8, 实际 %d", len(active)) } // 验证都是核心工具 for _, tool := range active { if !alwaysLoadTools[tool.Name()] { t.Errorf("非核心工具 %s 不应出现在活跃列表中", tool.Name()) } } } // TestDeferredRegistry_SearchTools 测试工具搜索 func TestDeferredRegistry_SearchTools(t *testing.T) { reg := createTestRegistry(10) // 给 WebFetch 一个特定的搜索提示 reg.Unregister("WebFetch") _ = reg.Register(&mockDeferredTool{ name: "WebFetch", searchHint: "fetch web page http url download", }) dr := NewDeferredRegistry(reg, 15) // 搜索 "web" matches := dr.SearchTools("web") if len(matches) == 0 { t.Error("搜索 'web' 应找到匹配的工具") } found := false for _, m := range matches { if m.Name() == "WebFetch" { found = true break } } if !found { t.Error("搜索 'web' 应找到 WebFetch 工具") } } // TestDeferredRegistry_SearchTools_NoResults 测试搜索无结果 func TestDeferredRegistry_SearchTools_NoResults(t *testing.T) { reg := createTestRegistry(10) dr := NewDeferredRegistry(reg, 15) matches := dr.SearchTools("zzz_nonexistent_xxx") if len(matches) != 0 { t.Errorf("搜索不存在的关键词应返回空结果, 实际: %d", len(matches)) } } // TestDeferredRegistry_ActivateTool 测试激活延迟工具 func TestDeferredRegistry_ActivateTool(t *testing.T) { reg := createTestRegistry(10) dr := NewDeferredRegistry(reg, 15) // 激活前 WebFetch 不在活跃列表中 active := dr.ActiveTools() for _, tool := range active { if tool.Name() == "WebFetch" { t.Error("WebFetch 在激活前不应出现在活跃列表中") } } // 激活 WebFetch dr.ActivateTool("WebFetch") // 激活后应出现在活跃列表中 active = dr.ActiveTools() found := false for _, tool := range active { if tool.Name() == "WebFetch" { found = true break } } if !found { t.Error("WebFetch 激活后应出现在活跃列表中") } } // TestDeferredRegistry_IsDeferred 测试延迟状态检查 func TestDeferredRegistry_IsDeferred(t *testing.T) { reg := createTestRegistry(10) dr := NewDeferredRegistry(reg, 15) // 核心工具不应是延迟的 if dr.IsDeferred("Bash") { t.Error("Bash 是核心工具,不应是延迟的") } // 非核心工具应是延迟的 if !dr.IsDeferred("WebFetch") { t.Error("WebFetch 应是延迟的") } // 不存在的工具不应是延迟的 if dr.IsDeferred("NonExistent") { t.Error("不存在的工具不应是延迟的") } // 激活后不再是延迟的 dr.ActivateTool("WebFetch") if dr.IsDeferred("WebFetch") { t.Error("WebFetch 激活后不应是延迟的") } } // TestDeferredRegistry_BelowThreshold_NotDeferred 测试低于阈值时没有延迟工具 func TestDeferredRegistry_BelowThreshold_NotDeferred(t *testing.T) { reg := createTestRegistry(3) // 低于阈值 dr := NewDeferredRegistry(reg, 15) if dr.IsDeferred("WebFetch") { t.Error("低于阈值时不应有延迟工具") } } // TestDeferredRegistry_ResetActivations 测试重置激活状态 func TestDeferredRegistry_ResetActivations(t *testing.T) { reg := createTestRegistry(10) dr := NewDeferredRegistry(reg, 15) dr.ActivateTool("WebFetch") dr.ActivateTool("WebSearch") dr.ResetActivations() if !dr.IsDeferred("WebFetch") { t.Error("重置后 WebFetch 应回到延迟状态") } if !dr.IsDeferred("WebSearch") { t.Error("重置后 WebSearch 应回到延迟状态") } } // TestDeferredRegistry_DefaultThreshold 测试默认阈值 func TestDeferredRegistry_DefaultThreshold(t *testing.T) { reg := NewRegistry() dr := NewDeferredRegistry(reg, 0) if dr.Threshold() != defaultDeferredThreshold { t.Errorf("默认阈值应为 %d, 实际: %d", defaultDeferredThreshold, dr.Threshold()) } } // --- P1-2:WithAlwaysLoad 和 RegisterDeferredTool 测试 --- // TestDeferredRegistry_WithAlwaysLoad_IsolatesInstances 验证 WithAlwaysLoad 只影响当前实例. func TestDeferredRegistry_WithAlwaysLoad_IsolatesInstances(t *testing.T) { reg := createTestRegistry(10) // 超过阈值 dr1 := NewDeferredRegistry(reg, 15) dr2 := NewDeferredRegistry(reg, 15) // 只对 dr1 添加自定义核心工具 dr1.WithAlwaysLoad("WebFetch") // dr1 应包含 WebFetch(核心工具) active1 := dr1.ActiveTools() found1 := false for _, t := range active1 { if t.Name() == "WebFetch" { found1 = true break } } if !found1 { t.Error("dr1 添加了 WebFetch 为核心工具,应出现在活跃列表") } // dr2 不应受影响,WebFetch 仍是延迟工具 if !dr2.IsDeferred("WebFetch") { t.Error("dr2 未调用 WithAlwaysLoad,WebFetch 应仍是延迟工具(不受 dr1 影响)") } } // TestDeferredRegistry_WithAlwaysLoad_Chaining 验证链式调用. func TestDeferredRegistry_WithAlwaysLoad_Chaining(t *testing.T) { reg := createTestRegistry(10) dr := NewDeferredRegistry(reg, 15). WithAlwaysLoad("WebFetch", "WebSearch") if dr.IsDeferred("WebFetch") { t.Error("链式 WithAlwaysLoad 后 WebFetch 应不再是延迟工具") } if dr.IsDeferred("WebSearch") { t.Error("链式 WithAlwaysLoad 后 WebSearch 应不再是延迟工具") } } // TestRegisterDeferredTool 验证 RegisterDeferredTool 更新全局默认白名单. // 注意:此测试修改全局状态,在测试结束后应尽量不影响其他测试. // 精妙之处(CLEVER): 测试完成后用同名工具再注册一遍是幂等的(map[name]=true 已经存在), // 不会引起问题.但新增的工具名在全局白名单里会保留到测试进程结束--这在测试场景下是可接受的. func TestRegisterDeferredTool_UpdatesGlobal(t *testing.T) { const testToolName = "TestScanBarcode_P1_2" // 注册前不应在全局白名单中 if alwaysLoadTools[testToolName] { t.Skip("测试工具名已存在,跳过(其他测试可能已注册)") } RegisterDeferredTool(testToolName) if !alwaysLoadTools[testToolName] { t.Errorf("RegisterDeferredTool 后 %q 应出现在全局白名单", testToolName) } // 新建实例时应继承全局白名单中的新工具 reg := NewRegistry() for _, name := range []string{"Bash", "Read", "Edit", "Write", "Glob", "Grep", "Agent", "ToolSearch"} { _ = reg.Register(&mockDeferredTool{name: name}) } _ = reg.Register(&mockDeferredTool{name: testToolName}) // 注册额外工具使总数超过阈值 for i := 0; i < 8; i++ { _ = reg.Register(&mockDeferredTool{name: "Extra" + string(rune('A'+i))}) } dr := NewDeferredRegistry(reg, 15) if dr.IsDeferred(testToolName) { t.Errorf("通过 RegisterDeferredTool 注册的 %q 不应是延迟工具", testToolName) } } // TestDeferredRegistry_SearchByAlias 测试通过别名搜索 func TestDeferredRegistry_SearchByAlias(t *testing.T) { reg := NewRegistry() // 注册足够多工具以超过阈值 for i := 0; i < 8; i++ { coreNames := []string{"Bash", "Read", "Edit", "Write", "Glob", "Grep", "Agent", "ToolSearch"} if i < len(coreNames) { _ = reg.Register(&mockDeferredTool{name: coreNames[i]}) } } // 注册带别名的工具 _ = reg.Register(&mockDeferredTool{ name: "FileRead", aliases: []string{"cat", "readfile"}, }) // 再注册几个工具超过阈值 for i := 0; i < 8; i++ { _ = reg.Register(&mockDeferredTool{ name: "Extra" + string(rune('A'+i)), }) } dr := NewDeferredRegistry(reg, 10) matches := dr.SearchTools("cat") found := false for _, m := range matches { if m.Name() == "FileRead" { found = true break } } if !found { t.Error("应该能通过别名 'cat' 搜索到 FileRead 工具") } }