package builtin import ( "context" "encoding/json" "strings" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) func TestSQLValidator_KeywordGate(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{}) accepted := []string{ "SELECT 1", "select 1", "Select * From t", "WITH x AS (SELECT 1) SELECT * FROM x", "EXPLAIN SELECT 1", " \n SELECT 1 ", } for _, sql := range accepted { if _, reason, ok := v.Validate(sql); !ok { t.Errorf("expected accept, got reject: %q reason=%q", sql, reason) } } rejected := []string{ "DELETE FROM t", "UPDATE t SET x = 1", "INSERT INTO t VALUES (1)", "DROP TABLE t", "CREATE TABLE t (x INT)", "CALL proc()", "PRAGMA foreign_keys", "SHOW TABLES", "", " ", } for _, sql := range rejected { if _, _, ok := v.Validate(sql); ok { t.Errorf("expected reject, got accept: %q", sql) } } } func TestSQLValidator_MultiStatement(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{}) cases := []struct { name string sql string wantErr bool }{ {"two statements rejected", "SELECT 1; SELECT 2", true}, {"trailing semicolon ok", "SELECT 1;", false}, {"semicolon inside single quote", "SELECT 'a;b' FROM t", false}, {"semicolon inside double quote", `SELECT "x;y" FROM t`, false}, {"escaped single quote '' inside string", "SELECT 'a''b' FROM t", false}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, reason, ok := v.Validate(tc.sql) if tc.wantErr && ok { t.Errorf("expected reject, got accept for %q", tc.sql) } if !tc.wantErr && !ok { t.Errorf("expected accept, got reject for %q reason=%q", tc.sql, reason) } }) } } func TestSQLValidator_Comments(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{}) accepted := []string{ "-- leading comment\nSELECT 1", "/* block */ SELECT 1", "SELECT 1 /* inner */ FROM t", "SELECT 'not -- a comment' FROM t", "SELECT '/*' FROM t", "SELECT 1 -- trailing comment", } for _, sql := range accepted { if _, reason, ok := v.Validate(sql); !ok { t.Errorf("expected accept: %q reason=%q", sql, reason) } } // Comment-only input should reject (empty after strip). if _, _, ok := v.Validate("-- comment only"); ok { t.Error("expected reject for comment-only input") } if _, _, ok := v.Validate("/* only */"); ok { t.Error("expected reject for block-comment-only input") } // DROP TABLE hidden in a comment stays neutralized -- the leading SELECT // wins because the comment is stripped before keyword detection. if _, _, ok := v.Validate("-- DROP TABLE t\nSELECT 1"); !ok { t.Error("comment containing DROP should not affect keyword detection") } // A comment-hidden statement type still rejects -- the actual first // keyword of the stripped SQL is DELETE. if _, _, ok := v.Validate("/* SELECT */ DELETE FROM t"); ok { t.Error("expected reject when actual first keyword is DELETE") } } func TestSQLValidator_Limit(t *testing.T) { t.Run("existing LIMIT preserved", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{DefaultLimit: 100, RequireLimit: true}) normalized, _, ok := v.Validate("SELECT * FROM t LIMIT 50") if !ok { t.Fatal("expected accept") } if !strings.Contains(normalized, "LIMIT 50") { t.Errorf("expected LIMIT 50 preserved, got %q", normalized) } if strings.Contains(normalized, "LIMIT 100") { t.Errorf("injection should not happen when LIMIT exists, got %q", normalized) } }) t.Run("missing LIMIT injected when DefaultLimit>0", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{DefaultLimit: 100}) normalized, _, ok := v.Validate("SELECT * FROM t") if !ok { t.Fatal("expected accept after injection") } if !strings.Contains(normalized, "LIMIT 100") { t.Errorf("expected LIMIT 100 injected, got %q", normalized) } }) t.Run("missing LIMIT rejected when RequireLimit and no default", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{RequireLimit: true}) _, reason, ok := v.Validate("SELECT * FROM t") if ok { t.Fatal("expected reject") } if !strings.Contains(reason, "LIMIT") { t.Errorf("expected reason to mention LIMIT, got %q", reason) } }) t.Run("missing LIMIT accepted in lenient mode", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{}) if _, _, ok := v.Validate("SELECT * FROM t"); !ok { t.Error("expected accept under zero-value config") } }) t.Run("normalized SQL is semicolon-free after split", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{DefaultLimit: 10}) normalized, _, _ := v.Validate("SELECT 1;") if normalized != "SELECT 1 LIMIT 10" { t.Errorf("expected 'SELECT 1 LIMIT 10', got %q", normalized) } }) t.Run("case-insensitive LIMIT detection", func(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{DefaultLimit: 100, RequireLimit: true}) normalized, _, ok := v.Validate("select * from t limit 5") if !ok { t.Fatal("expected accept") } if strings.Contains(normalized, "LIMIT 100") { t.Errorf("lowercase limit should be detected, injection skipped, got %q", normalized) } }) } func TestSQLValidator_Whitelist(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{ AllowedTables: []string{"orders", "items"}, }) accepted := []string{ "SELECT * FROM orders", "SELECT * FROM Orders", "SELECT * FROM public.orders", "SELECT * FROM orders JOIN items ON orders.id = items.order_id", } for _, sql := range accepted { if _, reason, ok := v.Validate(sql); !ok { t.Errorf("expected accept: %q reason=%q", sql, reason) } } rejected := []string{ "SELECT * FROM customers", "SELECT * FROM orders JOIN customers ON orders.cid = customers.id", } for _, sql := range rejected { if _, _, ok := v.Validate(sql); ok { t.Errorf("expected reject: %q", sql) } } } func TestSQLValidator_WhitelistDisabledWhenEmpty(t *testing.T) { v := NewSQLValidatorTool(SQLValidatorConfig{}) if _, _, ok := v.Validate("SELECT * FROM anything_goes"); !ok { t.Error("expected accept when whitelist is empty") } } func TestSQLValidator_Execute(t *testing.T) { tool := NewSQLValidatorTool(SQLValidatorConfig{DefaultLimit: 100}) t.Run("accept", func(t *testing.T) { input := json.RawMessage(`{"sql":"SELECT * FROM t"}`) res, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if res.IsError { t.Errorf("expected success, got error Output=%q", res.Output) } if !strings.Contains(res.Output, "accepted") { t.Errorf("expected 'accepted' in output, got %q", res.Output) } if !strings.Contains(res.Output, "LIMIT 100") { t.Errorf("expected LIMIT 100 injected in output, got %q", res.Output) } data, ok := res.Data.(map[string]any) if !ok { t.Fatalf("expected Data to be map, got %T", res.Data) } if data["valid"] != true { t.Error("expected Data[valid]=true") } }) t.Run("reject", func(t *testing.T) { input := json.RawMessage(`{"sql":"DROP TABLE t"}`) res, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Error("expected IsError=true") } if !strings.Contains(res.Output, "rejected") { t.Errorf("expected 'rejected' in output, got %q", res.Output) } }) t.Run("empty sql", func(t *testing.T) { input := json.RawMessage(`{"sql":""}`) res, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Error("expected IsError=true for empty sql") } }) t.Run("invalid json", func(t *testing.T) { input := json.RawMessage(`{bad json`) _, err := tool.Execute(context.Background(), input, nil) if err == nil { t.Error("expected error for invalid json") } }) } func TestSQLValidator_ToolInterface(t *testing.T) { tool := NewSQLValidatorTool(SQLValidatorConfig{ AllowedTables: []string{"orders"}, DefaultLimit: 100, RequireLimit: true, }) if tool.Name() != "SQLValidate" { t.Errorf("expected Name=SQLValidate, got %q", tool.Name()) } desc := tool.Description(context.Background()) for _, fragment := range []string{"SELECT", "WITH", "EXPLAIN", "orders", "LIMIT 100", "required"} { if !strings.Contains(desc, fragment) { t.Errorf("Description missing %q: %s", fragment, desc) } } var schema map[string]any if err := json.Unmarshal(tool.InputSchema(), &schema); err != nil { t.Fatalf("InputSchema not valid JSON: %v", err) } props, _ := schema["properties"].(map[string]any) if _, ok := props["sql"]; !ok { t.Error("InputSchema missing sql property") } md := tool.Metadata() if !md.ReadOnly { t.Error("expected ReadOnly=true") } if !md.ConcurrencySafe { t.Error("expected ConcurrencySafe=true") } if md.PermissionClass != permission.PermClassReadOnly { t.Errorf("expected PermissionClass=readonly, got %q", md.PermissionClass) } // Compile-time assertions: tool satisfies Tool + MetadataProvider. var _ tools.Tool = tool var _ tools.MetadataProvider = tool } func TestStripSQLComments(t *testing.T) { cases := []struct { in, want string }{ {"SELECT 1", "SELECT 1"}, {"-- comment\nSELECT 1", "\nSELECT 1"}, {"SELECT 1 -- end", "SELECT 1 "}, {"/* block */ SELECT 1", " SELECT 1"}, {"SELECT '--x' FROM t", "SELECT '--x' FROM t"}, {"SELECT '/*' FROM t", "SELECT '/*' FROM t"}, {"SELECT 'a''b' FROM t", "SELECT 'a''b' FROM t"}, {"SELECT /* unterminated", "SELECT "}, } for _, tc := range cases { got := stripSQLComments(tc.in) if got != tc.want { t.Errorf("stripSQLComments(%q) = %q, want %q", tc.in, got, tc.want) } } } func TestSplitStatementsQuoteAware(t *testing.T) { cases := []struct { in string want []string }{ {"SELECT 1", []string{"SELECT 1"}}, {"SELECT 1; SELECT 2", []string{"SELECT 1", "SELECT 2"}}, {"SELECT 1;", []string{"SELECT 1"}}, {"SELECT 'a;b' FROM t", []string{"SELECT 'a;b' FROM t"}}, {";;;", nil}, {"", nil}, } for _, tc := range cases { got := splitStatementsQuoteAware(tc.in) if len(got) != len(tc.want) { t.Errorf("splitStatementsQuoteAware(%q) len=%d, want %d (%v vs %v)", tc.in, len(got), len(tc.want), got, tc.want) continue } for i := range got { if got[i] != tc.want[i] { t.Errorf("splitStatementsQuoteAware(%q)[%d] = %q, want %q", tc.in, i, got[i], tc.want[i]) } } } }