package builtin // SQLCASTool -- safe optimistic-lock UPDATE for staging / shadow tables. // // Purpose: give the agent a bounded write primitive for session-scoped // staging tables (the platform layer provisions these with a version // column; they are NOT production OLTP tables). The tool wraps a single // UPDATE ... WHERE pk=? AND version=? statement, auto-increments the // version column, and re-reads + retries on RowsAffected=0 up to the // configured retry budget. On retry exhaustion the tool returns a // structured failure so the agent can back off and refetch state. // // Scope guard (load-bearing): this tool is for staging / shadow tables // only. Do not point it at production OLTP tables -- it does not model // FK constraints, triggers, or multi-row transactions. Callers are // expected to route production writes through a separate workflow. // // Zero driver dependency: accepts a StagingDB wrapper via constructor // DI; the caller picks the driver. Test layer uses modernc.org/sqlite // (pure Go). StagingDB is a newtype over *sql.DB -- the constructor // refuses to accept a plain *sql.DB, so every call site must write // StagingDB{db} at the integration boundary, making "this db is staging" // an API contract instead of godoc folklore. // // Retry default = 0 (fail fast): for an AI agent a version conflict // means "the world moved"; the right response is usually to refetch // state and re-reason, not silently retry with stale intent. Callers // who know their staging writes are idempotent can opt in to // maxRetries > 0 at construction. // // SQLCASTool -- 针对 staging / 影子表的乐观锁安全写入工具. // // 目的: 给 agent 提供有界的写能力, 用于 session 级 staging 表 (平台层 // 代建, 自带 version 字段; 刻意不用于生产 OLTP 表). 工具封装单条 // UPDATE ... WHERE pk=? AND version=? 语句, 自动递增 version, 在 // RowsAffected=0 时重读 version 并在重试预算内再试. 预算耗尽时返回结构化 // 失败, 供 agent 退避并重新拉状态. // // 作用域约束 (承重): 本工具仅面向 staging / 影子表. 禁止指向生产 OLTP 表 // -- 不建模外键 / 触发器 / 多行事务. 生产写入应另设工作流承载. // // 零 driver 依赖: 通过构造函数 DI 注入 StagingDB 包装, driver 由调用方 // 选. 测试层采用 modernc.org/sqlite (纯 Go). StagingDB 是 *sql.DB 的 // newtype -- 构造函数拒绝裸 *sql.DB, 所以每个调用点都必须在集成边界写 // StagingDB{db}, 让 "此 db 是 staging" 成为 API 契约而非 godoc 俗约. // // 重试默认 0 (fail fast): 对 AI agent 而言 version 冲突意味着 "世界变了", // 正确反应通常是重新拉状态重新推理, 而非 silent retry 以旧意图打新世界. // 集成方若明确知道自己的 staging 写是幂等的, 可在构造时显式 opt-in 到 // maxRetries > 0. import ( "context" "database/sql" "encoding/json" "errors" "fmt" "regexp" "sort" "strings" "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // StagingDB wraps *sql.DB to signal at every call site that the handle // targets a platform-managed staging / shadow table rather than a // production OLTP primary. Embedding *sql.DB means method promotion // keeps the caller API ergonomic (StagingDB{db}.QueryRow works). Does // not defend against StagingDB{prodDB} lies -- it is an intent marker, // not a sandbox -- but it turns the staging-scope constraint from a // godoc claim into a type signature that the compiler and every reader // must acknowledge. // // StagingDB 将 *sql.DB 封装, 在每个调用点表明该 handle 指向平台管理的 // staging / 影子表而非生产 OLTP 主库. 内嵌 *sql.DB 经由 method promotion // 保持调用方 API 人体工学 (StagingDB{db}.QueryRow 可直接用). 不能防 // StagingDB{prodDB} 撒谎 -- 它是意图标记而非沙盒 -- 但把 staging 作用域 // 约束从 godoc 声明升格为类型签名, 编译器和每位阅读者都必须显式认账. type StagingDB struct { *sql.DB } // SQLCASTool implements the Tool interface for optimistic-lock UPDATE on // staging tables. Holds a StagingDB and the retry budget; no other // mutable state, safe to share across goroutines. // // SQLCASTool 实现 Tool 接口做 staging 表乐观锁 UPDATE. 持有 StagingDB 和 // 重试预算; 无其他可变状态, 可 goroutine 间共享. type SQLCASTool struct { db StagingDB maxRetries int } // NewSQLCASTool constructs a CAS tool. maxRetries semantics: initial // attempt + up to maxRetries additional retries (maxRetries=0 means one // attempt then fail; maxRetries=3 means at most 4 round-trips). // // Panics on invalid config (DI contract, surfaces errors at startup // rather than under load): db.DB must be non-nil, maxRetries must be in // [0, 10] (upper bound is a 防呆 guard against typo'd large values; // practical retry budgets rarely exceed single digits). // // NewSQLCASTool 构造 CAS 工具. maxRetries 语义: 初次尝试加最多 maxRetries // 次重试 (maxRetries=0 即一次尝试后直接失败; maxRetries=3 即最多 4 次 // 往返). // // 非法配置 panic (DI 契约, 启动期暴露问题而非负载下炸): db.DB 必须非 nil, // maxRetries 必须在 [0, 10] 区间 (上限是防呆 guard, 防止误填过大数值; // 实际重试预算几乎不超个位数). func NewSQLCASTool(db StagingDB, maxRetries int) *SQLCASTool { if db.DB == nil { panic("builtin.NewSQLCASTool: db.DB must not be nil (方案 β 严格 DI)") } if maxRetries < 0 { panic(fmt.Sprintf("builtin.NewSQLCASTool: maxRetries must be >= 0, got %d", maxRetries)) } if maxRetries > 10 { panic(fmt.Sprintf("builtin.NewSQLCASTool: maxRetries must be <= 10 (防呆上限), got %d", maxRetries)) } return &SQLCASTool{db: db, maxRetries: maxRetries} } // sqlCASInput is the JSON body the LLM produces. pk_val and update_cols // values are kept as json.RawMessage so we can enforce a strict // allowlist of JSON scalar types before binding them to driver // parameters. // // sqlCASInput 是 LLM 产出的 JSON 载荷. pk_val 和 update_cols 的值以 // json.RawMessage 保留原型, 以便在绑定到 driver 参数前强制 JSON 标量类型 // 白名单. type sqlCASInput struct { Table string `json:"table"` PKCol string `json:"pk_col"` PKVal json.RawMessage `json:"pk_val"` VersionCol string `json:"version_col"` ExpectedVersion int64 `json:"expected_version"` UpdateCols map[string]json.RawMessage `json:"update_cols"` } // Name returns the tool name. // // Name 返回工具名. func (t *SQLCASTool) Name() string { return "SQLCAS" } // Description is what the LLM reads when deciding to call this tool. // The staging-only scope and retry budget are both surfaced so the LLM // self-limits to the intended usage. // // Description 是 LLM 决定调用本工具时阅读的说明. staging 作用域和重试 // 预算都显式暴露, 让 LLM 自我约束在预期用法内. func (t *SQLCASTool) Description(ctx context.Context) string { var b strings.Builder b.WriteString("Performs an optimistic-lock UPDATE on a staging / shadow table. ") b.WriteString("The table must have an integer version column; the tool re-reads version on conflict and retries up to its configured budget. ") fmt.Fprintf(&b, "Retry budget: initial attempt + up to %d retries (configured at construction). ", t.maxRetries) b.WriteString("INTENDED FOR STAGING / SHADOW TABLES ONLY -- do not target production OLTP tables. ") b.WriteString("Identifiers (table / pk_col / version_col / update_cols keys) must match [a-zA-Z_]\\w* (plain, unquoted). ") b.WriteString("Value types: pk_val and update_cols values support string / number / boolean / null.") return b.String() } // InputSchema returns the JSON Schema for the tool's input. // // InputSchema 返回工具输入的 JSON Schema. func (t *SQLCASTool) InputSchema() json.RawMessage { return json.RawMessage(`{ "type": "object", "properties": { "table": { "type": "string", "description": "Target staging/shadow table name. Must match [a-zA-Z_]\\w*" }, "pk_col": { "type": "string", "description": "Primary key column name. Must match [a-zA-Z_]\\w*" }, "pk_val": { "description": "Primary key value. Allowed JSON types: string, number, boolean, null" }, "version_col": { "type": "string", "description": "Integer version column name. Must match [a-zA-Z_]\\w*" }, "expected_version": { "type": "integer", "description": "Client-observed version. Used in the first WHERE version=? predicate; re-read from DB on conflict" }, "update_cols": { "type": "object", "description": "Map of column_name -> new_value. At least one entry required. Keys must match [a-zA-Z_]\\w*; values: string/number/bool/null" } }, "required": ["table", "pk_col", "pk_val", "version_col", "expected_version", "update_cols"] }`) } // Metadata declares the tool's cross-cutting properties. CAS is // concurrency-safe by design (the retry loop converges under // contention). Not ReadOnly (UPDATE). Not Destructive (staging scope, // blast radius bounded; production destruction is a separate workflow). // // Metadata 声明工具的跨切面属性. CAS 按设计 concurrency-safe (重试循环 // 在争用下收敛). 非 ReadOnly (UPDATE). 非 Destructive (staging 作用域, // blast radius 有界; 生产级破坏另有工作流). func (t *SQLCASTool) Metadata() tools.Metadata { return tools.Metadata{ ConcurrencySafe: true, ReadOnly: false, Destructive: false, SearchHint: "sql cas optimistic lock version update staging", PermissionClass: permission.PermClassGeneric, AuditOperation: "edit", } } // Execute runs the CAS loop. All business-level failures (row missing, // version type mismatch, identifier rejected, retries exhausted) return // &Result{IsError:true} with a human-readable Output so the LLM can // self-correct; only input-JSON parse errors propagate as Go errors. // // Execute 执行 CAS 循环. 业务级失败 (行缺失 / version 类型不匹配 / // 标识符被拒 / 重试耗尽) 全部以 &Result{IsError:true} 并附人类可读 Output // 返回, 让 LLM 自我纠正; 仅输入 JSON 解析错误作为 Go error 向上传播. func (t *SQLCASTool) Execute(ctx context.Context, input json.RawMessage, progress tools.ProgressFunc) (*tools.Result, error) { var params sqlCASInput if err := json.Unmarshal(input, ¶ms); err != nil { return nil, fmt.Errorf("sqlcas: invalid input: %w", err) } if !isValidIdentifier(params.Table) { return casReject("identifier %q not allowed for table (must match [a-zA-Z_]\\w*, plain unquoted)", params.Table), nil } if !isValidIdentifier(params.PKCol) { return casReject("identifier %q not allowed for pk_col (must match [a-zA-Z_]\\w*, plain unquoted)", params.PKCol), nil } if !isValidIdentifier(params.VersionCol) { return casReject("identifier %q not allowed for version_col (must match [a-zA-Z_]\\w*, plain unquoted)", params.VersionCol), nil } if len(params.UpdateCols) == 0 { return casReject("update_cols is empty. Provide at least one column:value pair to update"), nil } pkArg, err := jsonValueToArg(params.PKVal) if err != nil { return casReject("pk_val has unsupported JSON type: %v. Allowed: string, number, boolean, null", err), nil } colNames := make([]string, 0, len(params.UpdateCols)) for col := range params.UpdateCols { colNames = append(colNames, col) } sort.Strings(colNames) updateArgs := make([]any, 0, len(colNames)) setParts := make([]string, 0, len(colNames)) for _, col := range colNames { if !isValidIdentifier(col) { return casReject("identifier %q not allowed for update_cols key (must match [a-zA-Z_]\\w*, plain unquoted)", col), nil } arg, err := jsonValueToArg(params.UpdateCols[col]) if err != nil { return casReject("update_cols[%q] has unsupported JSON type: %v. Allowed: string, number, boolean, null", col, err), nil } updateArgs = append(updateArgs, arg) setParts = append(setParts, col+"=?") } updateSQL := fmt.Sprintf( "UPDATE %s SET %s, %s=%s+1 WHERE %s=? AND %s=?", params.Table, strings.Join(setParts, ", "), params.VersionCol, params.VersionCol, params.PKCol, params.VersionCol, ) selectSQL := fmt.Sprintf("SELECT %s FROM %s WHERE %s=?", params.VersionCol, params.Table, params.PKCol, ) expected := params.ExpectedVersion for attempt := 0; attempt <= t.maxRetries; attempt++ { args := make([]any, 0, len(updateArgs)+2) args = append(args, updateArgs...) args = append(args, pkArg, expected) res, err := t.db.ExecContext(ctx, updateSQL, args...) if err != nil { return casDriverErr(err), err } n, err := res.RowsAffected() if err != nil { return casDriverErr(err), err } if n == 1 { return &tools.Result{ Output: fmt.Sprintf("CAS applied. Table %s, pk %s=%v, version %d -> %d (attempt %d).", params.Table, params.PKCol, pkArg, expected, expected+1, attempt+1), Data: map[string]any{ "valid": true, "new_version": expected + 1, "rows_affected": int64(1), "attempts": attempt + 1, }, }, nil } var rawVersion any err = t.db.QueryRowContext(ctx, selectSQL, pkArg).Scan(&rawVersion) if errors.Is(err, sql.ErrNoRows) { return casReject("no row where %s=%v in table %s (verify pk_val, or confirm the row was pre-seeded)", params.PKCol, pkArg, params.Table), nil } if err != nil { return casDriverErr(err), err } current, ok := rawVersion.(int64) if !ok { return casReject("version column %s has non-int type %T (only int version columns supported; timestamp / uuid are not)", params.VersionCol, rawVersion), nil } if current == expected { return casDriverErr(fmt.Errorf("RowsAffected=0 but version unchanged (%d); likely driver inconsistency", current)), nil } expected = current } return &tools.Result{ Output: fmt.Sprintf( "CAS failed: version conflict after %d retries (max). Another writer is updating %s pk=%v; back off and re-read state before retrying.", t.maxRetries, params.Table, pkArg), IsError: true, }, nil } // casReject builds a user-visible rejection result (IsError=true) whose // Output is a short, actionable sentence the LLM can act on without // further context. // // casReject 构造用户可见的拒绝结果 (IsError=true), Output 是 LLM 可直接据此 // 纠正的一句短提示, 无需额外上下文. func casReject(format string, args ...any) *tools.Result { return &tools.Result{ Output: "CAS rejected: " + fmt.Sprintf(format, args...), IsError: true, } } // casDriverErr wraps a driver-level error (SQL exec / scan failure) as // a Result so the LLM sees a readable Output; callers of Execute still // receive the underlying error for audit / telemetry. // // casDriverErr 把 driver 级错误 (SQL exec / scan 失败) 包成 Result 让 LLM // 看到可读 Output; Execute 的调用方同时会收到底层 error 用于审计 / 遥测. func casDriverErr(err error) *tools.Result { return &tools.Result{ Output: fmt.Sprintf("CAS failed: driver error: %v", err), IsError: true, } } var reIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) // isValidIdentifier reports whether s is a plain SQL identifier of the // form [a-zA-Z_]\w*. Quoted identifiers (backtick / double-quote) are // deliberately not accepted -- same policy as SQLValidatorTool's // extractTableRefs known blind spots, for consistency. // // isValidIdentifier 判断 s 是否为 [a-zA-Z_]\w* 形态的纯 SQL 标识符. 反引号 // / 双引号 identifier 刻意不接受 -- 与 SQLValidatorTool extractTableRefs // 的已知盲区策略一致, 保持一致性. func isValidIdentifier(s string) bool { return reIdentifier.MatchString(s) } // jsonValueToArg converts a raw JSON scalar to a driver argument, // rejecting arrays / objects. Go's JSON unmarshal maps JSON number to // float64 by default -- that round-trips through database/sql to // integer columns on every driver we care about, so we keep the // conversion unopinionated. // // jsonValueToArg 把原始 JSON 标量转成 driver 参数, 拒绝数组 / 对象. Go 的 // JSON unmarshal 默认把 JSON number 映射为 float64 -- 在我们关心的所有 // driver 上都能正确 round-trip 到整型列, 故此处不做额外类型转换. func jsonValueToArg(raw json.RawMessage) (any, error) { var v any if err := json.Unmarshal(raw, &v); err != nil { return nil, err } switch v.(type) { case nil, string, float64, bool: return v, nil } return nil, fmt.Errorf("%T", v) }