package builtin import ( "context" "encoding/json" "errors" "regexp" "strings" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" "git.flytoex.net/yuanwei/flyto-agent/pkg/validator" ) // -- fakes -- type fakeInnerTool struct { name string result *tools.Result err error execCalls int } func (t *fakeInnerTool) Name() string { return t.name } func (t *fakeInnerTool) Description(_ context.Context) string { return "fake" } func (t *fakeInnerTool) InputSchema() json.RawMessage { return json.RawMessage(`{"type":"object"}`) } func (t *fakeInnerTool) Execute(_ context.Context, _ json.RawMessage, _ tools.ProgressFunc) (*tools.Result, error) { t.execCalls++ return t.result, t.err } type fakeInnerToolWithMeta struct { fakeInnerTool meta tools.Metadata } func (t *fakeInnerToolWithMeta) Metadata() tools.Metadata { return t.meta } type fakeInnerValidator struct { verdict validator.Verdict err error callCount int lastDiff validator.DiffInput } func (v *fakeInnerValidator) Name() string { return "fake-validator" } func (v *fakeInnerValidator) Validate(_ context.Context, diff validator.DiffInput) (validator.Verdict, error) { v.callCount++ v.lastDiff = diff return v.verdict, v.err } // -- ValidatedTool core behaviour -- func TestValidatedTool_ApprovedPassesThrough(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "ok"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{Approved: true, Severity: validator.SeverityWarn}} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) got, err := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if err != nil { t.Fatalf("unexpected error: %v", err) } if got.IsError { t.Errorf("approved should not set IsError, got %+v", got) } if got.Output != "ok" { t.Errorf("output should be unchanged, got %q", got.Output) } if vdr.callCount != 1 { t.Errorf("validator should run once, got %d", vdr.callCount) } } func TestValidatedTool_BlockedRewritesResult(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "rows updated"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{ Approved: false, Severity: validator.SeverityBlock, Reason: "DROP detected", ValidatorName: "rules", PolicyVersion: "v1", }} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) got, _ := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if !got.IsError { t.Errorf("blocked should set IsError=true, got %+v", got) } if !strings.Contains(got.Output, "rules") { t.Errorf("output should include ValidatorName, got %q", got.Output) } if !strings.Contains(got.Output, "v1") { t.Errorf("output should include PolicyVersion, got %q", got.Output) } if !strings.Contains(got.Output, "DROP detected") { t.Errorf("output should include Reason, got %q", got.Output) } if !strings.Contains(got.Output, "rows updated") { t.Errorf("output should include original output, got %q", got.Output) } } func TestValidatedTool_WarnPassesThrough(t *testing.T) { // Warn does not rewrite the result -- it's advisory; the circuit // breaker logs the sample but the write proceeds. tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "did it"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{ Approved: false, Severity: validator.SeverityWarn, Reason: "minor concern", }} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) got, _ := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if got.IsError { t.Errorf("warn should not set IsError, got %+v", got) } if got.Output != "did it" { t.Errorf("warn should not rewrite output, got %q", got.Output) } } func TestValidatedTool_ValidatorErrorBlocks(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "ok"}} vdr := &fakeInnerValidator{err: errors.New("backend down")} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) got, _ := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if !got.IsError { t.Errorf("validator error should escalate to Block, got %+v", got) } if !strings.Contains(got.Output, "backend down") { t.Errorf("output should include error message, got %q", got.Output) } } func TestValidatedTool_InnerErrorSkipsValidation(t *testing.T) { tool := &fakeInnerTool{name: "inner", err: errors.New("tool failure")} vdr := &fakeInnerValidator{} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) _, err := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if err == nil || err.Error() != "tool failure" { t.Errorf("inner error should bubble up, got %v", err) } if vdr.callCount != 0 { t.Errorf("validator should not run when inner errors, got %d", vdr.callCount) } } func TestValidatedTool_InnerIsErrorSkipsValidation(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "fail", IsError: true}} vdr := &fakeInnerValidator{} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) got, _ := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if !got.IsError { t.Errorf("inner IsError should pass through, got %+v", got) } if vdr.callCount != 0 { t.Errorf("validator should not run on inner IsError, got %d", vdr.callCount) } } // -- VerdictSink -- func TestValidatedTool_SinkCalledOnEveryVerdict(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "ok"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{Approved: true, Severity: validator.SeverityWarn}} var sinkCalls int var lastName string var lastVerdict validator.Verdict sink := func(name string, v validator.Verdict) { sinkCalls++ lastName = name lastVerdict = v } vt := NewValidatedTool(tool, vdr, DefaultExtractor(), sink) _, _ = vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if sinkCalls != 1 { t.Errorf("sink should fire once per validated call, got %d", sinkCalls) } if lastName != "inner" { t.Errorf("sink should receive inner tool name, got %q", lastName) } if !lastVerdict.Approved { t.Errorf("sink should receive the actual verdict, got %+v", lastVerdict) } } func TestValidatedTool_SinkFiresOnBlock(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "ok"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{Severity: validator.SeverityBlock, Reason: "no"}} var sinkCalled bool sink := func(_ string, _ validator.Verdict) { sinkCalled = true } vt := NewValidatedTool(tool, vdr, DefaultExtractor(), sink) _, _ = vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if !sinkCalled { t.Errorf("sink should fire on Block verdict") } } func TestValidatedTool_NilSinkOK(t *testing.T) { tool := &fakeInnerTool{name: "inner", result: &tools.Result{Output: "ok"}} vdr := &fakeInnerValidator{verdict: validator.Verdict{Approved: true}} vt := NewValidatedTool(tool, vdr, DefaultExtractor(), nil) _, err := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if err != nil { t.Errorf("nil sink should not panic/error, got %v", err) } } // -- method forwarding -- func TestValidatedTool_NameForwarded(t *testing.T) { tool := &fakeInnerTool{name: "my-tool"} vt := NewValidatedTool(tool, &fakeInnerValidator{}, DefaultExtractor(), nil) if vt.Name() != "my-tool" { t.Errorf("Name should forward, got %q", vt.Name()) } } func TestValidatedTool_MetadataForwarded(t *testing.T) { tool := &fakeInnerToolWithMeta{ fakeInnerTool: fakeInnerTool{name: "inner"}, meta: tools.Metadata{ReadOnly: false, Destructive: true}, } vt := NewValidatedTool(tool, &fakeInnerValidator{}, DefaultExtractor(), nil) got := vt.Metadata() if !got.Destructive { t.Errorf("Destructive should forward, got %+v", got) } } // -- extractors -- func TestDefaultExtractor(t *testing.T) { ext := DefaultExtractor() diff := ext("SQLCAS", &tools.Result{Data: map[string]any{"x": 1}}) if diff.SourceTool != "SQLCAS" { t.Errorf("SourceTool should be toolName, got %q", diff.SourceTool) } if !strings.Contains(string(diff.Raw), `"x":1`) { t.Errorf("Raw should contain Data JSON, got %q", string(diff.Raw)) } if diff.Metadata != nil { t.Errorf("DefaultExtractor should not set Metadata, got %v", diff.Metadata) } } func TestExtractorSQLCAS(t *testing.T) { ext := ExtractorSQLCAS() diff := ext("SQLCAS", &tools.Result{Data: map[string]any{ "affected_rows": 5, "table_name": "orders", "other": "ignored", }}) // JSON round-trip converts int to float64. if diff.Metadata["affected_rows"] != float64(5) { t.Errorf("affected_rows should be extracted, got %v", diff.Metadata["affected_rows"]) } if diff.Metadata["table_name"] != "orders" { t.Errorf("table_name should be extracted, got %v", diff.Metadata["table_name"]) } if _, ok := diff.Metadata["other"]; ok { t.Errorf("unexpected key 'other' in metadata, got %v", diff.Metadata) } } func TestExtractorSQLDryRun_NormalizesRowCount(t *testing.T) { // DryRun exposes "affected_row_count" (not "affected_rows"); // extractor normalises so DiffSizeRule keyed on "affected_rows" // works uniformly across SQL tools. ext := ExtractorSQLDryRun() diff := ext("SQLDryRun", &tools.Result{Data: map[string]any{ "affected_row_count": 10, "table_name": "items", "after_predicate_mismatch": true, }}) if diff.Metadata["affected_rows"] != float64(10) { t.Errorf("affected_row_count should normalise to affected_rows, got %v", diff.Metadata["affected_rows"]) } if diff.Metadata["table_name"] != "items" { t.Errorf("table_name should be extracted, got %v", diff.Metadata["table_name"]) } if diff.Metadata["after_predicate_mismatch"] != true { t.Errorf("after_predicate_mismatch should be extracted, got %v", diff.Metadata["after_predicate_mismatch"]) } } // -- end-to-end with real Validator -- func TestValidatedTool_EndToEnd_RuleBlocksDrop(t *testing.T) { tool := &fakeInnerTool{name: "SQLCAS", result: &tools.Result{ Output: "executed", Data: map[string]any{"sql": "DROP TABLE orders"}, }} rule := validator.NewRuleValidator("rules", "v1", &validator.PatternRule{Patterns: []*regexp.Regexp{regexp.MustCompile(`(?i)\bDROP\s+TABLE\b`)}}) vt := NewValidatedTool(tool, rule, DefaultExtractor(), nil) got, err := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if err != nil { t.Fatalf("unexpected error: %v", err) } if !got.IsError { t.Errorf("DROP TABLE should be blocked, got %+v", got) } if !strings.Contains(got.Output, "pattern") { t.Errorf("output should mention the pattern rule, got %q", got.Output) } if !strings.Contains(got.Output, "rules") { t.Errorf("output should mention ValidatorName, got %q", got.Output) } } func TestValidatedTool_EndToEnd_ExtractorFeedsRule(t *testing.T) { // ExtractorSQLCAS pulls affected_rows into Metadata; DiffSizeRule // reads it to block oversized diffs. tool := &fakeInnerTool{name: "SQLCAS", result: &tools.Result{ Data: map[string]any{"affected_rows": 500}, }} rule := validator.NewRuleValidator("rules", "v1", &validator.DiffSizeRule{MaxRows: 100}) vt := NewValidatedTool(tool, rule, ExtractorSQLCAS(), nil) got, _ := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if !got.IsError { t.Errorf("500 affected_rows should be blocked (limit 100), got %+v", got) } if !strings.Contains(got.Output, "500") { t.Errorf("block reason should mention 500 rows, got %q", got.Output) } } // TestNewValidatedTool_NilValidator_Panics guards the construction-time // fail-fast: a nil Validator signals mis-wiring (the industry forgot // to pass a real Validator or validator.AlwaysApprove{} for opt-out). // Panic at construction surfaces it at startup instead of letting the // first Execute call nil-deref. // // TestNewValidatedTool_NilValidator_Panics 守住构造期 fail-fast: nil // Validator 表示错配置 (industry 忘传真 Validator 或 AlwaysApprove{} // opt-out). 构造期 panic 让错配在启动暴露, 而非首次 Execute 时 nil-deref. func TestNewValidatedTool_NilValidator_Panics(t *testing.T) { defer func() { r := recover() if r == nil { t.Fatal("NewValidatedTool with nil Validator must panic") } msg, ok := r.(string) if !ok || !strings.Contains(msg, "validator") || !strings.Contains(msg, "AlwaysApprove") { t.Errorf("panic message must mention validator nil and AlwaysApprove opt-out, got %v", r) } }() tool := &fakeInnerTool{name: "any"} _ = NewValidatedTool(tool, nil, DefaultExtractor(), nil) } // TestNewValidatedTool_NilExtractor_Panics mirrors the Validator guard: // nil extractor would nil-deref inside Execute. Construction-time // panic makes the mis-wiring visible. // // TestNewValidatedTool_NilExtractor_Panics 镜像 Validator 守护: nil // extractor 会在 Execute 里 nil-deref. 构造期 panic 让错配可见. func TestNewValidatedTool_NilExtractor_Panics(t *testing.T) { defer func() { if r := recover(); r == nil { t.Fatal("NewValidatedTool with nil extractor must panic") } }() tool := &fakeInnerTool{name: "any"} _ = NewValidatedTool(tool, validator.AlwaysApprove{}, nil, nil) } // TestValidatedTool_WithAlwaysApprove_PassesThrough confirms the // explicit opt-out path: AlwaysApprove approves everything, the // inner Result passes through unchanged, and the VerdictSink (if // wired) receives the audit-visible ValidatorName "always-approve". // // TestValidatedTool_WithAlwaysApprove_PassesThrough 确认显式 opt-out // 路径: AlwaysApprove 全批准, inner Result 原样透传, VerdictSink (如 // 已接) 收到审计可见的 ValidatorName "always-approve". func TestValidatedTool_WithAlwaysApprove_PassesThrough(t *testing.T) { tool := &fakeInnerTool{name: "unchecked", result: &tools.Result{Output: "raw output"}} var sinkCalls []validator.Verdict sink := func(_ string, v validator.Verdict) { sinkCalls = append(sinkCalls, v) } vt := NewValidatedTool(tool, validator.AlwaysApprove{}, DefaultExtractor(), sink) got, err := vt.Execute(context.Background(), json.RawMessage(`{}`), nil) if err != nil { t.Fatalf("unexpected error: %v", err) } if got.IsError { t.Errorf("AlwaysApprove must not set IsError, got %+v", got) } if got.Output != "raw output" { t.Errorf("output must pass through unchanged, got %q", got.Output) } if len(sinkCalls) != 1 { t.Fatalf("sink must fire once, got %d", len(sinkCalls)) } if sinkCalls[0].ValidatorName != "always-approve" { t.Errorf("sink must see ValidatorName=always-approve for audit filtering, got %q", sinkCalls[0].ValidatorName) } if !sinkCalls[0].Approved { t.Error("sink must see Approved=true") } }