// registry_test.go -- 工具注册表的单元测试. // // 覆盖场景: // - Register 注册新工具 // - 重复注册报错 // - Get 按名称获取工具(含别名) // - All 返回所有工具(按注册顺序) // - Names 返回所有工具名称 // - Unregister 注销工具 // - Filter 按名称过滤工具 // - Count 返回工具数量 // - GetMetadata 安全获取元数据 package tools import ( "context" "encoding/json" "testing" ) // mockTool 是用于测试的 mock 工具实现. type mockTool struct { name string aliases []string } func (m *mockTool) Name() string { return m.name } func (m *mockTool) Description(ctx context.Context) string { return "mock tool: " + m.name } func (m *mockTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (m *mockTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok"}, nil } func (m *mockTool) Metadata() Metadata { return Metadata{ ConcurrencySafe: true, ReadOnly: true, Aliases: m.aliases, } } // TestRegistry_Register 测试注册工具 func TestRegistry_Register(t *testing.T) { r := NewRegistry() tool := &mockTool{name: "TestTool"} if err := r.Register(tool); err != nil { t.Fatalf("注册失败: %v", err) } if r.Count() != 1 { t.Errorf("注册后数量应为 1, 实际: %d", r.Count()) } } // TestRegistry_DuplicateRegister 测试重复注册报错 func TestRegistry_DuplicateRegister(t *testing.T) { r := NewRegistry() tool := &mockTool{name: "TestTool"} r.Register(tool) err := r.Register(tool) if err == nil { t.Error("重复注册应报错") } } // TestRegistry_Get 测试按名称获取工具 func TestRegistry_Get(t *testing.T) { r := NewRegistry() tool := &mockTool{name: "MyTool"} r.Register(tool) got, ok := r.Get("MyTool") if !ok { t.Fatal("应能找到工具") } if got.Name() != "MyTool" { t.Errorf("工具名不匹配: %q", got.Name()) } // 查找不存在的工具 _, ok = r.Get("NonExistent") if ok { t.Error("不应找到不存在的工具") } } // TestRegistry_GetByAlias 测试通过别名获取工具 func TestRegistry_GetByAlias(t *testing.T) { r := NewRegistry() tool := &mockTool{name: "Read", aliases: []string{"FileRead"}} r.Register(tool) got, ok := r.Get("FileRead") if !ok { t.Fatal("应能通过别名找到工具") } if got.Name() != "Read" { t.Errorf("工具名不匹配: %q", got.Name()) } } // TestRegistry_All 测试返回所有工具(按注册顺序) func TestRegistry_All(t *testing.T) { r := NewRegistry() r.Register(&mockTool{name: "A"}) r.Register(&mockTool{name: "B"}) r.Register(&mockTool{name: "C"}) all := r.All() if len(all) != 3 { t.Fatalf("期望 3 个工具, 实际 %d", len(all)) } if all[0].Name() != "A" || all[1].Name() != "B" || all[2].Name() != "C" { t.Error("工具应按注册顺序返回") } } // TestRegistry_Names 测试返回所有工具名称 func TestRegistry_Names(t *testing.T) { r := NewRegistry() r.Register(&mockTool{name: "A"}) r.Register(&mockTool{name: "B"}) names := r.Names() if len(names) != 2 { t.Fatalf("期望 2 个名称, 实际 %d", len(names)) } if names[0] != "A" || names[1] != "B" { t.Errorf("名称不匹配: %v", names) } } // TestRegistry_Unregister 测试注销工具 func TestRegistry_Unregister(t *testing.T) { r := NewRegistry() tool := &mockTool{name: "MyTool", aliases: []string{"OldName"}} r.Register(tool) ok := r.Unregister("MyTool") if !ok { t.Error("注销应返回 true") } if r.Count() != 0 { t.Errorf("注销后数量应为 0, 实际: %d", r.Count()) } // 别名也应被清理 _, found := r.Get("OldName") if found { t.Error("注销后别名不应可查") } // 注销不存在的工具 ok = r.Unregister("NonExistent") if ok { t.Error("注销不存在的工具应返回 false") } } // TestRegistry_Filter 测试按名称过滤 func TestRegistry_Filter(t *testing.T) { r := NewRegistry() r.Register(&mockTool{name: "A"}) r.Register(&mockTool{name: "B"}) r.Register(&mockTool{name: "C"}) filtered := r.Filter([]string{"A", "C"}) if len(filtered) != 2 { t.Fatalf("期望 2 个工具, 实际 %d", len(filtered)) } if filtered[0].Name() != "A" || filtered[1].Name() != "C" { t.Error("过滤结果不匹配") } } // TestGetMetadata_WithProvider 测试有元数据的工具 func TestGetMetadata_WithProvider(t *testing.T) { tool := &mockTool{name: "Test"} meta := GetMetadata(tool) if !meta.ConcurrencySafe { t.Error("应返回工具声明的元数据") } if !meta.ReadOnly { t.Error("应返回工具声明的元数据") } } // TestGetMetadata_WithoutProvider 测试无元数据的工具(保守默认值) type simpleTool struct{} func (s *simpleTool) Name() string { return "Simple" } func (s *simpleTool) Description(ctx context.Context) string { return "" } func (s *simpleTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (s *simpleTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok"}, nil } func TestGetMetadata_WithoutProvider(t *testing.T) { tool := &simpleTool{} meta := GetMetadata(tool) if meta.ConcurrencySafe { t.Error("默认应为不可并发") } if meta.ReadOnly { t.Error("默认应为非只读") } if meta.RequiresReverseThinking { t.Error("默认应为不需要反向思维") } if meta.Destructive { t.Error("默认应为非破坏性") } } // reverseThinkingTool 是仅声明 RequiresReverseThinking 的 mock, 验证 Metadata // 字段透传不被其他字段干扰. 与 RequiresCheckpoint 同源测试模式. type reverseThinkingTool struct{} func (r *reverseThinkingTool) Name() string { return "ReverseThinking" } func (r *reverseThinkingTool) Description(ctx context.Context) string { return "" } func (r *reverseThinkingTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (r *reverseThinkingTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok"}, nil } func (r *reverseThinkingTool) Metadata() Metadata { return Metadata{RequiresReverseThinking: true} } // TestGetMetadata_RequiresReverseThinking 验证 RequiresReverseThinking 字段 // 经 GetMetadata 透传, 与 RequiresCheckpoint 字段语义独立可叠加. func TestGetMetadata_RequiresReverseThinking(t *testing.T) { tool := &reverseThinkingTool{} meta := GetMetadata(tool) if !meta.RequiresReverseThinking { t.Error("RequiresReverseThinking 应透传 true") } if meta.RequiresCheckpoint { t.Error("RequiresCheckpoint 不应被 RequiresReverseThinking 干扰, 应保持 false") } }