// orchestrator_test.go -- 工具编排器的单元测试. // // 覆盖场景: // - partition 并发分组算法 // - ExecuteBatch 执行批次 // - 未知工具报错 // - 串行/并行执行分组正确 // - MinConfidence safety gate 真 wire: 拒绝路径 + 放行后 stripped input 达工具 package tools import ( "context" "encoding/json" "strings" "sync/atomic" "testing" ) // concurrentTool 是一个标记为可并发的 mock 工具 type concurrentTool struct{ name string } func (c *concurrentTool) Name() string { return c.name } func (c *concurrentTool) Description(ctx context.Context) string { return "" } func (c *concurrentTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (c *concurrentTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok from " + c.name}, nil } func (c *concurrentTool) Metadata() Metadata { return Metadata{ConcurrencySafe: true, ReadOnly: true} } // serialTool 是一个不可并发的 mock 工具 type serialTool struct{ name string } func (s *serialTool) Name() string { return s.name } func (s *serialTool) Description(ctx context.Context) string { return "" } func (s *serialTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (s *serialTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { return &Result{Output: "ok from " + s.name}, nil } func (s *serialTool) Metadata() Metadata { return Metadata{ConcurrencySafe: false, ReadOnly: false} } // TestPartition 测试并发分组算法 func TestPartition(t *testing.T) { registry := NewRegistry() registry.Register(&concurrentTool{name: "Glob"}) registry.Register(&concurrentTool{name: "Grep"}) registry.Register(&serialTool{name: "Edit"}) o := NewOrchestrator(registry, 10) // 输入: [Glob, Grep, Edit, Glob, Grep] // 期望: [[Glob, Grep](并发), [Edit](串行), [Glob, Grep](并发)] calls := []ToolCall{ {ID: "1", Name: "Glob"}, {ID: "2", Name: "Grep"}, {ID: "3", Name: "Edit"}, {ID: "4", Name: "Glob"}, {ID: "5", Name: "Grep"}, } batches := o.partition(calls) if len(batches) != 3 { t.Fatalf("期望 3 个批次, 实际 %d", len(batches)) } // 第一批应为并发(Glob, Grep) if len(batches[0]) != 2 || !batches[0][0].concurrent { t.Errorf("第一批应为 2 个并发工具, 实际: %d 个, concurrent=%v", len(batches[0]), batches[0][0].concurrent) } // 第二批应为串行(Edit) if len(batches[1]) != 1 || batches[1][0].concurrent { t.Errorf("第二批应为 1 个串行工具") } // 第三批应为并发(Glob, Grep) if len(batches[2]) != 2 || !batches[2][0].concurrent { t.Errorf("第三批应为 2 个并发工具") } } // TestPartition_UnknownTool 测试未知工具单独分组 func TestPartition_UnknownTool(t *testing.T) { registry := NewRegistry() o := NewOrchestrator(registry, 10) calls := []ToolCall{ {ID: "1", Name: "Unknown"}, } batches := o.partition(calls) if len(batches) != 1 { t.Fatalf("期望 1 个批次, 实际 %d", len(batches)) } if batches[0][0].concurrent { t.Error("未知工具不应标记为并发") } } // TestExecuteBatch 测试批量执行 func TestExecuteBatch(t *testing.T) { registry := NewRegistry() registry.Register(&concurrentTool{name: "Glob"}) registry.Register(&serialTool{name: "Edit"}) o := NewOrchestrator(registry, 10) calls := []ToolCall{ {ID: "1", Name: "Glob", Input: json.RawMessage(`{}`)}, {ID: "2", Name: "Edit", Input: json.RawMessage(`{}`)}, } results := make(chan ToolCallResult, 10) go func() { o.ExecuteBatch(context.Background(), calls, results) close(results) }() var collected []ToolCallResult for r := range results { collected = append(collected, r) } if len(collected) != 2 { t.Fatalf("期望 2 个结果, 实际 %d", len(collected)) } // 验证结果中都有输出 for _, r := range collected { if r.Output == "" { t.Errorf("工具 %s 输出为空", r.Name) } if r.IsError { t.Errorf("工具 %s 不应报错: %s", r.Name, r.Output) } } } // TestExecuteBatch_UnknownTool 测试执行未知工具报错 func TestExecuteBatch_UnknownTool(t *testing.T) { registry := NewRegistry() o := NewOrchestrator(registry, 10) calls := []ToolCall{ {ID: "1", Name: "NonExistent", Input: json.RawMessage(`{}`)}, } results := make(chan ToolCallResult, 10) go func() { o.ExecuteBatch(context.Background(), calls, results) close(results) }() r := <-results if !r.IsError { t.Error("未知工具应报错") } if r.Output == "" { t.Error("应有错误输出") } } // TestNewOrchestrator_DefaultConcurrency 测试默认并发数 func TestNewOrchestrator_DefaultConcurrency(t *testing.T) { registry := NewRegistry() // maxConcurrency <= 0 应使用默认值 10 o := NewOrchestrator(registry, 0) if o.maxConcurrency != 10 { t.Errorf("默认并发数应为 10, 实际: %d", o.maxConcurrency) } o = NewOrchestrator(registry, -1) if o.maxConcurrency != 10 { t.Errorf("负数应使用默认值 10, 实际: %d", o.maxConcurrency) } } // gatedTool 是一个声明了 MinConfidence 的 mock 工具, 用于 orchestrator 层 // 验证 safety gate 真 wire 到 executeSingle (不是只测 helper). type gatedTool struct { name string threshold int // executedWithInput 记录工具实际收到的 input JSON, 用于断言保留字段剥除. executedWithInput atomic.Value // calls 记录 Execute 被真正调用的次数 (用于验证拒绝路径没执行工具). calls atomic.Int32 } func (g *gatedTool) Name() string { return g.name } func (g *gatedTool) Description(ctx context.Context) string { return "gated tool" } func (g *gatedTool) InputSchema() json.RawMessage { return json.RawMessage(`{}`) } func (g *gatedTool) Capability() ToolCapability { return ToolCapability{MinConfidence: g.threshold} } func (g *gatedTool) Execute(ctx context.Context, input json.RawMessage, progress ProgressFunc) (*Result, error) { g.calls.Add(1) g.executedWithInput.Store(string(input)) return &Result{Output: "ran"}, nil } // TestOrchestrator_GateRejects_MissingConfidence 验证 gate 在 Execute 前拦 // 截: MinConfidence>0 工具收到无 _flyto_confidence 的 input -> 返回 IsError // 且工具 Execute 根本没被调用. func TestOrchestrator_GateRejects_MissingConfidence(t *testing.T) { g := &gatedTool{name: "gated", threshold: 80} registry := NewRegistry() registry.Register(g) o := NewOrchestrator(registry, 10) result := o.executeSingle(context.Background(), ToolCall{ ID: "1", Name: "gated", Input: json.RawMessage(`{"path":"/tmp/x"}`), }) if !result.IsError { t.Fatal("缺 _flyto_confidence 应 IsError=true") } if !strings.Contains(result.Output, "_flyto_confidence") { t.Errorf("错误输出应提示保留字段名, 实际: %q", result.Output) } if g.calls.Load() != 0 { t.Errorf("Execute 不应被调用 (gate 拦截前置), 实际调用 %d 次", g.calls.Load()) } } // TestOrchestrator_GatePasses_StripsFieldBeforeExecute 验证 gate 放行后 // 工具 Execute 收到的 input 已剥除保留字段 (业务工具不看到 _flyto_confidence). func TestOrchestrator_GatePasses_StripsFieldBeforeExecute(t *testing.T) { g := &gatedTool{name: "gated", threshold: 80} registry := NewRegistry() registry.Register(g) o := NewOrchestrator(registry, 10) result := o.executeSingle(context.Background(), ToolCall{ ID: "2", Name: "gated", Input: json.RawMessage(`{"_flyto_confidence":95,"path":"/tmp/x"}`), }) if result.IsError { t.Fatalf("confidence=95 >= min=80 应放行, 实际 IsError=true: %q", result.Output) } if g.calls.Load() != 1 { t.Fatalf("Execute 应被调用 1 次, 实际 %d", g.calls.Load()) } received, _ := g.executedWithInput.Load().(string) if strings.Contains(received, "_flyto_confidence") { t.Errorf("工具收到的 input 不应含 _flyto_confidence, 实际: %s", received) } if !strings.Contains(received, `"path":"/tmp/x"`) { t.Errorf("业务字段 path 应保留, 实际: %s", received) } } // TestOrchestrator_NoGate_OnZeroMinConfidence 验证 MinConfidence=0 工具 // (即当前所有 builtin) 完全 bypass gate, input 原样达工具, 保持向后兼容. func TestOrchestrator_NoGate_OnZeroMinConfidence(t *testing.T) { g := &gatedTool{name: "free", threshold: 0} registry := NewRegistry() registry.Register(g) o := NewOrchestrator(registry, 10) inputJSON := `{"path":"/tmp/x"}` result := o.executeSingle(context.Background(), ToolCall{ ID: "3", Name: "free", Input: json.RawMessage(inputJSON), }) if result.IsError { t.Fatalf("MinConfidence=0 应不拦截, 实际 IsError=true: %q", result.Output) } received, _ := g.executedWithInput.Load().(string) if received != inputJSON { t.Errorf("MinConfidence=0 应原样透传 input (无 JSON 重编码), 实际: %s (期望: %s)", received, inputJSON) } }