package builtin import ( "context" "database/sql" "encoding/json" "fmt" "strings" "testing" _ "modernc.org/sqlite" "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // newDryRunTestDB opens an in-memory sqlite DB and seeds 5 rows into // an orders table (wave_id=1, status='new'). // // newDryRunTestDB 打开内存 sqlite, 在 orders 表中种入 5 行 (wave_id=1, // status='new'). func newDryRunTestDB(t *testing.T) StagingDB { t.Helper() db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatalf("open: %v", err) } t.Cleanup(func() { _ = db.Close() }) if _, err := db.Exec(`CREATE TABLE orders ( id INTEGER PRIMARY KEY, status TEXT, wave_id INTEGER, qty INTEGER )`); err != nil { t.Fatalf("create: %v", err) } for i := 1; i <= 5; i++ { if _, err := db.Exec( `INSERT INTO orders (id, status, wave_id, qty) VALUES (?, 'new', 1, ?)`, i, i*10, ); err != nil { t.Fatalf("seed: %v", err) } } return StagingDB{DB: db} } func dryRunInput(t *testing.T, raw map[string]any) json.RawMessage { t.Helper() b, err := json.Marshal(raw) if err != nil { t.Fatalf("marshal: %v", err) } return b } func countOrders(t *testing.T, db StagingDB) int { t.Helper() var n int if err := db.QueryRow("SELECT COUNT(*) FROM orders").Scan(&n); err != nil { t.Fatalf("count: %v", err) } return n } func TestSQLDryRun_UpdateHappyPath(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE wave_id=?", "args": []any{1}, "preview_predicate": "SELECT id, status, wave_id FROM orders WHERE wave_id=?", "preview_args": []any{1}, }), nil) if err != nil { t.Fatalf("err: %v", err) } if res.IsError { t.Fatalf("expected success: %s", res.Output) } dr, ok := res.Data.(*tools.DryRunResult) if !ok { t.Fatalf("Data is %T, want *tools.DryRunResult", res.Data) } if !strings.Contains(dr.WouldAffect, "UPDATE") { t.Errorf("WouldAffect missing UPDATE: %q", dr.WouldAffect) } if !strings.Contains(dr.Preview, "Before:") || !strings.Contains(dr.Preview, "After:") { t.Errorf("Preview missing Before/After sections: %q", dr.Preview) } if !strings.Contains(dr.Preview, "new") || !strings.Contains(dr.Preview, "picked") { t.Errorf("Preview missing before 'new' or after 'picked': %q", dr.Preview) } if dr.EstimatedImpact["operation"] != "UPDATE" { t.Errorf("Impact operation=%v, want UPDATE", dr.EstimatedImpact["operation"]) } if dr.EstimatedImpact["rows_affected"] != int64(5) { t.Errorf("rows_affected=%v, want 5", dr.EstimatedImpact["rows_affected"]) } if dr.EstimatedImpact["consistency"] != "pass" { t.Errorf("consistency=%v, want pass", dr.EstimatedImpact["consistency"]) } } func TestSQLDryRun_UpdateRollbackReallyRolls(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) _, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE wave_id=?", "args": []any{1}, "preview_predicate": "SELECT id FROM orders WHERE wave_id=?", "preview_args": []any{1}, }), nil) if err != nil { t.Fatalf("err: %v", err) } var status string if err := db.QueryRow("SELECT status FROM orders WHERE id=1").Scan(&status); err != nil { t.Fatalf("readback: %v", err) } if status != "new" { t.Errorf("expected status unchanged after Dry-run ROLLBACK, got %q", status) } } func TestSQLDryRun_UpdateAfterPredicateMismatchSignal(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) // predicate references the column that UPDATE modifies; after the // UPDATE, predicate won't match any row -> after_predicate_mismatch. res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE status='new'", "args": []any{}, "preview_predicate": "SELECT id, status FROM orders WHERE status='new'", "preview_args": []any{}, }), nil) if err != nil { t.Fatalf("err: %v", err) } if res.IsError { t.Fatalf("expected success (mismatch is a signal not error): %s", res.Output) } dr := res.Data.(*tools.DryRunResult) if dr.EstimatedImpact["after_predicate_mismatch"] != true { t.Errorf("expected after_predicate_mismatch=true, got %v", dr.EstimatedImpact["after_predicate_mismatch"]) } if dr.EstimatedImpact["preview_rows_before"] != int64(5) { t.Errorf("before=%v, want 5", dr.EstimatedImpact["preview_rows_before"]) } if dr.EstimatedImpact["preview_rows_after"] != int64(0) { t.Errorf("after=%v, want 0", dr.EstimatedImpact["preview_rows_after"]) } // Before count matches RowsAffected, so main consistency is still pass. if dr.EstimatedImpact["consistency"] != "pass" { t.Errorf("consistency=%v, want pass (before matches RowsAffected)", dr.EstimatedImpact["consistency"]) } } func TestSQLDryRun_UpdatePredicateCountMismatch(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) // predicate matches 5 rows (wave_id=1), but UPDATE WHERE id=1 only // affects 1 row -> consistency=mismatch. res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE id=?", "args": []any{1}, "preview_predicate": "SELECT id FROM orders WHERE wave_id=?", "preview_args": []any{1}, }), nil) if err != nil { t.Fatalf("err: %v", err) } dr := res.Data.(*tools.DryRunResult) if dr.EstimatedImpact["consistency"] != "mismatch" { t.Errorf("expected consistency=mismatch, got %v", dr.EstimatedImpact["consistency"]) } if _, has := dr.EstimatedImpact["mismatch_reason"]; !has { t.Error("expected mismatch_reason present") } } func TestSQLDryRun_DeleteHappyPath(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "DELETE FROM orders WHERE wave_id=?", "args": []any{1}, "preview_predicate": "SELECT id, status FROM orders WHERE wave_id=?", "preview_args": []any{1}, }), nil) if err != nil { t.Fatalf("err: %v", err) } if res.IsError { t.Fatalf("expected success: %s", res.Output) } dr := res.Data.(*tools.DryRunResult) if dr.EstimatedImpact["operation"] != "DELETE" { t.Errorf("operation=%v, want DELETE", dr.EstimatedImpact["operation"]) } if dr.EstimatedImpact["rows_affected"] != int64(5) { t.Errorf("rows_affected=%v, want 5", dr.EstimatedImpact["rows_affected"]) } if !strings.Contains(dr.Preview, "to be deleted") { t.Errorf("Preview missing 'to be deleted' label: %q", dr.Preview) } if countOrders(t, db) != 5 { t.Errorf("ROLLBACK failed: expected 5 rows still, got %d", countOrders(t, db)) } } func TestSQLDryRun_InsertHappyPath(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "INSERT INTO orders (id, status, wave_id, qty) VALUES (?, ?, ?, ?)", "args": []any{99, "new", 2, 100}, }), nil) if err != nil { t.Fatalf("err: %v", err) } if res.IsError { t.Fatalf("expected success: %s", res.Output) } dr := res.Data.(*tools.DryRunResult) if dr.EstimatedImpact["operation"] != "INSERT" { t.Errorf("operation=%v, want INSERT", dr.EstimatedImpact["operation"]) } if dr.EstimatedImpact["rows_affected"] != int64(1) { t.Errorf("rows_affected=%v, want 1", dr.EstimatedImpact["rows_affected"]) } if !strings.Contains(dr.Preview, "INSERT args:") { t.Errorf("Preview missing args label: %q", dr.Preview) } if countOrders(t, db) != 5 { t.Errorf("ROLLBACK failed: expected 5 rows, got %d", countOrders(t, db)) } } func TestSQLDryRun_InsertRejectsPredicate(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "INSERT INTO orders (id, status, wave_id, qty) VALUES (99, 'new', 2, 100)", "preview_predicate": "SELECT * FROM orders", }), nil) if err != nil { t.Fatalf("err: %v", err) } if !res.IsError { t.Fatalf("expected reject, got success: %s", res.Output) } if !strings.Contains(res.Output, "must not be supplied for INSERT") { t.Errorf("unexpected Output: %q", res.Output) } } func TestSQLDryRun_UpdateRequiresPredicate(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE wave_id=1", "args": []any{}, }), nil) if err != nil { t.Fatalf("err: %v", err) } if !res.IsError { t.Fatalf("expected reject") } if !strings.Contains(res.Output, "required for UPDATE") { t.Errorf("unexpected Output: %q", res.Output) } } func TestSQLDryRun_DeleteRequiresPredicate(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "DELETE FROM orders WHERE wave_id=1", }), nil) if err != nil { t.Fatalf("err: %v", err) } if !res.IsError { t.Fatalf("expected reject") } if !strings.Contains(res.Output, "required for DELETE") { t.Errorf("unexpected Output: %q", res.Output) } } func TestSQLDryRun_RejectNonDML(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) for _, sql := range []string{ "SELECT * FROM orders", "DROP TABLE orders", "CREATE TABLE x (id INT)", "ALTER TABLE orders ADD COLUMN x INT", } { t.Run(sql, func(t *testing.T) { res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": sql, }), nil) if err != nil { t.Fatalf("err: %v", err) } if !res.IsError { t.Fatalf("expected reject for %q: %s", sql, res.Output) } if !strings.Contains(res.Output, "not supported") { t.Errorf("unexpected Output: %q", res.Output) } }) } } func TestSQLDryRun_EmptySQL(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "", }), nil) if err != nil { t.Fatalf("err: %v", err) } if !res.IsError { t.Fatal("expected reject for empty sql") } } func TestSQLDryRun_DriverErrorBadSQL(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE nonexistent_table SET x=1 WHERE y=2", "preview_predicate": "SELECT * FROM nonexistent_table", }), nil) if err == nil { t.Fatal("expected Go error for driver error (audit channel)") } if !res.IsError { t.Fatal("expected IsError=true") } if !strings.Contains(res.Output, "driver error") { t.Errorf("unexpected Output: %q", res.Output) } } func TestSQLDryRun_TruncateSnapshot(t *testing.T) { db := newDryRunTestDB(t) // Seed 101 more rows so preview overshoots the 100-row cap (edge // case: exactly 100 rows does NOT truncate; need >=101 to trigger). for i := 100; i <= 200; i++ { if _, err := db.Exec( `INSERT INTO orders (id, status, wave_id, qty) VALUES (?, 'new', 2, ?)`, i, i, ); err != nil { t.Fatalf("seed: %v", err) } } tool := NewSQLDryRunTool(db) res, err := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE wave_id=?", "args": []any{2}, "preview_predicate": "SELECT id, status FROM orders WHERE wave_id=?", "preview_args": []any{2}, }), nil) if err != nil { t.Fatalf("err: %v", err) } dr := res.Data.(*tools.DryRunResult) if dr.EstimatedImpact["before_truncated"] != true { t.Errorf("expected before_truncated=true, got %v", dr.EstimatedImpact["before_truncated"]) } if !strings.Contains(dr.Preview, "truncated at maxSnapshotRows=100") { t.Errorf("Preview missing explicit truncation notice: %q", dr.Preview) } } func TestSQLDryRun_ConstructorPanicNilDB(t *testing.T) { defer func() { r := recover() if r == nil { t.Fatal("expected panic on nil db") } if !strings.Contains(fmt.Sprint(r), "db.DB must not be nil") { t.Errorf("unexpected panic: %v", r) } }() _ = NewSQLDryRunTool(StagingDB{}) } func TestSQLDryRun_ToolInterface(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) if tool.Name() != "SQLDryRun" { t.Errorf("Name=%q, want SQLDryRun", tool.Name()) } desc := tool.Description(context.Background()) for _, frag := range []string{"BEGIN/ROLLBACK", "UPDATE", "DELETE", "INSERT", "STAGING", "100 rows"} { if !strings.Contains(desc, frag) { t.Errorf("Description missing %q: %s", frag, desc) } } var schema map[string]any if err := json.Unmarshal(tool.InputSchema(), &schema); err != nil { t.Fatalf("InputSchema JSON: %v", err) } required, _ := schema["required"].([]any) if len(required) != 1 { t.Errorf("expected 1 required field (sql), got %d", len(required)) } md := tool.Metadata() if md.ReadOnly { t.Error("expected ReadOnly=false (DML actually executes pre-ROLLBACK)") } if md.ConcurrencySafe { t.Error("expected ConcurrencySafe=false (holds row locks during TX)") } if md.Destructive { t.Error("expected Destructive=false (no commit, bounded by staging)") } if md.PermissionClass != permission.PermClassGeneric { t.Errorf("expected PermClassGeneric, got %q", md.PermissionClass) } var _ tools.Tool = tool var _ tools.MetadataProvider = tool } func TestSQLDryRun_DataIsDryRunResultPointer(t *testing.T) { db := newDryRunTestDB(t) tool := NewSQLDryRunTool(db) res, _ := tool.Execute(context.Background(), dryRunInput(t, map[string]any{ "sql": "UPDATE orders SET status='picked' WHERE id=?", "args": []any{1}, "preview_predicate": "SELECT id FROM orders WHERE id=?", "preview_args": []any{1}, }), nil) if _, ok := res.Data.(*tools.DryRunResult); !ok { t.Fatalf("Data is %T, want *tools.DryRunResult", res.Data) } }