package builtin // SQLDryRunTool -- preview a DML statement without committing it. // // Purpose: let the agent show a human reviewer exactly what a pending // write would do before it happens. The tool opens a transaction, // captures a "before" snapshot via a caller-supplied predicate SELECT, // executes the target SQL, captures an "after" snapshot (UPDATE only), // then ROLLBACKs. The caller's DB is never committed -- the tool is a // read-only primitive by construction, even though it touches write // statements. // // Per-operation asymmetry (deliberate): // - UPDATE: before-SELECT + real SQL + after-SELECT. Two tables let // the human eyeball the diff. // - DELETE: before-SELECT + real SQL. After is implicitly empty. // - INSERT: no predicate needed. Preview = formatted args + affected // row count. // // Caller contract (preview_predicate): // - Required for UPDATE / DELETE; MUST NOT be supplied for INSERT. // - Prefer a stable key (PK). A predicate that references the column // being modified (e.g. WHERE status='new' when UPDATE sets // status='picked') will correctly match before rows but miss after // rows; the tool flags this as "after_predicate_mismatch" in // EstimatedImpact as a *signal* to the reviewer, not an error. // // Consistency check: the tool compares len(before_rows) to // RowsAffected from the real SQL. A mismatch means the LLM's // preview_predicate and the target SQL's WHERE do not describe the // same row set. Flagged as a signal, NOT an error. Does NOT defend // against an adversarial LLM crafting matching-count-but-different- // rows; that is the authorization layer's job, not Dry-run's. // // No row-level diff: the tool does NOT compute per-row before/after // diffs. Row identity requires schema-specific PK awareness (different // per driver / schema), which is a rabbit hole. Side-by-side tables // are honest and cheap; let the human compare. // // Truncation: at most 100 rows per snapshot. When truncated, the // output explicitly says "showing first N rows, more exist" -- no // silent sampling, which is how approval theatre happens. // // Transaction isolation: the caller's StagingDB.BeginTx defaults // apply. Inside the same transaction the tool's UPDATE sees its own // writes (standard SQL behavior). BeginTx holds row locks until // ROLLBACK; in production prefer a read replica if lock contention on // hot rows matters. // // Future replacement: on branching DBs (Neon / PlanetScale), a better // implementation opens a branch, executes the real SQL without // ROLLBACK, and presents the branch state to the reviewer. This tool // is the generic fallback for non-branching DBs. // // SQLDryRunTool -- 预览 DML 语句而不提交. // // 目的: 让 agent 在实际写发生前向人类审核者展示即将发生的影响. 工具 // 开启事务, 经调用方提供的 predicate SELECT 捕获 "改前" 快照, 执行目标 // SQL, 捕获 "改后" 快照 (仅 UPDATE), 然后 ROLLBACK. 调用方的 DB 从不 // commit -- 尽管触及写语句, 工具构造上即为只读原语. // // 按 operation 刻意不对称: // - UPDATE: 改前 SELECT + 真 SQL + 改后 SELECT. 两张表让人眼球对照 // diff. // - DELETE: 改前 SELECT + 真 SQL. 改后隐式为空. // - INSERT: 无需 predicate. 预览 = 格式化参数 + 影响行数. // // 调用方契约 (preview_predicate): // - UPDATE / DELETE 必填; INSERT 禁填. // - 建议使用稳定键 (PK). 若 predicate 引用被修改的列 (如 UPDATE 改 // status='picked', predicate 写 WHERE status='new'), 改前能正确匹配 // 改后会查不到; 工具在 EstimatedImpact 以 "after_predicate_mismatch" // 标注此情况, 作为*信号*供审核者参考, 非错误. // // 一致性检查: 工具比对 len(before_rows) 与真 SQL 的 RowsAffected. 不一致 // 意味 LLM 的 preview_predicate 与目标 SQL 的 WHERE 描述的行集不同. // 标记为信号不作为错误. 不对抗恶意 LLM 构造 "行数匹配但内容不同" 的伪 // predicate; 那是授权层职责, 不是 Dry-run 的事. // // 不算行级 diff: 工具不计算逐行改前/改后 diff. 行 identity 需要 schema // 特定的 PK 意识 (每个 driver / schema 不同), 是个坑. 改前/改后两张表 // 并列诚实且便宜; 让人类对比. // // 截断: 每份快照最多 100 行. 截断时输出显式写 "显示前 N 行, 仍有更多" // -- 不 silent sampling, 以免演变成 approval theatre. // // 事务隔离: 调用方 StagingDB.BeginTx 的默认隔离级别生效. 同事务内工具 // 的 UPDATE 可看到自己的写 (标准 SQL 行为). BeginTx 持行锁直到 ROLLBACK; // 生产上若热点行锁竞争是问题, 建议走 read replica. // // 未来替换: branching DB (Neon / PlanetScale) 上, 更优实现是开 branch // 执行真 SQL 无需 ROLLBACK, 把 branch 状态呈现给审核者. 本工具是 // 非 branching DB 的通用 fallback. import ( "context" "database/sql" "encoding/json" "fmt" "sort" "strings" "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // maxSnapshotRows caps the number of rows materialized per before / // after snapshot. Truncation beyond this is explicit in the output. // // maxSnapshotRows 限制每份改前 / 改后快照物化的行数. 超出截断在输出中 // 显式标注. const maxSnapshotRows = 100 // SQLDryRunTool previews DML statements on a staging database without // committing. Holds a StagingDB; no mutable state, safe to share // across goroutines. // // SQLDryRunTool 在 staging 数据库上预览 DML 语句而不提交. 持有 StagingDB, // 无可变状态, 可跨 goroutine 共享. type SQLDryRunTool struct { db StagingDB } // NewSQLDryRunTool constructs a Dry-run tool. Panics on nil db.DB (DI // contract; surfaces config errors at startup). // // NewSQLDryRunTool 构造 Dry-run 工具. db.DB 为 nil 时 panic (DI 契约, // 启动期暴露配置错误). func NewSQLDryRunTool(db StagingDB) *SQLDryRunTool { if db.DB == nil { panic("builtin.NewSQLDryRunTool: db.DB must not be nil (方案 β 严格 DI)") } return &SQLDryRunTool{db: db} } type sqlDryRunInput struct { SQL string `json:"sql"` Args []json.RawMessage `json:"args,omitempty"` PreviewPredicate string `json:"preview_predicate,omitempty"` PreviewArgs []json.RawMessage `json:"preview_args,omitempty"` } // Name returns the tool name. // // Name 返回工具名. func (t *SQLDryRunTool) Name() string { return "SQLDryRun" } // Description informs the LLM of the contract (staging only, three // operation paths, predicate requirement). // // Description 告知 LLM 契约 (仅 staging, 三种 operation 路径, predicate // 要求). func (t *SQLDryRunTool) Description(ctx context.Context) string { var b strings.Builder b.WriteString("Previews a DML statement inside a BEGIN/ROLLBACK transaction without committing. ") b.WriteString("UPDATE: returns before-SELECT + after-SELECT tables via preview_predicate. ") b.WriteString("DELETE: returns before-SELECT table via preview_predicate. ") b.WriteString("INSERT: returns formatted args + affected row count (no predicate). ") b.WriteString("INTENDED FOR STAGING / SHADOW TABLES ONLY -- holds row locks until ROLLBACK; in production prefer a read replica. ") b.WriteString("preview_predicate should use a stable key (PK). ") b.WriteString("Up to 100 rows per snapshot; truncation is explicit.") return b.String() } // InputSchema returns the JSON Schema for the tool's input. // // InputSchema 返回工具输入的 JSON Schema. func (t *SQLDryRunTool) InputSchema() json.RawMessage { return json.RawMessage(`{ "type": "object", "properties": { "sql": { "type": "string", "description": "The DML statement to preview (UPDATE / DELETE / INSERT)" }, "args": { "type": "array", "description": "Parameters bound to '?' placeholders in sql", "items": {} }, "preview_predicate": { "type": "string", "description": "SELECT statement defining the 'rows of interest'. REQUIRED for UPDATE/DELETE, forbidden for INSERT. Prefer a stable key (PK) predicate" }, "preview_args": { "type": "array", "description": "Parameters bound to '?' placeholders in preview_predicate", "items": {} } }, "required": ["sql"] }`) } // Metadata declares cross-cutting properties. ReadOnly=false because // UPDATE/DELETE/INSERT actually execute (triggers fire, constraints // check, indexes update) even though ROLLBACK reverts. ConcurrencySafe // =false because BeginTx holds row locks until the tool's ROLLBACK -- // two concurrent Dry-runs on the same row would contend. // // Metadata 声明跨切面属性. ReadOnly=false -- 尽管 ROLLBACK 最终撤销, // UPDATE/DELETE/INSERT 实际会触发 trigger / 约束检查 / 索引更新. // ConcurrencySafe=false -- BeginTx 持行锁直到工具 ROLLBACK, 同行两个并发 // Dry-run 会争用. func (t *SQLDryRunTool) Metadata() tools.Metadata { return tools.Metadata{ ConcurrencySafe: false, ReadOnly: false, Destructive: false, SearchHint: "sql dry-run preview update delete insert staging rollback", PermissionClass: permission.PermClassGeneric, AuditOperation: "invoke", } } // Execute classifies sql by first keyword and dispatches to the // per-operation path. Business-level failures (wrong operation, missing // predicate, driver error) return &Result{IsError:true} with an Output // the LLM can act on. // // Execute 按首 keyword 分类 sql, 分派到对应 operation 路径. 业务级失败 // (operation 不支持 / predicate 缺失 / driver 错误) 返回 &Result{IsError: // true}, Output 给 LLM 据以自纠. func (t *SQLDryRunTool) Execute(ctx context.Context, input json.RawMessage, progress tools.ProgressFunc) (*tools.Result, error) { var params sqlDryRunInput if err := json.Unmarshal(input, ¶ms); err != nil { return nil, fmt.Errorf("sqldryrun: invalid input: %w", err) } if strings.TrimSpace(params.SQL) == "" { return dryRunReject("sql is empty"), nil } op := firstKeyword(params.SQL) switch op { case "UPDATE": if params.PreviewPredicate == "" { return dryRunReject("preview_predicate is required for UPDATE"), nil } return t.runUpdate(ctx, params) case "DELETE": if params.PreviewPredicate == "" { return dryRunReject("preview_predicate is required for DELETE"), nil } return t.runDelete(ctx, params) case "INSERT": if params.PreviewPredicate != "" { return dryRunReject("preview_predicate must not be supplied for INSERT"), nil } return t.runInsert(ctx, params) default: return dryRunReject("operation %q not supported (only UPDATE / DELETE / INSERT)", op), nil } } func (t *SQLDryRunTool) runUpdate(ctx context.Context, params sqlDryRunInput) (*tools.Result, error) { tx, err := t.db.BeginTx(ctx, nil) if err != nil { return dryRunDriverErr(err), err } defer func() { _ = tx.Rollback() }() before, err := captureSnapshot(ctx, tx, params.PreviewPredicate, rawArgs(params.PreviewArgs)) if err != nil { return dryRunDriverErr(fmt.Errorf("preview_predicate (before): %w", err)), err } execRes, err := tx.ExecContext(ctx, params.SQL, rawArgs(params.Args)...) if err != nil { return dryRunDriverErr(fmt.Errorf("target sql: %w", err)), err } rowsAffected, err := execRes.RowsAffected() if err != nil { return dryRunDriverErr(fmt.Errorf("RowsAffected: %w", err)), err } after, err := captureSnapshot(ctx, tx, params.PreviewPredicate, rawArgs(params.PreviewArgs)) if err != nil { return dryRunDriverErr(fmt.Errorf("preview_predicate (after): %w", err)), err } impact := map[string]any{ "operation": "UPDATE", "rows_affected": rowsAffected, "preview_rows_before": int64(len(before.rows)), "preview_rows_after": int64(len(after.rows)), "consistency": "pass", } if int64(len(before.rows)) != rowsAffected { impact["consistency"] = "mismatch" impact["mismatch_reason"] = fmt.Sprintf( "preview_predicate matched %d rows before, but target SQL affected %d. "+ "Likely the predicate's WHERE and the target SQL's WHERE describe different row sets -- "+ "check for overlap or use a PK-based predicate.", len(before.rows), rowsAffected, ) } if int64(len(after.rows)) != rowsAffected { impact["after_predicate_mismatch"] = true impact["after_predicate_mismatch_reason"] = fmt.Sprintf( "preview_predicate matched %d rows before and %d rows after, but target SQL affected %d. "+ "The predicate likely references a column the target SQL modifies -- "+ "switch to a PK-based predicate for a stable view.", len(before.rows), len(after.rows), rowsAffected, ) } if before.truncated { impact["before_truncated"] = true } if after.truncated { impact["after_truncated"] = true } preview := formatSnapshot(before, "Before") + "\n" + formatSnapshot(after, "After") wouldAffect := fmt.Sprintf("UPDATE (%d rows affected, %d rows previewed before)", rowsAffected, len(before.rows)) return dryRunResult(wouldAffect, preview, impact), nil } func (t *SQLDryRunTool) runDelete(ctx context.Context, params sqlDryRunInput) (*tools.Result, error) { tx, err := t.db.BeginTx(ctx, nil) if err != nil { return dryRunDriverErr(err), err } defer func() { _ = tx.Rollback() }() before, err := captureSnapshot(ctx, tx, params.PreviewPredicate, rawArgs(params.PreviewArgs)) if err != nil { return dryRunDriverErr(fmt.Errorf("preview_predicate: %w", err)), err } execRes, err := tx.ExecContext(ctx, params.SQL, rawArgs(params.Args)...) if err != nil { return dryRunDriverErr(fmt.Errorf("target sql: %w", err)), err } rowsAffected, err := execRes.RowsAffected() if err != nil { return dryRunDriverErr(fmt.Errorf("RowsAffected: %w", err)), err } impact := map[string]any{ "operation": "DELETE", "rows_affected": rowsAffected, "preview_rows": int64(len(before.rows)), "consistency": "pass", } if int64(len(before.rows)) != rowsAffected { impact["consistency"] = "mismatch" impact["mismatch_reason"] = fmt.Sprintf( "preview_predicate matched %d rows, but target SQL affected %d.", len(before.rows), rowsAffected, ) } if before.truncated { impact["before_truncated"] = true } preview := formatSnapshot(before, "Before (to be deleted)") wouldAffect := fmt.Sprintf("DELETE (%d rows affected, %d rows previewed)", rowsAffected, len(before.rows)) return dryRunResult(wouldAffect, preview, impact), nil } func (t *SQLDryRunTool) runInsert(ctx context.Context, params sqlDryRunInput) (*tools.Result, error) { tx, err := t.db.BeginTx(ctx, nil) if err != nil { return dryRunDriverErr(err), err } defer func() { _ = tx.Rollback() }() execRes, err := tx.ExecContext(ctx, params.SQL, rawArgs(params.Args)...) if err != nil { return dryRunDriverErr(fmt.Errorf("target sql: %w", err)), err } rowsAffected, err := execRes.RowsAffected() if err != nil { return dryRunDriverErr(fmt.Errorf("RowsAffected: %w", err)), err } impact := map[string]any{ "operation": "INSERT", "rows_affected": rowsAffected, } var preview strings.Builder preview.WriteString("INSERT args:\n") if len(params.Args) == 0 { preview.WriteString("(no bound args)\n") } else { for i, arg := range params.Args { fmt.Fprintf(&preview, " [%d] %s\n", i, strings.TrimSpace(string(arg))) } } wouldAffect := fmt.Sprintf("INSERT (%d rows affected)", rowsAffected) return dryRunResult(wouldAffect, preview.String(), impact), nil } // snapshot holds a materialized table-shaped view of a SELECT result, // capped at maxSnapshotRows. // // snapshot 持有 SELECT 结果的表格化物化视图, 上限 maxSnapshotRows 行. type snapshot struct { columns []string rows [][]any truncated bool } // captureSnapshot runs a SELECT inside the given tx, materializing up // to maxSnapshotRows rows. Sets truncated=true when more rows remain // unread. // // captureSnapshot 在给定 tx 内跑 SELECT, 物化至多 maxSnapshotRows 行. // 仍有未读行时 truncated=true. func captureSnapshot(ctx context.Context, tx *sql.Tx, predicate string, args []any) (*snapshot, error) { rows, err := tx.QueryContext(ctx, predicate, args...) if err != nil { return nil, err } defer func() { _ = rows.Close() }() cols, err := rows.Columns() if err != nil { return nil, err } out := &snapshot{columns: cols} for rows.Next() { if len(out.rows) >= maxSnapshotRows { out.truncated = true break } row := make([]any, len(cols)) ptrs := make([]any, len(cols)) for i := range row { ptrs[i] = &row[i] } if err := rows.Scan(ptrs...); err != nil { return nil, err } out.rows = append(out.rows, row) } if err := rows.Err(); err != nil { return nil, err } return out, nil } // formatSnapshot renders a snapshot as a fixed-width text table. Column // widths are capped at 40 chars; overflowing values are truncated with // an ellipsis. Footer line is explicit about truncation so reviewers // never see silently sampled data. // // formatSnapshot 把 snapshot 渲染为定宽文本表. 列宽上限 40 字符; 超长值 // 省略号截断. 页脚显式标注截断, 审核者不会看到被静默采样的数据. func formatSnapshot(s *snapshot, label string) string { var b strings.Builder fmt.Fprintf(&b, "%s:\n", label) if len(s.columns) == 0 { b.WriteString("(no columns)\n") return b.String() } if len(s.rows) == 0 { b.WriteString("(no rows)\n") return b.String() } widths := make([]int, len(s.columns)) for i, col := range s.columns { widths[i] = len(col) } for _, row := range s.rows { for i, v := range row { vs := formatSnapshotValue(v) if len(vs) > widths[i] { widths[i] = len(vs) } } } const maxColWidth = 40 for i := range widths { if widths[i] > maxColWidth { widths[i] = maxColWidth } } for i, col := range s.columns { if i > 0 { b.WriteString(" | ") } fmt.Fprintf(&b, "%-*s", widths[i], truncateForColumn(col, widths[i])) } b.WriteString("\n") for i := range widths { if i > 0 { b.WriteString("-+-") } b.WriteString(strings.Repeat("-", widths[i])) } b.WriteString("\n") for _, row := range s.rows { for i, v := range row { if i > 0 { b.WriteString(" | ") } fmt.Fprintf(&b, "%-*s", widths[i], truncateForColumn(formatSnapshotValue(v), widths[i])) } b.WriteString("\n") } if s.truncated { fmt.Fprintf(&b, "(showing first %d rows, more exist -- truncated at maxSnapshotRows=%d)\n", len(s.rows), maxSnapshotRows) } else { fmt.Fprintf(&b, "(%d rows)\n", len(s.rows)) } return b.String() } func formatSnapshotValue(v any) string { switch x := v.(type) { case nil: return "NULL" case []byte: return string(x) default: return fmt.Sprintf("%v", x) } } func truncateForColumn(s string, w int) string { if len(s) <= w { return s } if w <= 3 { return s[:w] } return s[:w-3] + "..." } // rawArgs converts json.RawMessage args to driver arguments by // unmarshaling each into any. Unsupported types (object / array) flow // through to the driver which will error -- we deliberately do not // pre-validate here because the DML statement's parameter types are // driver-dependent. // // rawArgs 把 json.RawMessage 参数逐个 unmarshal 为 any 后作为 driver 参数. // 不支持的类型 (object / array) 透传给 driver 由其报错 -- 此处刻意不做前 // 置校验, 因 DML 语句的参数类型是 driver 相关的. func rawArgs(raws []json.RawMessage) []any { if len(raws) == 0 { return nil } out := make([]any, len(raws)) for i, r := range raws { var v any _ = json.Unmarshal(r, &v) out[i] = v } return out } // dryRunResult packages a DryRunResult into a Tool Result: structured // payload goes to Data, LLM-readable text goes to Output. // // dryRunResult 把 DryRunResult 封装为 Tool Result: 结构化载荷写入 Data, // LLM 可读文本写入 Output. func dryRunResult(wouldAffect, preview string, impact map[string]any) *tools.Result { dr := &tools.DryRunResult{ WouldAffect: wouldAffect, Preview: preview, EstimatedImpact: impact, } var out strings.Builder out.WriteString("Dry-run: ") out.WriteString(wouldAffect) out.WriteString("\n\n") out.WriteString(preview) out.WriteString("\nImpact:") keys := make([]string, 0, len(impact)) for k := range impact { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { fmt.Fprintf(&out, " %s=%v", k, impact[k]) } out.WriteString("\n") return &tools.Result{ Output: out.String(), Data: dr, } } func dryRunReject(format string, args ...any) *tools.Result { return &tools.Result{ Output: "Dry-run rejected: " + fmt.Sprintf(format, args...), IsError: true, } } func dryRunDriverErr(err error) *tools.Result { return &tools.Result{ Output: fmt.Sprintf("Dry-run failed: driver error: %v", err), IsError: true, } }