package builtin // SQLValidator -- read-only SQL pre-flight gate (pure string parsing). // // Purpose: stop the LLM from executing destructive SQL by shape, before any // database driver sees the statement. Rejects non-SELECT/WITH/EXPLAIN first // keywords, multi-statements, and optionally queries that miss LIMIT or // reference tables outside a whitelist. On accept, returns a normalized // statement (with LIMIT injected if configured). // // Zero DB dependency -- this tool does not open a *sql.DB, import any // driver, or contact a server. It is schema-agnostic and driver-agnostic by // construction; it is deliberately a crude gate, not a full SQL parser. // The consumer is expected to pair it with a real executor (Dry-run / query // tool) to bind the normalized SQL. // // Schema-agnostic via DI (no driver): this is why the tool lives in the // engine core rather than the platform consumer tier. See tools/doc.go for // the wider ToolCapability protocol. // // SQLValidator -- 只读 SQL 前置校验器 (纯字符串解析). // // 目的: 在任何数据库 driver 看到语句前, 从形态上阻止 LLM 执行破坏性 SQL. // 拒绝非 SELECT/WITH/EXPLAIN 起首的语句, 拒绝多条语句, 可选拒绝缺 LIMIT 的查询 // 以及访问白名单外表的查询. 通过时返回规范化语句 (按配置注入 LIMIT). // // 零 DB 依赖 -- 此工具不打开 *sql.DB, 不 import 任何 driver, 不连任何服务器. // 构造上即 schema 无关 / driver 无关; 刻意做成粗校验而非完整 SQL parser. // 使用方应搭配真 executor (Dry-run / 查询工具) 绑定规范化 SQL. // // 经由 DI 实现 schema 无关 (无 driver): 此所以该工具位于引擎层而非平台消费层. // 完整 ToolCapability 协议见 tools/doc.go. import ( "context" "encoding/json" "fmt" "regexp" "strings" "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // SQLValidatorConfig configures the three tunable rules. // // SQLValidatorConfig 配置三条可调规则. type SQLValidatorConfig struct { // AllowedTables is the table-name whitelist. Nil or empty disables the // whitelist check entirely. Names are matched case-insensitively. // // AllowedTables 是表名白名单. nil 或空 slice 完全禁用白名单检查. 名称 // 大小写不敏感匹配. AllowedTables []string // DefaultLimit, when > 0, is the LIMIT value auto-injected when the // caller's SQL has no LIMIT clause. Zero disables injection. // // DefaultLimit 大于 0 时, 用于在调用方 SQL 无 LIMIT 子句时自动注入的 // LIMIT 值. 零值禁用注入. DefaultLimit int // RequireLimit, when true, rejects SQL that lacks a LIMIT clause after // optional injection (i.e. RequireLimit=true + DefaultLimit=0 means // "LIMIT must be present in caller input"). // // RequireLimit 为 true 时, 可选注入后仍缺 LIMIT 子句的 SQL 被拒 (即 // RequireLimit=true + DefaultLimit=0 表示 "调用方必须自带 LIMIT"). RequireLimit bool } // SQLValidatorTool implements the Tool interface for SQL pre-flight validation. // // SQLValidatorTool 实现 Tool 接口做 SQL 前置校验. type SQLValidatorTool struct { cfg SQLValidatorConfig } // NewSQLValidatorTool constructs a validator with the given config. Zero- // value config accepts any SELECT/WITH/EXPLAIN single statement, no LIMIT // enforcement, no table whitelist. // // NewSQLValidatorTool 以给定配置构造校验器. 零值配置接受任意单条 // SELECT/WITH/EXPLAIN 语句, 不强制 LIMIT, 不校验表白名单. func NewSQLValidatorTool(cfg SQLValidatorConfig) *SQLValidatorTool { return &SQLValidatorTool{cfg: cfg} } type sqlValidatorInput struct { SQL string `json:"sql"` } // Name returns the tool name. // // Name 返回工具名. func (t *SQLValidatorTool) Name() string { return "SQLValidate" } // Description is what the LLM reads when deciding to call this tool. // Kept close to the rules so the model self-serves rewrites on rejection. // // Description 是 LLM 决定调用本工具时阅读的说明. 贴近规则书写, 便于模型 // 在被拒时自行改写. func (t *SQLValidatorTool) Description(ctx context.Context) string { var b strings.Builder b.WriteString("Validates a SQL statement for read-only use. ") b.WriteString("Accepts only single statements whose first keyword is SELECT, WITH, or EXPLAIN. ") b.WriteString("Rejects multi-statement input (statements separated by ';'). ") if len(t.cfg.AllowedTables) > 0 { fmt.Fprintf(&b, "Table whitelist: %s. ", strings.Join(t.cfg.AllowedTables, ", ")) } if t.cfg.DefaultLimit > 0 { fmt.Fprintf(&b, "Auto-injects LIMIT %d if caller omits LIMIT. ", t.cfg.DefaultLimit) } if t.cfg.RequireLimit { b.WriteString("LIMIT clause is required. ") } b.WriteString("On accept, returns the normalized SQL for downstream execution. ") b.WriteString("On reject, returns the reason so the caller can rewrite and retry.") return b.String() } // InputSchema returns the JSON Schema for the tool's input. // // InputSchema 返回工具输入的 JSON Schema. func (t *SQLValidatorTool) InputSchema() json.RawMessage { return json.RawMessage(`{ "type": "object", "properties": { "sql": { "type": "string", "description": "The SQL statement to validate (single statement, SELECT/WITH/EXPLAIN only)" } }, "required": ["sql"] }`) } // Metadata declares the tool as read-only. The validator does not contact // any DB, so it is always safe to run concurrently and always permitted. // // Metadata 声明工具为只读. 校验器不连 DB, 始终可并发且始终放行. func (t *SQLValidatorTool) Metadata() tools.Metadata { return tools.Metadata{ ConcurrencySafe: true, ReadOnly: true, Destructive: false, SearchHint: "sql validate readonly select pre-flight gate", PermissionClass: permission.PermClassReadOnly, AuditOperation: "read", } } // Execute parses input, runs Validate, and packages the result. Rejection // is signalled via IsError=true with a human-readable reason; acceptance // returns the normalized SQL in Output and a structured payload in Data. // // Execute 解析输入, 调 Validate, 打包结果. 拒绝通过 IsError=true 携带人类 // 可读的原因; 通过时 Output 为规范化 SQL, Data 为结构化 payload. func (t *SQLValidatorTool) Execute(ctx context.Context, input json.RawMessage, progress tools.ProgressFunc) (*tools.Result, error) { var params sqlValidatorInput if err := json.Unmarshal(input, ¶ms); err != nil { return nil, fmt.Errorf("sqlvalidate: invalid input: %w", err) } if strings.TrimSpace(params.SQL) == "" { return &tools.Result{ Output: "SQL rejected: empty input", IsError: true, }, nil } normalized, reason, ok := t.Validate(params.SQL) if !ok { return &tools.Result{ Output: "SQL rejected: " + reason, IsError: true, }, nil } return &tools.Result{ Output: "SQL accepted. Normalized:\n" + normalized, Data: map[string]any{ "valid": true, "normalized": normalized, }, }, nil } // Validate runs the full rule chain against a caller SQL. Returns the // normalized statement on accept (may equal input when no injection // happens); on reject, reason is human-readable and normalized is empty. // // Exported so other builtin tools (Dry-run, CAS) can reuse the check // without re-wrapping the Tool.Execute boundary. // // Validate 对调用方 SQL 跑完整规则链. 通过时返回规范化语句 (未注入时与输入 // 相同); 拒绝时 reason 为人类可读原因, normalized 为空. // // 导出以便其他 builtin 工具 (Dry-run / CAS) 复用本校验, 无需绕回 Tool.Execute. func (t *SQLValidatorTool) Validate(sql string) (normalized string, reason string, ok bool) { stripped := stripSQLComments(sql) trimmed := strings.TrimSpace(stripped) if trimmed == "" { return "", "empty SQL after stripping comments", false } stmts := splitStatementsQuoteAware(trimmed) if len(stmts) > 1 { return "", "multiple statements not allowed (only one SQL per call)", false } single := strings.TrimSpace(stmts[0]) if single == "" { return "", "empty SQL", false } kw := firstKeyword(single) if !isReadKeyword(kw) { return "", fmt.Sprintf("statement type %q not allowed (only SELECT / WITH / EXPLAIN)", kw), false } if len(t.cfg.AllowedTables) > 0 { tables := extractTableRefs(single) whitelist := make(map[string]struct{}, len(t.cfg.AllowedTables)) for _, tbl := range t.cfg.AllowedTables { whitelist[strings.ToLower(tbl)] = struct{}{} } for _, tbl := range tables { if _, ok := whitelist[strings.ToLower(tbl)]; !ok { return "", fmt.Sprintf("table %q not in whitelist %v", tbl, t.cfg.AllowedTables), false } } } if !hasLimitClause(single) { if t.cfg.DefaultLimit > 0 { single = injectLimit(single, t.cfg.DefaultLimit) } else if t.cfg.RequireLimit { return "", "missing LIMIT clause (RequireLimit is set, DefaultLimit=0 so no injection)", false } } return single, "", true } // stripSQLComments removes -- line comments and /* */ block comments that // sit OUTSIDE string/identifier quotes. String literals (single-quoted), // double-quoted identifiers, and backtick identifiers are preserved // verbatim -- their contents are never scanned for comment markers. // // Kept intentionally simple: unterminated quote / comment is treated as // extending to end-of-input. This is tighter than a real parser but fine // for a pre-flight gate -- malformed SQL rejected at the DB layer anyway. // // stripSQLComments 移除 quote 外的 -- 行注释和 /* */ 块注释. 单引号字符串 / // 双引号 identifier / 反引号 identifier 内容原样保留, 绝不扫描其中的注释标记. // // 刻意保持简单: 未终止的 quote 或注释视为延伸至输入末尾. 比真实 parser 严苛, // 但对前置校验足够 -- 畸形 SQL 由 DB 层兜底拒绝. func stripSQLComments(sql string) string { var out strings.Builder out.Grow(len(sql)) i := 0 for i < len(sql) { c := sql[i] switch c { case '\'', '"', '`': quote := c out.WriteByte(c) i++ for i < len(sql) { out.WriteByte(sql[i]) if sql[i] == quote { if quote == '\'' && i+1 < len(sql) && sql[i+1] == '\'' { out.WriteByte(sql[i+1]) i += 2 continue } i++ break } i++ } case '-': if i+1 < len(sql) && sql[i+1] == '-' { for i < len(sql) && sql[i] != '\n' { i++ } } else { out.WriteByte(c) i++ } case '/': if i+1 < len(sql) && sql[i+1] == '*' { i += 2 for i+1 < len(sql) && !(sql[i] == '*' && sql[i+1] == '/') { i++ } if i+1 < len(sql) { i += 2 } else { i = len(sql) } out.WriteByte(' ') } else { out.WriteByte(c) i++ } default: out.WriteByte(c) i++ } } return out.String() } // splitStatementsQuoteAware splits SQL on semicolons that sit outside // quotes. Trailing empty elements (e.g. from a trailing ';') are dropped. // // splitStatementsQuoteAware 在 quote 外的分号处拆分 SQL. trailing 空串 (如 // 来自结尾 ';') 被丢弃. func splitStatementsQuoteAware(sql string) []string { var out []string var buf strings.Builder i := 0 flush := func() { if s := strings.TrimSpace(buf.String()); s != "" { out = append(out, s) } buf.Reset() } for i < len(sql) { c := sql[i] switch c { case '\'', '"', '`': quote := c buf.WriteByte(c) i++ for i < len(sql) { buf.WriteByte(sql[i]) if sql[i] == quote { if quote == '\'' && i+1 < len(sql) && sql[i+1] == '\'' { buf.WriteByte(sql[i+1]) i += 2 continue } i++ break } i++ } case ';': flush() i++ default: buf.WriteByte(c) i++ } } flush() return out } // firstKeyword returns the uppercased first word of sql, or "" if none. // // firstKeyword 返回 sql 首个单词的大写形式, 无则返回 "". func firstKeyword(sql string) string { trimmed := strings.TrimLeft(sql, " \t\r\n(") end := 0 for end < len(trimmed) { c := trimmed[end] if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_' { end++ continue } break } if end == 0 { return "" } return strings.ToUpper(trimmed[:end]) } // isReadKeyword reports whether kw is one of the allowed read-only start // keywords. Conservative -- unlisted first keywords (including CALL, PRAGMA, // SHOW) are rejected. // // isReadKeyword 判断 kw 是否为允许的只读起首关键字. 保守实现 -- 未列出的 // 首 keyword (含 CALL / PRAGMA / SHOW) 一律拒绝. func isReadKeyword(kw string) bool { switch kw { case "SELECT", "WITH", "EXPLAIN": return true } return false } var ( reLimit = regexp.MustCompile(`(?i)\blimit\s+\d+`) reTable = regexp.MustCompile(`(?i)(?:from|join)\s+(?:[a-zA-Z_][a-zA-Z0-9_]*\.)?([a-zA-Z_][a-zA-Z0-9_]*)`) ) // hasLimitClause checks for a LIMIT token (case-insensitive). // Naive -- LIMIT inside a comment or string literal would be counted. // stripSQLComments already removes comments; string literals containing // the token "LIMIT 10" are rare in practice and accepted as false positive. // // hasLimitClause 检测 LIMIT <整数> 标记 (大小写不敏感). 朴素实现 -- 注释或 // 字符串字面量中的 LIMIT 也会被计入. stripSQLComments 已去除注释; 字符串 // 字面量中恰好含 "LIMIT 10" 这种罕见情况下接受误判为真. func hasLimitClause(sql string) bool { return reLimit.MatchString(sql) } // extractTableRefs returns table names appearing after FROM / JOIN. Simple // identifier-only extraction: schema-qualified names (schema.table) report // just "table", and subqueries / aliases are ignored. Caller should treat // this list as a best-effort sample, not an authoritative parse. // // Known blind spots (deliberately unresolved -- whitelist is a pre-flight // gate, not an authorization layer; the DB's own GRANT SELECT is the // real boundary): // - Backtick / double-quoted identifiers (`orders`, "orders") are not // extracted; whitelist silently passes. Prefer plain identifiers when // whitelist is enabled. // - Subqueries (FROM (SELECT ...)) yield bogus "table" like SELECT; // whitelist rejects them as a false positive (acceptable -- caller // rewrites with a plain table ref). // // extractTableRefs 返回 FROM / JOIN 之后出现的表名. 朴素 identifier 提取: // schema 限定名 (schema.table) 只返回 "table", 子查询 / 别名忽略. 调用方 // 应视为 best-effort 样本, 非权威 parse. // // 已知盲区 (刻意不处理 -- 白名单是前置 gate 非授权层; DB 自身的 GRANT // SELECT 才是真边界): // - 反引号 / 双引号 identifier (`orders`, "orders") 不被提取; 白名单静默 // 放行. 启用白名单时建议使用 plain identifier. // - 子查询 (FROM (SELECT ...)) 会得到 SELECT 这类伪 "table"; 白名单视为 // 假阳性并拒绝 (可接受 -- 调用方改写为 plain table ref). func extractTableRefs(sql string) []string { matches := reTable.FindAllStringSubmatch(sql, -1) var out []string for _, m := range matches { if len(m) >= 2 { out = append(out, m[1]) } } return out } // injectLimit appends a LIMIT clause to sql. The caller has already // stripped any trailing semicolon via splitStatementsQuoteAware, so the // output is semicolon-free by design. // // injectLimit 将 LIMIT 子句追加到 sql 末尾. 调用方已通过 // splitStatementsQuoteAware 去除 trailing 分号, 输出因此不含分号. func injectLimit(sql string, n int) string { return fmt.Sprintf("%s LIMIT %d", strings.TrimRight(sql, " \t\r\n"), n) }